Add files via upload
This commit is contained in:
parent
34a688d546
commit
8505df62e3
2 changed files with 436 additions and 407 deletions
|
@ -5,7 +5,9 @@ import argh
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import glob
|
||||||
from jsmin import jsmin
|
from jsmin import jsmin
|
||||||
|
from natsort import natsorted
|
||||||
|
|
||||||
from functions.alicorn_data_loader import alicorn_data_loader
|
from functions.alicorn_data_loader import alicorn_data_loader
|
||||||
from functions.train import train
|
from functions.train import train
|
||||||
|
@ -81,6 +83,11 @@ def main(
|
||||||
scheduler_factor=float(config["scheduler_factor"]),
|
scheduler_factor=float(config["scheduler_factor"]),
|
||||||
precision_100_percent=int(config["precision_100_percent"]),
|
precision_100_percent=int(config["precision_100_percent"]),
|
||||||
scheduler_threshold=float(config["scheduler_threshold"]),
|
scheduler_threshold=float(config["scheduler_threshold"]),
|
||||||
|
model_continue=bool(config["model_continue"]),
|
||||||
|
initial_model_path=str(config["initial_model_path"]),
|
||||||
|
tb_runs_path=str(config["tb_runs_path"]),
|
||||||
|
trained_models_path=str(config["trained_models_path"]),
|
||||||
|
performance_data_path=str(config["performance_data_path"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -118,6 +125,11 @@ def run_network(
|
||||||
scheduler_factor: float,
|
scheduler_factor: float,
|
||||||
precision_100_percent: int,
|
precision_100_percent: int,
|
||||||
scheduler_threshold: float,
|
scheduler_threshold: float,
|
||||||
|
model_continue: bool,
|
||||||
|
initial_model_path: str,
|
||||||
|
tb_runs_path: str,
|
||||||
|
trained_models_path: str,
|
||||||
|
performance_data_path: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
# define device:
|
# define device:
|
||||||
device_str: str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
device_str: str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||||
|
@ -189,8 +201,9 @@ def run_network(
|
||||||
current = datetime.datetime.now().strftime("%d%m-%H%M")
|
current = datetime.datetime.now().strftime("%d%m-%H%M")
|
||||||
|
|
||||||
# new tb session
|
# new tb session
|
||||||
os.makedirs("tb_runs", exist_ok=True)
|
|
||||||
path: str = os.path.join("tb_runs", f"{model_name}")
|
os.makedirs(tb_runs_path, exist_ok=True)
|
||||||
|
path: str = os.path.join(tb_runs_path, f"{model_name}")
|
||||||
tb = SummaryWriter(path)
|
tb = SummaryWriter(path)
|
||||||
|
|
||||||
# --------------------------------------------------------------------------
|
# --------------------------------------------------------------------------
|
||||||
|
@ -208,20 +221,30 @@ def run_network(
|
||||||
logger.info(f"Pooling layer kernel: {mp_1_kernel_size}, stride: {mp_1_stride}")
|
logger.info(f"Pooling layer kernel: {mp_1_kernel_size}, stride: {mp_1_stride}")
|
||||||
|
|
||||||
# define model:
|
# define model:
|
||||||
model = make_cnn(
|
if model_continue:
|
||||||
conv_out_channels_list=out_channels,
|
filename_list: list = natsorted(
|
||||||
conv_kernel_size=kernel_size,
|
glob.glob(os.path.join(initial_model_path, str("*.pt")))
|
||||||
conv_stride_size=stride,
|
)
|
||||||
conv_activation_function=activation_function,
|
assert len(filename_list) > 0
|
||||||
train_conv_0=train_first_layer,
|
model_filename: str = filename_list[-1]
|
||||||
conv_0_kernel_size=conv_0_kernel_size,
|
logger.info(f"Load filename: {model_filename}")
|
||||||
mp_1_kernel_size=mp_1_kernel_size,
|
else:
|
||||||
mp_1_stride=mp_1_stride,
|
model = make_cnn(
|
||||||
logger=logger,
|
conv_out_channels_list=out_channels,
|
||||||
pooling_type=pooling_type,
|
conv_kernel_size=kernel_size,
|
||||||
conv_0_enable_softmax=conv_0_enable_softmax,
|
conv_stride_size=stride,
|
||||||
l_relu_negative_slope=leak_relu_negative_slope,
|
conv_activation_function=activation_function,
|
||||||
).to(device)
|
train_conv_0=train_first_layer,
|
||||||
|
conv_0_kernel_size=conv_0_kernel_size,
|
||||||
|
mp_1_kernel_size=mp_1_kernel_size,
|
||||||
|
mp_1_stride=mp_1_stride,
|
||||||
|
logger=logger,
|
||||||
|
pooling_type=pooling_type,
|
||||||
|
conv_0_enable_softmax=conv_0_enable_softmax,
|
||||||
|
l_relu_negative_slope=leak_relu_negative_slope,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
model = torch.load(model_filename, map_location=device)
|
||||||
|
|
||||||
logger.info(model)
|
logger.info(model)
|
||||||
|
|
||||||
|
@ -321,11 +344,12 @@ def run_network(
|
||||||
pt_filename: str = f"{model_name}_{epoch}Epoch_{current}.pt"
|
pt_filename: str = f"{model_name}_{epoch}Epoch_{current}.pt"
|
||||||
logger.info("")
|
logger.info("")
|
||||||
logger.info(f"Saved model: {pt_filename}")
|
logger.info(f"Saved model: {pt_filename}")
|
||||||
os.makedirs("trained_models", exist_ok=True)
|
|
||||||
|
os.makedirs(trained_models_path, exist_ok=True)
|
||||||
torch.save(
|
torch.save(
|
||||||
model,
|
model,
|
||||||
os.path.join(
|
os.path.join(
|
||||||
"trained_models",
|
trained_models_path,
|
||||||
pt_filename,
|
pt_filename,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -370,9 +394,9 @@ def run_network(
|
||||||
save_name=model_name,
|
save_name=model_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
os.makedirs("performance_data", exist_ok=True)
|
os.makedirs(performance_data_path, exist_ok=True)
|
||||||
np.savez(
|
np.savez(
|
||||||
os.path.join("performance_data", f"performances_{model_name}.npz"),
|
os.path.join(performance_data_path, f"performances_{model_name}.npz"),
|
||||||
train_accuracy=np.array(train_accuracy),
|
train_accuracy=np.array(train_accuracy),
|
||||||
test_accuracy=np.array(test_accuracy),
|
test_accuracy=np.array(test_accuracy),
|
||||||
train_losses=np.array(train_losses),
|
train_losses=np.array(train_losses),
|
||||||
|
@ -386,11 +410,11 @@ def run_network(
|
||||||
logger.info("")
|
logger.info("")
|
||||||
logger.info(f"Saved model: {model_name}_{epoch}Epoch_{current}")
|
logger.info(f"Saved model: {model_name}_{epoch}Epoch_{current}")
|
||||||
if save_model:
|
if save_model:
|
||||||
os.makedirs("trained_models", exist_ok=True)
|
os.makedirs(trained_models_path, exist_ok=True)
|
||||||
torch.save(
|
torch.save(
|
||||||
model,
|
model,
|
||||||
os.path.join(
|
os.path.join(
|
||||||
"trained_models",
|
trained_models_path,
|
||||||
f"{model_name}_{epoch}Epoch_{current}.pt",
|
f"{model_name}_{epoch}Epoch_{current}.pt",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
19
config.json
19
config.json
|
@ -1,10 +1,11 @@
|
||||||
{
|
{
|
||||||
"data_path": "/home/kk/Documents/Semester4/code/RenderStimuli/Output/",
|
"data_path": "/home/kk/Documents/Semester4/code/RenderStimuli/Output/",
|
||||||
|
"model_continue": false, // true, (false)
|
||||||
"save_logging_messages": true, // (true), false
|
"save_logging_messages": true, // (true), false
|
||||||
"display_logging_messages": true, // (true), false
|
"display_logging_messages": true, // (true), false
|
||||||
"batch_size_train": 250,
|
"batch_size_train": 250,
|
||||||
"batch_size_test": 500,
|
"batch_size_test": 500,
|
||||||
"max_epochs": 2000,
|
"max_epochs": 5000,
|
||||||
"save_model": true,
|
"save_model": true,
|
||||||
"conv_0_kernel_size": 11,
|
"conv_0_kernel_size": 11,
|
||||||
"mp_1_kernel_size": 3,
|
"mp_1_kernel_size": 3,
|
||||||
|
@ -22,14 +23,14 @@
|
||||||
// LR Scheduler ->
|
// LR Scheduler ->
|
||||||
"use_scheduler": true, // (true), false
|
"use_scheduler": true, // (true), false
|
||||||
"scheduler_verbose": true,
|
"scheduler_verbose": true,
|
||||||
"scheduler_factor": 0.1, //(0.1)
|
"scheduler_factor": 0.025, //(0.1)
|
||||||
"scheduler_patience": 10, // (10)
|
"scheduler_patience": 100, // (10)
|
||||||
"scheduler_threshold": 1e-5, // (1e-4)
|
"scheduler_threshold": 1e-5, // (1e-4)
|
||||||
"minimum_learning_rate": 1e-8,
|
"minimum_learning_rate": 1e-10,
|
||||||
"learning_rate": 0.0001,
|
"learning_rate": 1e-5,
|
||||||
// <- LR Scheduler
|
// <- LR Scheduler
|
||||||
"pooling_type": "max", // (max), average, none
|
"pooling_type": "max", // (max), average, none
|
||||||
"conv_0_enable_softmax": false, // true, (false)
|
"conv_0_enable_softmax": true, // true, (false)
|
||||||
"use_adam": true, // (true) => adam, false => SGD
|
"use_adam": true, // (true) => adam, false => SGD
|
||||||
"condition": "Coignless",
|
"condition": "Coignless",
|
||||||
"scale_data": 255.0, // (255.0)
|
"scale_data": 255.0, // (255.0)
|
||||||
|
@ -48,5 +49,9 @@
|
||||||
],
|
],
|
||||||
"conv_stride_sizes": [
|
"conv_stride_sizes": [
|
||||||
1
|
1
|
||||||
]
|
],
|
||||||
|
"initial_model_path": "initial_models",
|
||||||
|
"tb_runs_path": "tb_runs",
|
||||||
|
"trained_models_path": "trained_models",
|
||||||
|
"performance_data_path": "performance_data"
|
||||||
}
|
}
|
Loading…
Reference in a new issue