PySpikeBySpike_04_2025/load_model.py
2025-04-08 15:21:11 +02:00

22 lines
474 B
Python

# %%
import torch
# load the state dict of the model
state = torch.load("Models/Model_seed_4_30.pt")
print(state)
# %%
from tools.make_network import make_network
model, _, _ = make_network(
input_dim_x=28,
input_dim_y=28,
input_number_of_channel=3,
device="cuda",
config_network_filename="config_network.json",
)
model.load_state_dict(state)
print(model)
# %%
(state['8.parametrizations.weight.original'] < 0).sum(), (model[8].weight < 0).sum()
# %%