Better detection of previous files
This commit is contained in:
parent
f3385b34ac
commit
9cfd4eb740
1 changed files with 35 additions and 17 deletions
52
learn_it.py
52
learn_it.py
|
@ -41,6 +41,7 @@ import time
|
||||||
import dataconf
|
import dataconf
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
import glob
|
||||||
|
|
||||||
from Dataset import (
|
from Dataset import (
|
||||||
DatasetMaster,
|
DatasetMaster,
|
||||||
|
@ -220,26 +221,43 @@ for id in range(0, len(network)):
|
||||||
+ ".npy"
|
+ ".npy"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Are there weights that overwrite the initial weights?
|
for id in range(0, len(network)):
|
||||||
file_to_load: str = (
|
|
||||||
cfg.learning_parameters.overload_path + "/Weight_L" + str(id) + ".npy"
|
|
||||||
)
|
|
||||||
if os.path.exists(file_to_load) is True:
|
|
||||||
network[id].weights = torch.tensor(
|
|
||||||
np.load(file_to_load),
|
|
||||||
dtype=torch.float64,
|
|
||||||
)
|
|
||||||
wf[id] = np.load(file_to_load)
|
|
||||||
logging.info(f"File used: {file_to_load}")
|
|
||||||
|
|
||||||
file_to_load = cfg.learning_parameters.overload_path + "/EpsXY_L" + str(id) + ".npy"
|
# Are there weights that overwrite the initial weights?
|
||||||
if os.path.exists(file_to_load) is True:
|
file_to_load = glob.glob(
|
||||||
network[id].epsilon_xy = torch.tensor(
|
cfg.learning_parameters.overload_path + "/Weight_L" + str(id) + "*.npy"
|
||||||
np.load(file_to_load),
|
)
|
||||||
|
|
||||||
|
if len(file_to_load) > 1:
|
||||||
|
raise Exception(
|
||||||
|
f"Too many previous weights files {cfg.learning_parameters.overload_path}/Weight_L{id}*.npy"
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(file_to_load) == 1:
|
||||||
|
network[id].weights = torch.tensor(
|
||||||
|
np.load(file_to_load[0]),
|
||||||
dtype=torch.float64,
|
dtype=torch.float64,
|
||||||
)
|
)
|
||||||
eps_xy[id] = np.load(file_to_load)
|
wf[id] = np.load(file_to_load[0])
|
||||||
logging.info(f"File used: {file_to_load}")
|
logging.info(f"File used: {file_to_load[0]}")
|
||||||
|
|
||||||
|
# Are there epsinlon xy files that overwrite the initial epsilon xy?
|
||||||
|
file_to_load = glob.glob(
|
||||||
|
cfg.learning_parameters.overload_path + "/EpsXY_L" + str(id) + "*.npy"
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(file_to_load) > 1:
|
||||||
|
raise Exception(
|
||||||
|
f"Too many previous epsilon xy files {cfg.learning_parameters.overload_path}/EpsXY_L{id}*.npy"
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(file_to_load) == 1:
|
||||||
|
network[id].epsilon_xy = torch.tensor(
|
||||||
|
np.load(file_to_load[0]),
|
||||||
|
dtype=torch.float64,
|
||||||
|
)
|
||||||
|
eps_xy[id] = np.load(file_to_load[0])
|
||||||
|
logging.info(f"File used: {file_to_load[0]}")
|
||||||
|
|
||||||
#######################################################################
|
#######################################################################
|
||||||
# Optimizer and LR Scheduler #
|
# Optimizer and LR Scheduler #
|
||||||
|
|
Loading…
Reference in a new issue