22 lines
474 B
Python
22 lines
474 B
Python
# %%
|
|
import torch
|
|
|
|
# load the state dict of the model
|
|
state = torch.load("Models/Model_seed_4_30.pt")
|
|
print(state)
|
|
# %%
|
|
from tools.make_network import make_network
|
|
|
|
model, _, _ = make_network(
|
|
input_dim_x=28,
|
|
input_dim_y=28,
|
|
input_number_of_channel=3,
|
|
device="cuda",
|
|
config_network_filename="config_network.json",
|
|
)
|
|
model.load_state_dict(state)
|
|
print(model)
|
|
|
|
# %%
|
|
(state['8.parametrizations.weight.original'] < 0).sum(), (model[8].weight < 0).sum()
|
|
# %%
|