import torch import torchvision as tv import matplotlib.pyplot as plt import os import glob from natsort import natsorted # import numpy as np layer_id: int = 0 scale_each: bool = False model_path: str = "trained_models" filename_list: list = natsorted(glob.glob(os.path.join(model_path, str("*.pt")))) assert len(filename_list) > 0 model_filename: str = filename_list[-1] print(f"Load filename: {model_filename}") model = torch.load(model_filename, map_location=torch.device("cpu")) assert layer_id < len(model) # --- weights = model[layer_id]._parameters["weight"].data bias = model[layer_id]._parameters["bias"].data weight_grid = tv.utils.make_grid( weights, nrow=8, padding=2, scale_each=scale_each, pad_value=float("NaN") ) v_max_abs = torch.abs(weight_grid[0, ...]).max() plt.subplot(3, 1, (1, 2)) plt.imshow( weight_grid[0, ...], vmin=-v_max_abs, vmax=v_max_abs, cmap="cool", ) plt.axis("off") plt.colorbar() plt.title("Weights") plt.subplot(3, 1, 3) plt.plot(bias) plt.title("Bias") plt.show()