Add files via upload
This commit is contained in:
parent
88e877734c
commit
2f8880f0d2
1 changed files with 4 additions and 4 deletions
|
@ -98,8 +98,8 @@ def make_cnn(
|
||||||
logger.info(f"Replace bias in conv2d {i} with {filename_load}")
|
logger.info(f"Replace bias in conv2d {i} with {filename_load}")
|
||||||
temp = np.load(filename_load)
|
temp = np.load(filename_load)
|
||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
torch.Tensor(temp.shape),
|
torch.tensor(temp.shape, dtype=torch.int),
|
||||||
torch.Tensor(cnn[-1]._parameters["bias"].shape),
|
torch.tensor(cnn[-1]._parameters["bias"].data.shape, dtype=torch.int),
|
||||||
)
|
)
|
||||||
cnn[-1]._parameters["bias"] = torch.tensor(
|
cnn[-1]._parameters["bias"] = torch.tensor(
|
||||||
temp,
|
temp,
|
||||||
|
@ -116,8 +116,8 @@ def make_cnn(
|
||||||
logger.info(f"Replace weights in conv2d {i} with {filename_load}")
|
logger.info(f"Replace weights in conv2d {i} with {filename_load}")
|
||||||
temp = np.load(filename_load)
|
temp = np.load(filename_load)
|
||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
torch.Tensor(temp.shape),
|
torch.tensor(temp.shape, dtype=torch.int),
|
||||||
torch.Tensor(cnn[-1]._parameters["weight"].shape),
|
torch.tensor(cnn[-1]._parameters["weight"].data.shape, dtype=torch.int),
|
||||||
)
|
)
|
||||||
cnn[-1]._parameters["weight"] = torch.tensor(
|
cnn[-1]._parameters["weight"] = torch.tensor(
|
||||||
temp,
|
temp,
|
||||||
|
|
Loading…
Reference in a new issue