diff --git a/functions/make_cnn_v2.py b/functions/make_cnn_v2.py index ae4b77f..6179e77 100644 --- a/functions/make_cnn_v2.py +++ b/functions/make_cnn_v2.py @@ -98,8 +98,8 @@ def make_cnn( logger.info(f"Replace bias in conv2d {i} with {filename_load}") temp = np.load(filename_load) assert torch.equal( - torch.Tensor(temp.shape), - torch.Tensor(cnn[-1]._parameters["bias"].shape), + torch.tensor(temp.shape, dtype=torch.int), + torch.tensor(cnn[-1]._parameters["bias"].data.shape, dtype=torch.int), ) cnn[-1]._parameters["bias"] = torch.tensor( temp, @@ -116,8 +116,8 @@ def make_cnn( logger.info(f"Replace weights in conv2d {i} with {filename_load}") temp = np.load(filename_load) assert torch.equal( - torch.Tensor(temp.shape), - torch.Tensor(cnn[-1]._parameters["weight"].shape), + torch.tensor(temp.shape, dtype=torch.int), + torch.tensor(cnn[-1]._parameters["weight"].data.shape, dtype=torch.int), ) cnn[-1]._parameters["weight"] = torch.tensor( temp,