Add files via upload

This commit is contained in:
David Rotermund 2023-07-22 14:53:31 +02:00 committed by GitHub
parent 34a688d546
commit 8505df62e3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 436 additions and 407 deletions

View file

@ -5,7 +5,9 @@ import argh
import time import time
import os import os
import json import json
import glob
from jsmin import jsmin from jsmin import jsmin
from natsort import natsorted
from functions.alicorn_data_loader import alicorn_data_loader from functions.alicorn_data_loader import alicorn_data_loader
from functions.train import train from functions.train import train
@ -81,6 +83,11 @@ def main(
scheduler_factor=float(config["scheduler_factor"]), scheduler_factor=float(config["scheduler_factor"]),
precision_100_percent=int(config["precision_100_percent"]), precision_100_percent=int(config["precision_100_percent"]),
scheduler_threshold=float(config["scheduler_threshold"]), scheduler_threshold=float(config["scheduler_threshold"]),
model_continue=bool(config["model_continue"]),
initial_model_path=str(config["initial_model_path"]),
tb_runs_path=str(config["tb_runs_path"]),
trained_models_path=str(config["trained_models_path"]),
performance_data_path=str(config["performance_data_path"]),
) )
@ -118,6 +125,11 @@ def run_network(
scheduler_factor: float, scheduler_factor: float,
precision_100_percent: int, precision_100_percent: int,
scheduler_threshold: float, scheduler_threshold: float,
model_continue: bool,
initial_model_path: str,
tb_runs_path: str,
trained_models_path: str,
performance_data_path: str,
) -> None: ) -> None:
# define device: # define device:
device_str: str = "cuda:0" if torch.cuda.is_available() else "cpu" device_str: str = "cuda:0" if torch.cuda.is_available() else "cpu"
@ -189,8 +201,9 @@ def run_network(
current = datetime.datetime.now().strftime("%d%m-%H%M") current = datetime.datetime.now().strftime("%d%m-%H%M")
# new tb session # new tb session
os.makedirs("tb_runs", exist_ok=True)
path: str = os.path.join("tb_runs", f"{model_name}") os.makedirs(tb_runs_path, exist_ok=True)
path: str = os.path.join(tb_runs_path, f"{model_name}")
tb = SummaryWriter(path) tb = SummaryWriter(path)
# -------------------------------------------------------------------------- # --------------------------------------------------------------------------
@ -208,6 +221,14 @@ def run_network(
logger.info(f"Pooling layer kernel: {mp_1_kernel_size}, stride: {mp_1_stride}") logger.info(f"Pooling layer kernel: {mp_1_kernel_size}, stride: {mp_1_stride}")
# define model: # define model:
if model_continue:
filename_list: list = natsorted(
glob.glob(os.path.join(initial_model_path, str("*.pt")))
)
assert len(filename_list) > 0
model_filename: str = filename_list[-1]
logger.info(f"Load filename: {model_filename}")
else:
model = make_cnn( model = make_cnn(
conv_out_channels_list=out_channels, conv_out_channels_list=out_channels,
conv_kernel_size=kernel_size, conv_kernel_size=kernel_size,
@ -223,6 +244,8 @@ def run_network(
l_relu_negative_slope=leak_relu_negative_slope, l_relu_negative_slope=leak_relu_negative_slope,
).to(device) ).to(device)
model = torch.load(model_filename, map_location=device)
logger.info(model) logger.info(model)
old_params: dict = {} old_params: dict = {}
@ -321,11 +344,12 @@ def run_network(
pt_filename: str = f"{model_name}_{epoch}Epoch_{current}.pt" pt_filename: str = f"{model_name}_{epoch}Epoch_{current}.pt"
logger.info("") logger.info("")
logger.info(f"Saved model: {pt_filename}") logger.info(f"Saved model: {pt_filename}")
os.makedirs("trained_models", exist_ok=True)
os.makedirs(trained_models_path, exist_ok=True)
torch.save( torch.save(
model, model,
os.path.join( os.path.join(
"trained_models", trained_models_path,
pt_filename, pt_filename,
), ),
) )
@ -370,9 +394,9 @@ def run_network(
save_name=model_name, save_name=model_name,
) )
os.makedirs("performance_data", exist_ok=True) os.makedirs(performance_data_path, exist_ok=True)
np.savez( np.savez(
os.path.join("performance_data", f"performances_{model_name}.npz"), os.path.join(performance_data_path, f"performances_{model_name}.npz"),
train_accuracy=np.array(train_accuracy), train_accuracy=np.array(train_accuracy),
test_accuracy=np.array(test_accuracy), test_accuracy=np.array(test_accuracy),
train_losses=np.array(train_losses), train_losses=np.array(train_losses),
@ -386,11 +410,11 @@ def run_network(
logger.info("") logger.info("")
logger.info(f"Saved model: {model_name}_{epoch}Epoch_{current}") logger.info(f"Saved model: {model_name}_{epoch}Epoch_{current}")
if save_model: if save_model:
os.makedirs("trained_models", exist_ok=True) os.makedirs(trained_models_path, exist_ok=True)
torch.save( torch.save(
model, model,
os.path.join( os.path.join(
"trained_models", trained_models_path,
f"{model_name}_{epoch}Epoch_{current}.pt", f"{model_name}_{epoch}Epoch_{current}.pt",
), ),
) )

View file

@ -1,10 +1,11 @@
{ {
"data_path": "/home/kk/Documents/Semester4/code/RenderStimuli/Output/", "data_path": "/home/kk/Documents/Semester4/code/RenderStimuli/Output/",
"model_continue": false, // true, (false)
"save_logging_messages": true, // (true), false "save_logging_messages": true, // (true), false
"display_logging_messages": true, // (true), false "display_logging_messages": true, // (true), false
"batch_size_train": 250, "batch_size_train": 250,
"batch_size_test": 500, "batch_size_test": 500,
"max_epochs": 2000, "max_epochs": 5000,
"save_model": true, "save_model": true,
"conv_0_kernel_size": 11, "conv_0_kernel_size": 11,
"mp_1_kernel_size": 3, "mp_1_kernel_size": 3,
@ -22,14 +23,14 @@
// LR Scheduler -> // LR Scheduler ->
"use_scheduler": true, // (true), false "use_scheduler": true, // (true), false
"scheduler_verbose": true, "scheduler_verbose": true,
"scheduler_factor": 0.1, //(0.1) "scheduler_factor": 0.025, //(0.1)
"scheduler_patience": 10, // (10) "scheduler_patience": 100, // (10)
"scheduler_threshold": 1e-5, // (1e-4) "scheduler_threshold": 1e-5, // (1e-4)
"minimum_learning_rate": 1e-8, "minimum_learning_rate": 1e-10,
"learning_rate": 0.0001, "learning_rate": 1e-5,
// <- LR Scheduler // <- LR Scheduler
"pooling_type": "max", // (max), average, none "pooling_type": "max", // (max), average, none
"conv_0_enable_softmax": false, // true, (false) "conv_0_enable_softmax": true, // true, (false)
"use_adam": true, // (true) => adam, false => SGD "use_adam": true, // (true) => adam, false => SGD
"condition": "Coignless", "condition": "Coignless",
"scale_data": 255.0, // (255.0) "scale_data": 255.0, // (255.0)
@ -48,5 +49,9 @@
], ],
"conv_stride_sizes": [ "conv_stride_sizes": [
1 1
] ],
"initial_model_path": "initial_models",
"tb_runs_path": "tb_runs",
"trained_models_path": "trained_models",
"performance_data_path": "performance_data"
} }