Better detection of previous files
This commit is contained in:
parent
f3385b34ac
commit
9cfd4eb740
1 changed files with 35 additions and 17 deletions
52
learn_it.py
52
learn_it.py
|
@ -41,6 +41,7 @@ import time
|
|||
import dataconf
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import glob
|
||||
|
||||
from Dataset import (
|
||||
DatasetMaster,
|
||||
|
@ -220,26 +221,43 @@ for id in range(0, len(network)):
|
|||
+ ".npy"
|
||||
)
|
||||
|
||||
# Are there weights that overwrite the initial weights?
|
||||
file_to_load: str = (
|
||||
cfg.learning_parameters.overload_path + "/Weight_L" + str(id) + ".npy"
|
||||
)
|
||||
if os.path.exists(file_to_load) is True:
|
||||
network[id].weights = torch.tensor(
|
||||
np.load(file_to_load),
|
||||
dtype=torch.float64,
|
||||
)
|
||||
wf[id] = np.load(file_to_load)
|
||||
logging.info(f"File used: {file_to_load}")
|
||||
for id in range(0, len(network)):
|
||||
|
||||
file_to_load = cfg.learning_parameters.overload_path + "/EpsXY_L" + str(id) + ".npy"
|
||||
if os.path.exists(file_to_load) is True:
|
||||
network[id].epsilon_xy = torch.tensor(
|
||||
np.load(file_to_load),
|
||||
# Are there weights that overwrite the initial weights?
|
||||
file_to_load = glob.glob(
|
||||
cfg.learning_parameters.overload_path + "/Weight_L" + str(id) + "*.npy"
|
||||
)
|
||||
|
||||
if len(file_to_load) > 1:
|
||||
raise Exception(
|
||||
f"Too many previous weights files {cfg.learning_parameters.overload_path}/Weight_L{id}*.npy"
|
||||
)
|
||||
|
||||
if len(file_to_load) == 1:
|
||||
network[id].weights = torch.tensor(
|
||||
np.load(file_to_load[0]),
|
||||
dtype=torch.float64,
|
||||
)
|
||||
eps_xy[id] = np.load(file_to_load)
|
||||
logging.info(f"File used: {file_to_load}")
|
||||
wf[id] = np.load(file_to_load[0])
|
||||
logging.info(f"File used: {file_to_load[0]}")
|
||||
|
||||
# Are there epsinlon xy files that overwrite the initial epsilon xy?
|
||||
file_to_load = glob.glob(
|
||||
cfg.learning_parameters.overload_path + "/EpsXY_L" + str(id) + "*.npy"
|
||||
)
|
||||
|
||||
if len(file_to_load) > 1:
|
||||
raise Exception(
|
||||
f"Too many previous epsilon xy files {cfg.learning_parameters.overload_path}/EpsXY_L{id}*.npy"
|
||||
)
|
||||
|
||||
if len(file_to_load) == 1:
|
||||
network[id].epsilon_xy = torch.tensor(
|
||||
np.load(file_to_load[0]),
|
||||
dtype=torch.float64,
|
||||
)
|
||||
eps_xy[id] = np.load(file_to_load[0])
|
||||
logging.info(f"File used: {file_to_load[0]}")
|
||||
|
||||
#######################################################################
|
||||
# Optimizer and LR Scheduler #
|
||||
|
|
Loading…
Reference in a new issue