Track and save only changed w and epsilon

This commit is contained in:
David Rotermund 2022-04-30 16:43:13 +02:00 committed by GitHub
parent fba40c9b83
commit f3385b34ac
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -452,43 +452,45 @@ with torch.no_grad():
eps_xy[id], dtype=torch.float64
)
# Save the new values
np.save(
cfg.weight_path
+ "/Weight_L"
+ str(id)
+ "_S"
+ str(cfg.learning_step)
+ ".npy",
network[id].weights.detach().numpy(),
)
try:
tb.add_histogram(
"Weights " + str(id),
network[id].weights,
cfg.learning_step,
if cfg.network_structure.w_trainable[id] is True:
# Save the new values
np.save(
cfg.weight_path
+ "/Weight_L"
+ str(id)
+ "_S"
+ str(cfg.learning_step)
+ ".npy",
network[id].weights.detach().numpy(),
)
except ValueError:
pass
np.save(
cfg.eps_xy_path
+ "/EpsXY_L"
+ str(id)
+ "_S"
+ str(cfg.learning_step)
+ ".npy",
network[id].epsilon_xy.detach().numpy(),
)
try:
tb.add_histogram(
"Epsilon XY " + str(id),
try:
tb.add_histogram(
"Weights " + str(id),
network[id].weights,
cfg.learning_step,
)
except ValueError:
pass
if cfg.network_structure.eps_xy_trainable[id] is True:
np.save(
cfg.eps_xy_path
+ "/EpsXY_L"
+ str(id)
+ "_S"
+ str(cfg.learning_step)
+ ".npy",
network[id].epsilon_xy.detach().numpy(),
cfg.learning_step,
)
except ValueError:
pass
try:
tb.add_histogram(
"Epsilon XY " + str(id),
network[id].epsilon_xy.detach().numpy(),
cfg.learning_step,
)
except ValueError:
pass
# Let the torch learning rate scheduler update the
# learning rates of the optimiers