From 9cfd4eb740d41629c91aa93cc75303a8bf01bf4f Mon Sep 17 00:00:00 2001 From: David Rotermund <54365609+davrot@users.noreply.github.com> Date: Sat, 30 Apr 2022 17:46:17 +0200 Subject: [PATCH] Better detection of previous files --- learn_it.py | 52 +++++++++++++++++++++++++++++++++++----------------- 1 file changed, 35 insertions(+), 17 deletions(-) diff --git a/learn_it.py b/learn_it.py index a06d4db..b01d345 100644 --- a/learn_it.py +++ b/learn_it.py @@ -41,6 +41,7 @@ import time import dataconf import logging from datetime import datetime +import glob from Dataset import ( DatasetMaster, @@ -220,26 +221,43 @@ for id in range(0, len(network)): + ".npy" ) - # Are there weights that overwrite the initial weights? - file_to_load: str = ( - cfg.learning_parameters.overload_path + "/Weight_L" + str(id) + ".npy" - ) - if os.path.exists(file_to_load) is True: - network[id].weights = torch.tensor( - np.load(file_to_load), - dtype=torch.float64, - ) - wf[id] = np.load(file_to_load) - logging.info(f"File used: {file_to_load}") +for id in range(0, len(network)): - file_to_load = cfg.learning_parameters.overload_path + "/EpsXY_L" + str(id) + ".npy" - if os.path.exists(file_to_load) is True: - network[id].epsilon_xy = torch.tensor( - np.load(file_to_load), + # Are there weights that overwrite the initial weights? + file_to_load = glob.glob( + cfg.learning_parameters.overload_path + "/Weight_L" + str(id) + "*.npy" + ) + + if len(file_to_load) > 1: + raise Exception( + f"Too many previous weights files {cfg.learning_parameters.overload_path}/Weight_L{id}*.npy" + ) + + if len(file_to_load) == 1: + network[id].weights = torch.tensor( + np.load(file_to_load[0]), dtype=torch.float64, ) - eps_xy[id] = np.load(file_to_load) - logging.info(f"File used: {file_to_load}") + wf[id] = np.load(file_to_load[0]) + logging.info(f"File used: {file_to_load[0]}") + + # Are there epsinlon xy files that overwrite the initial epsilon xy? + file_to_load = glob.glob( + cfg.learning_parameters.overload_path + "/EpsXY_L" + str(id) + "*.npy" + ) + + if len(file_to_load) > 1: + raise Exception( + f"Too many previous epsilon xy files {cfg.learning_parameters.overload_path}/EpsXY_L{id}*.npy" + ) + + if len(file_to_load) == 1: + network[id].epsilon_xy = torch.tensor( + np.load(file_to_load[0]), + dtype=torch.float64, + ) + eps_xy[id] = np.load(file_to_load[0]) + logging.info(f"File used: {file_to_load[0]}") ####################################################################### # Optimizer and LR Scheduler #