Add files via upload

This commit is contained in:
David Rotermund 2023-07-28 12:58:31 +02:00 committed by GitHub
parent 88e877734c
commit 2f8880f0d2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -98,8 +98,8 @@ def make_cnn(
logger.info(f"Replace bias in conv2d {i} with {filename_load}")
temp = np.load(filename_load)
assert torch.equal(
torch.Tensor(temp.shape),
torch.Tensor(cnn[-1]._parameters["bias"].shape),
torch.tensor(temp.shape, dtype=torch.int),
torch.tensor(cnn[-1]._parameters["bias"].data.shape, dtype=torch.int),
)
cnn[-1]._parameters["bias"] = torch.tensor(
temp,
@ -116,8 +116,8 @@ def make_cnn(
logger.info(f"Replace weights in conv2d {i} with {filename_load}")
temp = np.load(filename_load)
assert torch.equal(
torch.Tensor(temp.shape),
torch.Tensor(cnn[-1]._parameters["weight"].shape),
torch.tensor(temp.shape, dtype=torch.int),
torch.tensor(cnn[-1]._parameters["weight"].data.shape, dtype=torch.int),
)
cnn[-1]._parameters["weight"] = torch.tensor(
temp,