From 2f8880f0d2640243ab9f42e84911a4bbfb792918 Mon Sep 17 00:00:00 2001 From: David Rotermund <54365609+davrot@users.noreply.github.com> Date: Fri, 28 Jul 2023 12:58:31 +0200 Subject: [PATCH] Add files via upload --- functions/make_cnn_v2.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/functions/make_cnn_v2.py b/functions/make_cnn_v2.py index ae4b77f..6179e77 100644 --- a/functions/make_cnn_v2.py +++ b/functions/make_cnn_v2.py @@ -98,8 +98,8 @@ def make_cnn( logger.info(f"Replace bias in conv2d {i} with {filename_load}") temp = np.load(filename_load) assert torch.equal( - torch.Tensor(temp.shape), - torch.Tensor(cnn[-1]._parameters["bias"].shape), + torch.tensor(temp.shape, dtype=torch.int), + torch.tensor(cnn[-1]._parameters["bias"].data.shape, dtype=torch.int), ) cnn[-1]._parameters["bias"] = torch.tensor( temp, @@ -116,8 +116,8 @@ def make_cnn( logger.info(f"Replace weights in conv2d {i} with {filename_load}") temp = np.load(filename_load) assert torch.equal( - torch.Tensor(temp.shape), - torch.Tensor(cnn[-1]._parameters["weight"].shape), + torch.tensor(temp.shape, dtype=torch.int), + torch.tensor(cnn[-1]._parameters["weight"].data.shape, dtype=torch.int), ) cnn[-1]._parameters["weight"] = torch.tensor( temp,