mirror of
https://github.com/davrot/pytorch-sbs.git
synced 2025-07-03 11:00:03 +02:00
Track and save only changed w and epsilon
This commit is contained in:
parent
fba40c9b83
commit
f3385b34ac
1 changed files with 35 additions and 33 deletions
|
@ -452,6 +452,7 @@ with torch.no_grad():
|
|||
eps_xy[id], dtype=torch.float64
|
||||
)
|
||||
|
||||
if cfg.network_structure.w_trainable[id] is True:
|
||||
# Save the new values
|
||||
np.save(
|
||||
cfg.weight_path
|
||||
|
@ -472,6 +473,7 @@ with torch.no_grad():
|
|||
except ValueError:
|
||||
pass
|
||||
|
||||
if cfg.network_structure.eps_xy_trainable[id] is True:
|
||||
np.save(
|
||||
cfg.eps_xy_path
|
||||
+ "/EpsXY_L"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue