diff --git a/inspect_weights_conv_0.py b/inspect_weights_conv_0.py new file mode 100644 index 0000000..0d917d3 --- /dev/null +++ b/inspect_weights_conv_0.py @@ -0,0 +1,45 @@ +import torch +import torchvision as tv +import matplotlib.pyplot as plt +import os +import glob +from natsort import natsorted + +# import numpy as np + +layer_id: int = 0 +scale_each: bool = False + +model_path: str = "trained_models" +filename_list: list = natsorted(glob.glob(os.path.join(model_path, str("*.pt")))) +assert len(filename_list) > 0 +model_filename: str = filename_list[-1] +print(f"Load filename: {model_filename}") + +model = torch.load(model_filename, map_location=torch.device("cpu")) +assert layer_id < len(model) + +# --- +weights = model[layer_id]._parameters["weight"].data +bias = model[layer_id]._parameters["bias"].data + +weight_grid = tv.utils.make_grid(weights, nrow=8, padding=2, scale_each=scale_each) + +v_max_abs = torch.abs(weight_grid[0, ...]).max() + +plt.subplot(3, 1, (1, 2)) +plt.imshow( + weight_grid[0, ...], + vmin=-v_max_abs, + vmax=v_max_abs, + cmap="cool", +) +plt.axis("off") +plt.colorbar() +plt.title("Weights") + +plt.subplot(3, 1, 3) +plt.plot(bias) +plt.title("Bias") + +plt.show()