From f3385b34ac7d7342f5cf289f56f6d2653b98a7e9 Mon Sep 17 00:00:00 2001 From: David Rotermund <54365609+davrot@users.noreply.github.com> Date: Sat, 30 Apr 2022 16:43:13 +0200 Subject: [PATCH] Track and save only changed w and epsilon --- learn_it.py | 68 +++++++++++++++++++++++++++-------------------------- 1 file changed, 35 insertions(+), 33 deletions(-) diff --git a/learn_it.py b/learn_it.py index a167f5b..a06d4db 100644 --- a/learn_it.py +++ b/learn_it.py @@ -452,43 +452,45 @@ with torch.no_grad(): eps_xy[id], dtype=torch.float64 ) - # Save the new values - np.save( - cfg.weight_path - + "/Weight_L" - + str(id) - + "_S" - + str(cfg.learning_step) - + ".npy", - network[id].weights.detach().numpy(), - ) - - try: - tb.add_histogram( - "Weights " + str(id), - network[id].weights, - cfg.learning_step, + if cfg.network_structure.w_trainable[id] is True: + # Save the new values + np.save( + cfg.weight_path + + "/Weight_L" + + str(id) + + "_S" + + str(cfg.learning_step) + + ".npy", + network[id].weights.detach().numpy(), ) - except ValueError: - pass - np.save( - cfg.eps_xy_path - + "/EpsXY_L" - + str(id) - + "_S" - + str(cfg.learning_step) - + ".npy", - network[id].epsilon_xy.detach().numpy(), - ) - try: - tb.add_histogram( - "Epsilon XY " + str(id), + try: + tb.add_histogram( + "Weights " + str(id), + network[id].weights, + cfg.learning_step, + ) + except ValueError: + pass + + if cfg.network_structure.eps_xy_trainable[id] is True: + np.save( + cfg.eps_xy_path + + "/EpsXY_L" + + str(id) + + "_S" + + str(cfg.learning_step) + + ".npy", network[id].epsilon_xy.detach().numpy(), - cfg.learning_step, ) - except ValueError: - pass + try: + tb.add_histogram( + "Epsilon XY " + str(id), + network[id].epsilon_xy.detach().numpy(), + cfg.learning_step, + ) + except ValueError: + pass # Let the torch learning rate scheduler update the # learning rates of the optimiers