Better detection of previous files

This commit is contained in:
David Rotermund 2022-04-30 17:46:17 +02:00 committed by GitHub
parent f3385b34ac
commit 9cfd4eb740
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -41,6 +41,7 @@ import time
import dataconf import dataconf
import logging import logging
from datetime import datetime from datetime import datetime
import glob
from Dataset import ( from Dataset import (
DatasetMaster, DatasetMaster,
@ -220,26 +221,43 @@ for id in range(0, len(network)):
+ ".npy" + ".npy"
) )
# Are there weights that overwrite the initial weights? for id in range(0, len(network)):
file_to_load: str = (
cfg.learning_parameters.overload_path + "/Weight_L" + str(id) + ".npy"
)
if os.path.exists(file_to_load) is True:
network[id].weights = torch.tensor(
np.load(file_to_load),
dtype=torch.float64,
)
wf[id] = np.load(file_to_load)
logging.info(f"File used: {file_to_load}")
file_to_load = cfg.learning_parameters.overload_path + "/EpsXY_L" + str(id) + ".npy" # Are there weights that overwrite the initial weights?
if os.path.exists(file_to_load) is True: file_to_load = glob.glob(
network[id].epsilon_xy = torch.tensor( cfg.learning_parameters.overload_path + "/Weight_L" + str(id) + "*.npy"
np.load(file_to_load), )
if len(file_to_load) > 1:
raise Exception(
f"Too many previous weights files {cfg.learning_parameters.overload_path}/Weight_L{id}*.npy"
)
if len(file_to_load) == 1:
network[id].weights = torch.tensor(
np.load(file_to_load[0]),
dtype=torch.float64, dtype=torch.float64,
) )
eps_xy[id] = np.load(file_to_load) wf[id] = np.load(file_to_load[0])
logging.info(f"File used: {file_to_load}") logging.info(f"File used: {file_to_load[0]}")
# Are there epsinlon xy files that overwrite the initial epsilon xy?
file_to_load = glob.glob(
cfg.learning_parameters.overload_path + "/EpsXY_L" + str(id) + "*.npy"
)
if len(file_to_load) > 1:
raise Exception(
f"Too many previous epsilon xy files {cfg.learning_parameters.overload_path}/EpsXY_L{id}*.npy"
)
if len(file_to_load) == 1:
network[id].epsilon_xy = torch.tensor(
np.load(file_to_load[0]),
dtype=torch.float64,
)
eps_xy[id] = np.load(file_to_load[0])
logging.info(f"File used: {file_to_load[0]}")
####################################################################### #######################################################################
# Optimizer and LR Scheduler # # Optimizer and LR Scheduler #