Track and save only changed w and epsilon

This commit is contained in:
David Rotermund 2022-04-30 16:43:13 +02:00 committed by GitHub
parent fba40c9b83
commit f3385b34ac
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -452,43 +452,45 @@ with torch.no_grad():
eps_xy[id], dtype=torch.float64 eps_xy[id], dtype=torch.float64
) )
# Save the new values if cfg.network_structure.w_trainable[id] is True:
np.save( # Save the new values
cfg.weight_path np.save(
+ "/Weight_L" cfg.weight_path
+ str(id) + "/Weight_L"
+ "_S" + str(id)
+ str(cfg.learning_step) + "_S"
+ ".npy", + str(cfg.learning_step)
network[id].weights.detach().numpy(), + ".npy",
) network[id].weights.detach().numpy(),
try:
tb.add_histogram(
"Weights " + str(id),
network[id].weights,
cfg.learning_step,
) )
except ValueError:
pass
np.save( try:
cfg.eps_xy_path tb.add_histogram(
+ "/EpsXY_L" "Weights " + str(id),
+ str(id) network[id].weights,
+ "_S" cfg.learning_step,
+ str(cfg.learning_step) )
+ ".npy", except ValueError:
network[id].epsilon_xy.detach().numpy(), pass
)
try: if cfg.network_structure.eps_xy_trainable[id] is True:
tb.add_histogram( np.save(
"Epsilon XY " + str(id), cfg.eps_xy_path
+ "/EpsXY_L"
+ str(id)
+ "_S"
+ str(cfg.learning_step)
+ ".npy",
network[id].epsilon_xy.detach().numpy(), network[id].epsilon_xy.detach().numpy(),
cfg.learning_step,
) )
except ValueError: try:
pass tb.add_histogram(
"Epsilon XY " + str(id),
network[id].epsilon_xy.detach().numpy(),
cfg.learning_step,
)
except ValueError:
pass
# Let the torch learning rate scheduler update the # Let the torch learning rate scheduler update the
# learning rates of the optimiers # learning rates of the optimiers