From 9bb3d5befcc453d90145525f6109074e66cf4949 Mon Sep 17 00:00:00 2001 From: David Rotermund <54365609+davrot@users.noreply.github.com> Date: Wed, 19 Jul 2023 22:23:04 +0200 Subject: [PATCH] Add files via upload --- inspect_weights_conv_0.py | 45 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 inspect_weights_conv_0.py diff --git a/inspect_weights_conv_0.py b/inspect_weights_conv_0.py new file mode 100644 index 0000000..0d917d3 --- /dev/null +++ b/inspect_weights_conv_0.py @@ -0,0 +1,45 @@ +import torch +import torchvision as tv +import matplotlib.pyplot as plt +import os +import glob +from natsort import natsorted + +# import numpy as np + +layer_id: int = 0 +scale_each: bool = False + +model_path: str = "trained_models" +filename_list: list = natsorted(glob.glob(os.path.join(model_path, str("*.pt")))) +assert len(filename_list) > 0 +model_filename: str = filename_list[-1] +print(f"Load filename: {model_filename}") + +model = torch.load(model_filename, map_location=torch.device("cpu")) +assert layer_id < len(model) + +# --- +weights = model[layer_id]._parameters["weight"].data +bias = model[layer_id]._parameters["bias"].data + +weight_grid = tv.utils.make_grid(weights, nrow=8, padding=2, scale_each=scale_each) + +v_max_abs = torch.abs(weight_grid[0, ...]).max() + +plt.subplot(3, 1, (1, 2)) +plt.imshow( + weight_grid[0, ...], + vmin=-v_max_abs, + vmax=v_max_abs, + cmap="cool", +) +plt.axis("off") +plt.colorbar() +plt.title("Weights") + +plt.subplot(3, 1, 3) +plt.plot(bias) +plt.title("Bias") + +plt.show()