Add files via upload
This commit is contained in:
parent
031e35b0b8
commit
fbc4516e58
1 changed files with 62 additions and 0 deletions
62
inspect_weights_conv_x.py
Normal file
62
inspect_weights_conv_x.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
import torch
|
||||
import torchvision as tv
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
import glob
|
||||
from natsort import natsorted
|
||||
|
||||
# import numpy as np
|
||||
|
||||
layer_id: int = 3
|
||||
scale_each_inner: bool = False
|
||||
scale_each_outer: bool = False
|
||||
|
||||
model_path: str = "trained_models"
|
||||
filename_list: list = natsorted(glob.glob(os.path.join(model_path, str("*.pt"))))
|
||||
assert len(filename_list) > 0
|
||||
model_filename: str = filename_list[-1]
|
||||
print(f"Load filename: {model_filename}")
|
||||
|
||||
model = torch.load(model_filename, map_location=torch.device("cpu"))
|
||||
assert layer_id < len(model)
|
||||
|
||||
print("Full network:")
|
||||
print(model)
|
||||
print("")
|
||||
print(f"Selected layer {layer_id}:")
|
||||
print(model[layer_id])
|
||||
|
||||
# ---
|
||||
weights = model[layer_id]._parameters["weight"].data
|
||||
bias = model[layer_id]._parameters["bias"].data
|
||||
|
||||
weight_grid = tv.utils.make_grid(
|
||||
weights.movedim(0, 1),
|
||||
nrow=8,
|
||||
padding=2,
|
||||
scale_each=scale_each_inner,
|
||||
pad_value=float("NaN"),
|
||||
)
|
||||
weight_grid = tv.utils.make_grid(
|
||||
weight_grid.unsqueeze(1), nrow=4, padding=2, scale_each=scale_each_outer
|
||||
)
|
||||
|
||||
|
||||
v_max_abs = torch.abs(weight_grid[0, ...]).max()
|
||||
|
||||
plt.subplot(3, 1, (1, 2))
|
||||
plt.imshow(
|
||||
weight_grid[0, ...],
|
||||
vmin=-v_max_abs,
|
||||
vmax=v_max_abs,
|
||||
cmap="cool",
|
||||
)
|
||||
plt.axis("off")
|
||||
plt.colorbar()
|
||||
plt.title("Weights")
|
||||
|
||||
plt.subplot(3, 1, 3)
|
||||
plt.plot(bias)
|
||||
plt.title("Bias")
|
||||
|
||||
plt.show()
|
Loading…
Reference in a new issue