Add files via upload
This commit is contained in:
parent
88e877734c
commit
2f8880f0d2
1 changed files with 4 additions and 4 deletions
|
@ -98,8 +98,8 @@ def make_cnn(
|
|||
logger.info(f"Replace bias in conv2d {i} with {filename_load}")
|
||||
temp = np.load(filename_load)
|
||||
assert torch.equal(
|
||||
torch.Tensor(temp.shape),
|
||||
torch.Tensor(cnn[-1]._parameters["bias"].shape),
|
||||
torch.tensor(temp.shape, dtype=torch.int),
|
||||
torch.tensor(cnn[-1]._parameters["bias"].data.shape, dtype=torch.int),
|
||||
)
|
||||
cnn[-1]._parameters["bias"] = torch.tensor(
|
||||
temp,
|
||||
|
@ -116,8 +116,8 @@ def make_cnn(
|
|||
logger.info(f"Replace weights in conv2d {i} with {filename_load}")
|
||||
temp = np.load(filename_load)
|
||||
assert torch.equal(
|
||||
torch.Tensor(temp.shape),
|
||||
torch.Tensor(cnn[-1]._parameters["weight"].shape),
|
||||
torch.tensor(temp.shape, dtype=torch.int),
|
||||
torch.tensor(cnn[-1]._parameters["weight"].data.shape, dtype=torch.int),
|
||||
)
|
||||
cnn[-1]._parameters["weight"] = torch.tensor(
|
||||
temp,
|
||||
|
|
Loading…
Reference in a new issue