46 lines
1,003 B
Python
46 lines
1,003 B
Python
|
import torch
|
||
|
import torchvision as tv
|
||
|
import matplotlib.pyplot as plt
|
||
|
import os
|
||
|
import glob
|
||
|
from natsort import natsorted
|
||
|
|
||
|
# import numpy as np
|
||
|
|
||
|
layer_id: int = 0
|
||
|
scale_each: bool = False
|
||
|
|
||
|
model_path: str = "trained_models"
|
||
|
filename_list: list = natsorted(glob.glob(os.path.join(model_path, str("*.pt"))))
|
||
|
assert len(filename_list) > 0
|
||
|
model_filename: str = filename_list[-1]
|
||
|
print(f"Load filename: {model_filename}")
|
||
|
|
||
|
model = torch.load(model_filename, map_location=torch.device("cpu"))
|
||
|
assert layer_id < len(model)
|
||
|
|
||
|
# ---
|
||
|
weights = model[layer_id]._parameters["weight"].data
|
||
|
bias = model[layer_id]._parameters["bias"].data
|
||
|
|
||
|
weight_grid = tv.utils.make_grid(weights, nrow=8, padding=2, scale_each=scale_each)
|
||
|
|
||
|
v_max_abs = torch.abs(weight_grid[0, ...]).max()
|
||
|
|
||
|
plt.subplot(3, 1, (1, 2))
|
||
|
plt.imshow(
|
||
|
weight_grid[0, ...],
|
||
|
vmin=-v_max_abs,
|
||
|
vmax=v_max_abs,
|
||
|
cmap="cool",
|
||
|
)
|
||
|
plt.axis("off")
|
||
|
plt.colorbar()
|
||
|
plt.title("Weights")
|
||
|
|
||
|
plt.subplot(3, 1, 3)
|
||
|
plt.plot(bias)
|
||
|
plt.title("Bias")
|
||
|
|
||
|
plt.show()
|