# %% import torch # load the state dict of the model state = torch.load("Models/Model_seed_4_30.pt") print(state) # %% from tools.make_network import make_network model, _, _ = make_network( input_dim_x=28, input_dim_y=28, input_number_of_channel=3, device="cuda", config_network_filename="config_network.json", ) model.load_state_dict(state) print(model) # %% (state['8.parametrizations.weight.original'] < 0).sum(), (model[8].weight < 0).sum() # %%