Add files via upload
This commit is contained in:
parent
9f7c88df91
commit
659fbd071f
9 changed files with 931 additions and 0 deletions
400
cnn_training.py
Normal file
400
cnn_training.py
Normal file
|
@ -0,0 +1,400 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import datetime
|
||||||
|
import argh
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from jsmin import jsmin
|
||||||
|
|
||||||
|
from functions.alicorn_data_loader import alicorn_data_loader
|
||||||
|
from functions.train import train
|
||||||
|
from functions.test import test
|
||||||
|
from functions.make_cnn import make_cnn
|
||||||
|
from functions.set_seed import set_seed
|
||||||
|
from functions.plot_intermediate import plot_intermediate
|
||||||
|
from functions.create_logger import create_logger
|
||||||
|
|
||||||
|
|
||||||
|
# to disable logging output from Tensorflow
|
||||||
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
idx_conv_out_channels_list: int = 0,
|
||||||
|
idx_conv_kernel_sizes: int = 0,
|
||||||
|
idx_conv_stride_sizes: int = 0,
|
||||||
|
seed_counter: int = 0,
|
||||||
|
) -> None:
|
||||||
|
config_filenname = "config.json"
|
||||||
|
with open(config_filenname, "r") as file_handle:
|
||||||
|
config = json.loads(jsmin(file_handle.read()))
|
||||||
|
|
||||||
|
logger = create_logger(
|
||||||
|
save_logging_messages=bool(config["save_logging_messages"]),
|
||||||
|
display_logging_messages=bool(config["display_logging_messages"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# network settings:
|
||||||
|
conv_out_channels_list: list[list[int]] = config["conv_out_channels_list"]
|
||||||
|
conv_kernel_sizes: list[list[int]] = config["conv_kernel_sizes"]
|
||||||
|
conv_stride_sizes: list[int] = config["conv_stride_sizes"]
|
||||||
|
|
||||||
|
num_pfinkel: list = np.arange(
|
||||||
|
int(config["num_pfinkel_start"]),
|
||||||
|
int(config["num_pfinkel_stop"]),
|
||||||
|
int(config["num_pfinkel_step"]),
|
||||||
|
).tolist()
|
||||||
|
|
||||||
|
run_network(
|
||||||
|
out_channels=conv_out_channels_list[int(idx_conv_out_channels_list)],
|
||||||
|
kernel_size=conv_kernel_sizes[int(idx_conv_kernel_sizes)],
|
||||||
|
stride=conv_stride_sizes[int(idx_conv_stride_sizes)],
|
||||||
|
activation_function=str(config["activation_function"]),
|
||||||
|
train_first_layer=bool(config["train_first_layer"]),
|
||||||
|
seed_counter=seed_counter,
|
||||||
|
minimum_learning_rate=float(config["minimum_learning_rate"]),
|
||||||
|
conv_0_kernel_size=int(config["conv_0_kernel_size"]),
|
||||||
|
mp_1_kernel_size=int(config["mp_1_kernel_size"]),
|
||||||
|
mp_1_stride=int(config["mp_1_stride"]),
|
||||||
|
batch_size_train=int(config["batch_size_train"]),
|
||||||
|
batch_size_test=int(config["batch_size_test"]),
|
||||||
|
learning_rate=float(config["learning_rate"]),
|
||||||
|
max_epochs=int(config["max_epochs"]),
|
||||||
|
save_model=bool(config["save_model"]),
|
||||||
|
stimuli_per_pfinkel=int(config["stimuli_per_pfinkel"]),
|
||||||
|
num_pfinkel=num_pfinkel,
|
||||||
|
logger=logger,
|
||||||
|
save_ever_x_epochs=int(config["save_ever_x_epochs"]),
|
||||||
|
scheduler_patience=int(config["scheduler_patience"]),
|
||||||
|
condition=str(config["condition"]),
|
||||||
|
data_path=str(config["data_path"]),
|
||||||
|
pooling_type=str(config["pooling_type"]),
|
||||||
|
conv_0_enable_softmax=bool(config["conv_0_enable_softmax"]),
|
||||||
|
scale_data=int(config["scale_data"]),
|
||||||
|
use_scheduler=bool(config["use_scheduler"]),
|
||||||
|
use_adam=bool(config["use_adam"]),
|
||||||
|
use_plot_intermediate=bool(config["use_plot_intermediate"]),
|
||||||
|
leak_relu_negative_slope=float(config["leak_relu_negative_slope"]),
|
||||||
|
scheduler_verbose=bool(config["scheduler_verbose"]),
|
||||||
|
scheduler_factor=float(config["scheduler_factor"]),
|
||||||
|
precision_100_percent=int(config["precision_100_percent"]),
|
||||||
|
scheduler_threshold=float(config["scheduler_threshold"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_network(
|
||||||
|
out_channels: list[int],
|
||||||
|
kernel_size: list[int],
|
||||||
|
num_pfinkel: list,
|
||||||
|
logger,
|
||||||
|
stride: int,
|
||||||
|
activation_function: str,
|
||||||
|
train_first_layer: bool,
|
||||||
|
seed_counter: int,
|
||||||
|
minimum_learning_rate: float,
|
||||||
|
conv_0_kernel_size: int,
|
||||||
|
mp_1_kernel_size: int,
|
||||||
|
mp_1_stride: int,
|
||||||
|
scheduler_patience: int,
|
||||||
|
batch_size_train: int,
|
||||||
|
batch_size_test: int,
|
||||||
|
learning_rate: float,
|
||||||
|
max_epochs: int,
|
||||||
|
save_model: bool,
|
||||||
|
stimuli_per_pfinkel: int,
|
||||||
|
save_ever_x_epochs: int,
|
||||||
|
condition: str,
|
||||||
|
data_path: str,
|
||||||
|
pooling_type: str,
|
||||||
|
conv_0_enable_softmax: bool,
|
||||||
|
scale_data: float,
|
||||||
|
use_scheduler: bool,
|
||||||
|
use_adam: bool,
|
||||||
|
use_plot_intermediate: bool,
|
||||||
|
leak_relu_negative_slope: float,
|
||||||
|
scheduler_verbose: bool,
|
||||||
|
scheduler_factor: float,
|
||||||
|
precision_100_percent: int,
|
||||||
|
scheduler_threshold: float,
|
||||||
|
) -> None:
|
||||||
|
# define device:
|
||||||
|
device_str: str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||||
|
logger.info(f"Using {device_str} device")
|
||||||
|
device: torch.device = torch.device(device_str)
|
||||||
|
torch.set_default_dtype(torch.float32)
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
logger.info("-==- START -==-")
|
||||||
|
|
||||||
|
train_accuracy: list[float] = []
|
||||||
|
train_losses: list[float] = []
|
||||||
|
train_loss: list[float] = []
|
||||||
|
test_accuracy: list[float] = []
|
||||||
|
test_losses: list[float] = []
|
||||||
|
|
||||||
|
# prepare data:
|
||||||
|
|
||||||
|
logger.info(num_pfinkel)
|
||||||
|
logger.info(condition)
|
||||||
|
|
||||||
|
logger.info("Loading training data")
|
||||||
|
data_train = alicorn_data_loader(
|
||||||
|
num_pfinkel=num_pfinkel,
|
||||||
|
load_stimuli_per_pfinkel=stimuli_per_pfinkel,
|
||||||
|
condition=condition,
|
||||||
|
logger=logger,
|
||||||
|
data_path=data_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Loading test data")
|
||||||
|
data_test = alicorn_data_loader(
|
||||||
|
num_pfinkel=num_pfinkel,
|
||||||
|
load_stimuli_per_pfinkel=stimuli_per_pfinkel,
|
||||||
|
condition=condition,
|
||||||
|
logger=logger,
|
||||||
|
data_path=data_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Loading done!")
|
||||||
|
|
||||||
|
# data loader
|
||||||
|
loader_train = torch.utils.data.DataLoader(
|
||||||
|
data_train, shuffle=True, batch_size=batch_size_train
|
||||||
|
)
|
||||||
|
loader_test = torch.utils.data.DataLoader(
|
||||||
|
data_test, shuffle=False, batch_size=batch_size_test
|
||||||
|
)
|
||||||
|
|
||||||
|
previous_test_acc: float = -1
|
||||||
|
|
||||||
|
# set seed for reproducibility
|
||||||
|
set_seed(seed=int(seed_counter), logger=logger)
|
||||||
|
|
||||||
|
# number conv layer:
|
||||||
|
if train_first_layer:
|
||||||
|
num_conv_layers = len(out_channels)
|
||||||
|
else:
|
||||||
|
num_conv_layers = len(out_channels) if len(out_channels) >= 2 else 1
|
||||||
|
|
||||||
|
# determine num conv layers
|
||||||
|
model_name = (
|
||||||
|
f"ArghCNN__MPk3s2_numConvLayers{num_conv_layers}"
|
||||||
|
f"_outChannels{out_channels}_kernelSize{kernel_size}_"
|
||||||
|
f"{activation_function}_stride{stride}_"
|
||||||
|
f"trainFirstConvLayer{train_first_layer}_"
|
||||||
|
f"seed{seed_counter}_{condition}_MPk3s2"
|
||||||
|
)
|
||||||
|
current = datetime.datetime.now().strftime("%d%m-%H%M")
|
||||||
|
|
||||||
|
# new tb session
|
||||||
|
os.makedirs("tb_runs", exist_ok=True)
|
||||||
|
path: str = os.path.join("tb_runs", f"{model_name}")
|
||||||
|
tb = SummaryWriter(path)
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# print network configuration:
|
||||||
|
logger.info("----------------------------------------------------")
|
||||||
|
logger.info(f"Number conv layers: {num_conv_layers}")
|
||||||
|
logger.info(f"Output channels: {out_channels}")
|
||||||
|
logger.info(f"Kernel sizes: {kernel_size}")
|
||||||
|
logger.info(f"Stride: {stride}")
|
||||||
|
logger.info(f"Activation function: {activation_function}")
|
||||||
|
logger.info(f"Training conv 0: {train_first_layer}")
|
||||||
|
logger.info(f"Seed: {seed_counter}")
|
||||||
|
logger.info(f"LR-scheduler patience: {scheduler_patience}")
|
||||||
|
logger.info(f"Pooling layer kernel: {mp_1_kernel_size}, stride: {mp_1_stride}")
|
||||||
|
|
||||||
|
# define model:
|
||||||
|
model = make_cnn(
|
||||||
|
conv_out_channels_list=out_channels,
|
||||||
|
conv_kernel_size=kernel_size,
|
||||||
|
conv_stride_size=stride,
|
||||||
|
conv_activation_function=activation_function,
|
||||||
|
train_conv_0=train_first_layer,
|
||||||
|
conv_0_kernel_size=conv_0_kernel_size,
|
||||||
|
mp_1_kernel_size=mp_1_kernel_size,
|
||||||
|
mp_1_stride=mp_1_stride,
|
||||||
|
logger=logger,
|
||||||
|
pooling_type=pooling_type,
|
||||||
|
conv_0_enable_softmax=conv_0_enable_softmax,
|
||||||
|
l_relu_negative_slope=leak_relu_negative_slope,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
logger.info(model)
|
||||||
|
|
||||||
|
old_params: dict = {}
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
old_params[name] = param.data.detach().cpu().clone()
|
||||||
|
|
||||||
|
# pararmeters for training:
|
||||||
|
param_list: list = []
|
||||||
|
|
||||||
|
for i in range(0, len(model)):
|
||||||
|
if (not train_first_layer) and (i == 0):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
for name, param in model[i].named_parameters():
|
||||||
|
logger.info(f"Learning parameter: layer: {i} name: {name}")
|
||||||
|
param_list.append(param)
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
assert (
|
||||||
|
torch.isfinite(param.data).sum().cpu()
|
||||||
|
== torch.tensor(param.data.size()).prod()
|
||||||
|
), name
|
||||||
|
|
||||||
|
# optimizer and learning rate scheduler
|
||||||
|
if use_adam:
|
||||||
|
optimizer = torch.optim.Adam(param_list, lr=learning_rate)
|
||||||
|
else:
|
||||||
|
optimizer = torch.optim.SGD(param_list, lr=learning_rate) # type: ignore
|
||||||
|
|
||||||
|
if use_scheduler:
|
||||||
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
|
optimizer,
|
||||||
|
patience=scheduler_patience,
|
||||||
|
eps=minimum_learning_rate / 10,
|
||||||
|
verbose=scheduler_verbose,
|
||||||
|
factor=scheduler_factor,
|
||||||
|
threshold=scheduler_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
# training loop:
|
||||||
|
logger.info("-==- Data and network loader: Done -==-")
|
||||||
|
t_dis0 = time.perf_counter()
|
||||||
|
for epoch in range(1, max_epochs + 1):
|
||||||
|
# train
|
||||||
|
logger.info("-==- Training... -==-")
|
||||||
|
running_loss = train(
|
||||||
|
model=model,
|
||||||
|
loader=loader_train,
|
||||||
|
optimizer=optimizer,
|
||||||
|
epoch=epoch,
|
||||||
|
device=device,
|
||||||
|
tb=tb,
|
||||||
|
test_acc=previous_test_acc,
|
||||||
|
logger=logger,
|
||||||
|
train_accuracy=train_accuracy,
|
||||||
|
train_losses=train_losses,
|
||||||
|
train_loss=train_loss,
|
||||||
|
scale_data=scale_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# logging:
|
||||||
|
logger.info("")
|
||||||
|
|
||||||
|
logger.info("Check for changes in the weights:")
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if isinstance(old_params[name], torch.Tensor) and isinstance(
|
||||||
|
param.data, torch.Tensor
|
||||||
|
):
|
||||||
|
temp_torch = param.data.detach().cpu().clone()
|
||||||
|
if old_params[name].ndim == temp_torch.ndim:
|
||||||
|
if old_params[name].size() == temp_torch.size():
|
||||||
|
abs_diff = torch.abs(old_params[name] - temp_torch).max()
|
||||||
|
logger.info(f"Parameter {name}: {abs_diff:.3e}")
|
||||||
|
|
||||||
|
old_params[name] = temp_torch
|
||||||
|
|
||||||
|
logger.info("")
|
||||||
|
|
||||||
|
logger.info("-==- Testing... -==-")
|
||||||
|
previous_test_acc = test( # type: ignore
|
||||||
|
model=model,
|
||||||
|
loader=loader_test,
|
||||||
|
device=device,
|
||||||
|
tb=tb,
|
||||||
|
epoch=epoch,
|
||||||
|
logger=logger,
|
||||||
|
test_accuracy=test_accuracy,
|
||||||
|
test_losses=test_losses,
|
||||||
|
scale_data=scale_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Time required: {time.perf_counter()-t_dis0:.2e} sec")
|
||||||
|
|
||||||
|
# save model after every 100th epoch:
|
||||||
|
if save_model and (epoch % save_ever_x_epochs == 0):
|
||||||
|
pt_filename: str = f"{model_name}_{epoch}Epoch_{current}.pt"
|
||||||
|
logger.info("")
|
||||||
|
logger.info(f"Saved model: {pt_filename}")
|
||||||
|
os.makedirs("trained_models", exist_ok=True)
|
||||||
|
torch.save(
|
||||||
|
model,
|
||||||
|
os.path.join(
|
||||||
|
"trained_models",
|
||||||
|
pt_filename,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# check nan
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
assert (
|
||||||
|
torch.isfinite(param.data).sum().cpu()
|
||||||
|
== torch.tensor(param.data.size()).prod()
|
||||||
|
), name
|
||||||
|
|
||||||
|
# update scheduler
|
||||||
|
if use_scheduler:
|
||||||
|
if scheduler_verbose and isinstance(scheduler.best, float):
|
||||||
|
logger.info(
|
||||||
|
"Step LR scheduler: "
|
||||||
|
f"Loss: {running_loss:.2e} "
|
||||||
|
f"Best: {scheduler.best:.2e} "
|
||||||
|
f"Delta: {running_loss-scheduler.best:.2e} "
|
||||||
|
f"Threshold: {scheduler.threshold:.2e} "
|
||||||
|
f"Number of bad epochs: {scheduler.num_bad_epochs} "
|
||||||
|
f"Patience: {scheduler.patience} "
|
||||||
|
)
|
||||||
|
scheduler.step(running_loss)
|
||||||
|
|
||||||
|
# stop learning: lr too small
|
||||||
|
if optimizer.param_groups[0]["lr"] <= minimum_learning_rate:
|
||||||
|
logger.info("Learning rate is too small. Stop training.")
|
||||||
|
break
|
||||||
|
|
||||||
|
# stop learning: done
|
||||||
|
if round(previous_test_acc, precision_100_percent) == 100.0:
|
||||||
|
logger.info("100% test performance reached. Stop training.")
|
||||||
|
break
|
||||||
|
|
||||||
|
if use_plot_intermediate:
|
||||||
|
plot_intermediate(
|
||||||
|
train_accuracy=train_accuracy,
|
||||||
|
test_accuracy=test_accuracy,
|
||||||
|
train_losses=train_losses,
|
||||||
|
test_losses=test_losses,
|
||||||
|
save_name=model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
os.makedirs("performance_data", exist_ok=True)
|
||||||
|
np.savez(
|
||||||
|
os.path.join("performance_data", f"performances_{model_name}.npz"),
|
||||||
|
train_accuracy=np.array(train_accuracy),
|
||||||
|
test_accuracy=np.array(test_accuracy),
|
||||||
|
train_losses=np.array(train_losses),
|
||||||
|
test_losses=np.array(test_losses),
|
||||||
|
)
|
||||||
|
|
||||||
|
# end TB session:
|
||||||
|
tb.close()
|
||||||
|
|
||||||
|
# print model name:
|
||||||
|
logger.info("")
|
||||||
|
logger.info(f"Saved model: {model_name}_{epoch}Epoch_{current}")
|
||||||
|
if save_model:
|
||||||
|
os.makedirs("trained_models", exist_ok=True)
|
||||||
|
torch.save(
|
||||||
|
model,
|
||||||
|
os.path.join(
|
||||||
|
"trained_models",
|
||||||
|
f"{model_name}_{epoch}Epoch_{current}.pt",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
argh.dispatch_command(main)
|
52
config.json
Normal file
52
config.json
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
{
|
||||||
|
"data_path": "/home/kk/Documents/Semester4/code/RenderStimuli/Output/",
|
||||||
|
"save_logging_messages": true, // (true), false
|
||||||
|
"display_logging_messages": true, // (true), false
|
||||||
|
"batch_size_train": 250,
|
||||||
|
"batch_size_test": 500,
|
||||||
|
"max_epochs": 2000,
|
||||||
|
"save_model": true,
|
||||||
|
"conv_0_kernel_size": 11,
|
||||||
|
"mp_1_kernel_size": 3,
|
||||||
|
"mp_1_stride": 2,
|
||||||
|
"use_plot_intermediate": false, // true, (false)
|
||||||
|
"stimuli_per_pfinkel": 30000,
|
||||||
|
"num_pfinkel_start": 0,
|
||||||
|
"num_pfinkel_stop": 10,
|
||||||
|
"num_pfinkel_step": 10,
|
||||||
|
"precision_100_percent": 4, // (4)
|
||||||
|
"train_first_layer": true, // true, (false)
|
||||||
|
"save_ever_x_epochs": 10, // (10)
|
||||||
|
"activation_function": "leaky relu", // tanh, relu, (leaky relu), none
|
||||||
|
"leak_relu_negative_slope": 0.1, // (0.1)
|
||||||
|
// LR Scheduler ->
|
||||||
|
"use_scheduler": true, // (true), false
|
||||||
|
"scheduler_verbose": true,
|
||||||
|
"scheduler_factor": 0.1, //(0.1)
|
||||||
|
"scheduler_patience": 10, // (10)
|
||||||
|
"scheduler_threshold": 1e-5, // (1e-4)
|
||||||
|
"minimum_learning_rate": 1e-8,
|
||||||
|
"learning_rate": 0.0001,
|
||||||
|
// <- LR Scheduler
|
||||||
|
"pooling_type": "max", // (max), average, none
|
||||||
|
"conv_0_enable_softmax": false, // true, (false)
|
||||||
|
"use_adam": true, // (true) => adam, false => SGD
|
||||||
|
"condition": "Coignless",
|
||||||
|
"scale_data": 255.0, // (255.0)
|
||||||
|
"conv_out_channels_list": [
|
||||||
|
[
|
||||||
|
32,
|
||||||
|
8,
|
||||||
|
8
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"conv_kernel_sizes": [
|
||||||
|
[
|
||||||
|
7,
|
||||||
|
15
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"conv_stride_sizes": [
|
||||||
|
1
|
||||||
|
]
|
||||||
|
}
|
105
functions/alicorn_data_loader.py
Normal file
105
functions/alicorn_data_loader.py
Normal file
|
@ -0,0 +1,105 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def alicorn_data_loader(
|
||||||
|
num_pfinkel: list[int] | None,
|
||||||
|
load_stimuli_per_pfinkel: int,
|
||||||
|
condition: str,
|
||||||
|
logger,
|
||||||
|
data_path: str,
|
||||||
|
) -> torch.utils.data.TensorDataset:
|
||||||
|
"""
|
||||||
|
- num_pfinkel: list of the angles that should be loaded (ranging from
|
||||||
|
0-90). If None: all pfinkels loaded
|
||||||
|
- stimuli_per_pfinkel: defines amount of stimuli per path angle but
|
||||||
|
for label 0 and label 1 seperatly (e.g., stimuli_per_pfinkel = 1000:
|
||||||
|
1000 stimuli = label 1, 1000 stimuli = label 0)
|
||||||
|
"""
|
||||||
|
filename: str | None = None
|
||||||
|
if condition == "Angular":
|
||||||
|
filename = "angular_angle"
|
||||||
|
elif condition == "Coignless":
|
||||||
|
filename = "base_angle"
|
||||||
|
elif condition == "Natural":
|
||||||
|
filename = "corner_angle"
|
||||||
|
else:
|
||||||
|
filename = None
|
||||||
|
assert filename is not None
|
||||||
|
filepaths: str = os.path.join(data_path, f"{condition}")
|
||||||
|
|
||||||
|
stimuli_per_pfinkel: int = 100000
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
|
||||||
|
# for angles and batches
|
||||||
|
if num_pfinkel is None:
|
||||||
|
angle: list[int] = np.arange(0, 100, 10).tolist()
|
||||||
|
else:
|
||||||
|
angle = num_pfinkel
|
||||||
|
|
||||||
|
assert isinstance(angle, list)
|
||||||
|
|
||||||
|
batch: list[int] = np.arange(1, 11, 1).tolist()
|
||||||
|
|
||||||
|
if load_stimuli_per_pfinkel <= (stimuli_per_pfinkel // len(batch)):
|
||||||
|
num_img_per_pfinkel: int = load_stimuli_per_pfinkel
|
||||||
|
num_batches: int = 1
|
||||||
|
else:
|
||||||
|
# handle case where more than 10,000 stimuli per pfinkel needed
|
||||||
|
num_batches = load_stimuli_per_pfinkel // (stimuli_per_pfinkel // len(batch))
|
||||||
|
num_img_per_pfinkel = load_stimuli_per_pfinkel // num_batches
|
||||||
|
|
||||||
|
logger.info(f"{num_batches} batches")
|
||||||
|
logger.info(f"{num_img_per_pfinkel} stimuli per pfinkel.")
|
||||||
|
|
||||||
|
# initialize data and label tensors:
|
||||||
|
num_stimuli: int = len(angle) * num_batches * num_img_per_pfinkel * 2
|
||||||
|
data_tensor: torch.Tensor = torch.empty(
|
||||||
|
(num_stimuli, 200, 200), dtype=torch.uint8, device=torch.device("cpu")
|
||||||
|
)
|
||||||
|
label_tensor: torch.Tensor = torch.empty(
|
||||||
|
(num_stimuli), dtype=torch.int64, device=torch.device("cpu")
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"data tensor shape: {data_tensor.shape}")
|
||||||
|
logger.info(f"label tensor shape: {label_tensor.shape}")
|
||||||
|
|
||||||
|
# append data
|
||||||
|
idx: int = 0
|
||||||
|
for i in range(len(angle)):
|
||||||
|
for j in range(num_batches):
|
||||||
|
# load contour
|
||||||
|
temp_filename: str = (
|
||||||
|
f"{filename}_{angle[i]:03}_b{batch[j]:03}_n10000_RENDERED.npz"
|
||||||
|
)
|
||||||
|
contour_filename: str = os.path.join(filepaths, temp_filename)
|
||||||
|
c_data = np.load(contour_filename)
|
||||||
|
data_tensor[idx : idx + num_img_per_pfinkel, ...] = torch.tensor(
|
||||||
|
c_data["gaborfield"][:num_img_per_pfinkel, ...],
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
)
|
||||||
|
label_tensor[idx : idx + num_img_per_pfinkel] = int(1)
|
||||||
|
idx += num_img_per_pfinkel
|
||||||
|
|
||||||
|
# next append distractor stimuli
|
||||||
|
for i in range(len(angle)):
|
||||||
|
for j in range(num_batches):
|
||||||
|
# load distractor
|
||||||
|
temp_filename = (
|
||||||
|
f"{filename}_{angle[i]:03}_dist_b{batch[j]:03}_n10000_RENDERED.npz"
|
||||||
|
)
|
||||||
|
distractor_filename: str = os.path.join(filepaths, temp_filename)
|
||||||
|
nc_data = np.load(distractor_filename)
|
||||||
|
data_tensor[idx : idx + num_img_per_pfinkel, ...] = torch.tensor(
|
||||||
|
nc_data["gaborfield"][:num_img_per_pfinkel, ...],
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
)
|
||||||
|
label_tensor[idx : idx + num_img_per_pfinkel] = int(0)
|
||||||
|
idx += num_img_per_pfinkel
|
||||||
|
|
||||||
|
return torch.utils.data.TensorDataset(label_tensor, data_tensor.unsqueeze(1))
|
35
functions/create_logger.py
Normal file
35
functions/create_logger.py
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
import logging
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def create_logger(save_logging_messages: bool, display_logging_messages: bool):
|
||||||
|
now = datetime.datetime.now()
|
||||||
|
dt_string_filename = now.strftime("%Y_%m_%d_%H_%M_%S")
|
||||||
|
|
||||||
|
logger = logging.getLogger("MyLittleLogger")
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
if save_logging_messages:
|
||||||
|
time_format = "%b %-d %Y %H:%M:%S"
|
||||||
|
logformat = "%(asctime)s %(message)s"
|
||||||
|
file_formatter = logging.Formatter(fmt=logformat, datefmt=time_format)
|
||||||
|
os.makedirs("logs", exist_ok=True)
|
||||||
|
file_handler = logging.FileHandler(
|
||||||
|
os.path.join("logs", f"log_{dt_string_filename}.txt")
|
||||||
|
)
|
||||||
|
file_handler.setLevel(logging.INFO)
|
||||||
|
file_handler.setFormatter(file_formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
if display_logging_messages:
|
||||||
|
time_format = "%H:%M:%S"
|
||||||
|
logformat = "%(asctime)s %(message)s"
|
||||||
|
stream_formatter = logging.Formatter(fmt=logformat, datefmt=time_format)
|
||||||
|
|
||||||
|
stream_handler = logging.StreamHandler()
|
||||||
|
stream_handler.setLevel(logging.INFO)
|
||||||
|
stream_handler.setFormatter(stream_formatter)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
return logger
|
114
functions/make_cnn.py
Normal file
114
functions/make_cnn.py
Normal file
|
@ -0,0 +1,114 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def make_cnn(
|
||||||
|
conv_out_channels_list: list[int],
|
||||||
|
conv_kernel_size: list[int],
|
||||||
|
conv_stride_size: int,
|
||||||
|
conv_activation_function: str,
|
||||||
|
train_conv_0: bool,
|
||||||
|
logger,
|
||||||
|
conv_0_kernel_size: int,
|
||||||
|
mp_1_kernel_size: int,
|
||||||
|
mp_1_stride: int,
|
||||||
|
pooling_type: str,
|
||||||
|
conv_0_enable_softmax: bool,
|
||||||
|
l_relu_negative_slope: float,
|
||||||
|
) -> torch.nn.Sequential:
|
||||||
|
assert len(conv_out_channels_list) >= 1
|
||||||
|
assert len(conv_out_channels_list) == len(conv_kernel_size) + 1
|
||||||
|
|
||||||
|
cnn = torch.nn.Sequential()
|
||||||
|
|
||||||
|
# Fixed structure
|
||||||
|
cnn.append(
|
||||||
|
torch.nn.Conv2d(
|
||||||
|
in_channels=1,
|
||||||
|
out_channels=conv_out_channels_list[0] if train_conv_0 else 32,
|
||||||
|
kernel_size=conv_0_kernel_size,
|
||||||
|
stride=1,
|
||||||
|
bias=train_conv_0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if conv_0_enable_softmax:
|
||||||
|
cnn.append(torch.nn.Softmax(dim=1))
|
||||||
|
|
||||||
|
setting_understood: bool = False
|
||||||
|
if conv_activation_function.upper() == str("relu").upper():
|
||||||
|
cnn.append(torch.nn.ReLU())
|
||||||
|
setting_understood = True
|
||||||
|
elif conv_activation_function.upper() == str("leaky relu").upper():
|
||||||
|
cnn.append(torch.nn.LeakyReLU(negative_slope=l_relu_negative_slope))
|
||||||
|
setting_understood = True
|
||||||
|
elif conv_activation_function.upper() == str("tanh").upper():
|
||||||
|
cnn.append(torch.nn.Tanh())
|
||||||
|
setting_understood = True
|
||||||
|
elif conv_activation_function.upper() == str("none").upper():
|
||||||
|
setting_understood = True
|
||||||
|
assert setting_understood
|
||||||
|
|
||||||
|
setting_understood = False
|
||||||
|
if pooling_type.upper() == str("max").upper():
|
||||||
|
cnn.append(torch.nn.MaxPool2d(kernel_size=mp_1_kernel_size, stride=mp_1_stride))
|
||||||
|
setting_understood = True
|
||||||
|
elif pooling_type.upper() == str("average").upper():
|
||||||
|
cnn.append(torch.nn.AvgPool2d(kernel_size=mp_1_kernel_size, stride=mp_1_stride))
|
||||||
|
setting_understood = True
|
||||||
|
elif pooling_type.upper() == str("none").upper():
|
||||||
|
setting_understood = True
|
||||||
|
assert setting_understood
|
||||||
|
|
||||||
|
# Changing structure
|
||||||
|
for i in range(1, len(conv_out_channels_list)):
|
||||||
|
if i == 1 and not train_conv_0:
|
||||||
|
in_channels = 32
|
||||||
|
else:
|
||||||
|
in_channels = conv_out_channels_list[i - 1]
|
||||||
|
cnn.append(
|
||||||
|
torch.nn.Conv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=conv_out_channels_list[i],
|
||||||
|
kernel_size=conv_kernel_size[i - 1],
|
||||||
|
stride=conv_stride_size,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
setting_understood = False
|
||||||
|
if conv_activation_function.upper() == str("relu").upper():
|
||||||
|
cnn.append(torch.nn.ReLU())
|
||||||
|
setting_understood = True
|
||||||
|
elif conv_activation_function.upper() == str("leaky relu").upper():
|
||||||
|
cnn.append(torch.nn.LeakyReLU(negative_slope=l_relu_negative_slope))
|
||||||
|
setting_understood = True
|
||||||
|
elif conv_activation_function.upper() == str("tanh").upper():
|
||||||
|
cnn.append(torch.nn.Tanh())
|
||||||
|
setting_understood = True
|
||||||
|
elif conv_activation_function.upper() == str("none").upper():
|
||||||
|
setting_understood = True
|
||||||
|
|
||||||
|
assert setting_understood
|
||||||
|
|
||||||
|
# Fixed structure
|
||||||
|
# define fully connected layer:
|
||||||
|
cnn.append(torch.nn.Flatten(start_dim=1))
|
||||||
|
cnn.append(torch.nn.LazyLinear(2, bias=True))
|
||||||
|
|
||||||
|
# if conv1 not trained:
|
||||||
|
filename_load_weight_0: str | None = None
|
||||||
|
if train_conv_0 is False and cnn[0]._parameters["weight"].shape[0] == 32:
|
||||||
|
filename_load_weight_0 = "weights_radius10.npy"
|
||||||
|
if train_conv_0 is False and cnn[0]._parameters["weight"].shape[0] == 16:
|
||||||
|
filename_load_weight_0 = "8orient_2phase_weights.npy"
|
||||||
|
|
||||||
|
if filename_load_weight_0 is not None:
|
||||||
|
logger.info(f"Replace weights in CNN 0 with {filename_load_weight_0}")
|
||||||
|
cnn[0]._parameters["weight"] = torch.tensor(
|
||||||
|
np.load(filename_load_weight_0),
|
||||||
|
dtype=cnn[0]._parameters["weight"].dtype,
|
||||||
|
requires_grad=False,
|
||||||
|
device=cnn[0]._parameters["weight"].device,
|
||||||
|
)
|
||||||
|
|
||||||
|
return cnn
|
76
functions/plot_intermediate.py
Normal file
76
functions/plot_intermediate.py
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib as mpl
|
||||||
|
import os
|
||||||
|
|
||||||
|
mpl.rcParams["text.usetex"] = True
|
||||||
|
mpl.rcParams["font.family"] = "serif"
|
||||||
|
|
||||||
|
|
||||||
|
def plot_intermediate(
|
||||||
|
train_accuracy: list[float],
|
||||||
|
test_accuracy: list[float],
|
||||||
|
train_losses: list[float],
|
||||||
|
test_losses: list[float],
|
||||||
|
save_name: str,
|
||||||
|
reduction_factor: int = 1,
|
||||||
|
) -> None:
|
||||||
|
assert len(train_accuracy) == len(test_accuracy)
|
||||||
|
assert len(train_accuracy) == len(train_losses)
|
||||||
|
assert len(train_accuracy) == len(test_losses)
|
||||||
|
|
||||||
|
max_epochs: int = len(train_accuracy)
|
||||||
|
# set stepsize
|
||||||
|
x = np.arange(1, max_epochs + 1)
|
||||||
|
|
||||||
|
stepsize = max_epochs // reduction_factor
|
||||||
|
|
||||||
|
# accuracies
|
||||||
|
plt.figure(figsize=[12, 7])
|
||||||
|
plt.subplot(2, 1, 1)
|
||||||
|
|
||||||
|
plt.plot(x, np.array(train_accuracy), label="Train")
|
||||||
|
plt.plot(x, np.array(test_accuracy), label="Test")
|
||||||
|
plt.title("Training and Testing Accuracy", fontsize=18)
|
||||||
|
plt.xlabel("Epoch", fontsize=18)
|
||||||
|
plt.ylabel("Accuracy (\\%)", fontsize=18)
|
||||||
|
plt.legend(fontsize=14)
|
||||||
|
plt.xticks(
|
||||||
|
np.concatenate((np.array([1]), np.arange(stepsize, max_epochs + 1, stepsize))),
|
||||||
|
np.concatenate((np.array([1]), np.arange(stepsize, max_epochs + 1, stepsize))),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Increase tick label font size
|
||||||
|
plt.xticks(fontsize=16)
|
||||||
|
plt.yticks(fontsize=16)
|
||||||
|
plt.grid(True)
|
||||||
|
|
||||||
|
# losses
|
||||||
|
plt.subplot(2, 1, 2)
|
||||||
|
plt.plot(x, np.array(train_losses), label="Train")
|
||||||
|
plt.plot(x, np.array(test_losses), label="Test")
|
||||||
|
plt.title("Training and Testing Losses", fontsize=18)
|
||||||
|
plt.xlabel("Epoch", fontsize=18)
|
||||||
|
plt.ylabel("Loss", fontsize=18)
|
||||||
|
plt.legend(fontsize=14)
|
||||||
|
plt.xticks(
|
||||||
|
np.concatenate((np.array([1]), np.arange(stepsize, max_epochs + 1, stepsize))),
|
||||||
|
np.concatenate((np.array([1]), np.arange(stepsize, max_epochs + 1, stepsize))),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Increase tick label font size
|
||||||
|
plt.xticks(fontsize=16)
|
||||||
|
plt.yticks(fontsize=16)
|
||||||
|
plt.grid(True)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
os.makedirs("performance_plots", exist_ok=True)
|
||||||
|
plt.savefig(
|
||||||
|
os.path.join(
|
||||||
|
"performance_plots",
|
||||||
|
f"performance_{save_name}.pdf",
|
||||||
|
),
|
||||||
|
dpi=300,
|
||||||
|
bbox_inches="tight",
|
||||||
|
)
|
||||||
|
plt.show()
|
11
functions/set_seed.py
Normal file
11
functions/set_seed.py
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(seed: int, logger) -> None:
|
||||||
|
# set seed for all used modules
|
||||||
|
logger.info(f"set seed to {seed}")
|
||||||
|
torch.manual_seed(seed=seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed=seed)
|
||||||
|
np.random.seed(seed=seed)
|
58
functions/test.py
Normal file
58
functions/test.py
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test(
|
||||||
|
model: torch.nn.modules.container.Sequential,
|
||||||
|
loader: torch.utils.data.dataloader.DataLoader,
|
||||||
|
device: torch.device,
|
||||||
|
tb,
|
||||||
|
epoch: int,
|
||||||
|
logger: logging.Logger,
|
||||||
|
test_accuracy: list[float],
|
||||||
|
test_losses: list[float],
|
||||||
|
scale_data: float,
|
||||||
|
) -> float:
|
||||||
|
test_loss: float = 0.0
|
||||||
|
correct: int = 0
|
||||||
|
pattern_count: float = 0.0
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
for data in loader:
|
||||||
|
label = data[0].to(device)
|
||||||
|
image = data[1].type(dtype=torch.float32).to(device)
|
||||||
|
if scale_data > 0:
|
||||||
|
image /= scale_data
|
||||||
|
|
||||||
|
output = model(image)
|
||||||
|
|
||||||
|
# loss and optimization
|
||||||
|
loss = torch.nn.functional.cross_entropy(output, label, reduction="sum")
|
||||||
|
pattern_count += float(label.shape[0])
|
||||||
|
test_loss += loss.item()
|
||||||
|
prediction = output.argmax(dim=1)
|
||||||
|
correct += prediction.eq(label).sum().item()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
(
|
||||||
|
"Test set:"
|
||||||
|
f" Average loss: {test_loss / pattern_count:.3e},"
|
||||||
|
f" Accuracy: {correct}/{pattern_count},"
|
||||||
|
f"({100.0 * correct / pattern_count:.2f}%)"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info("")
|
||||||
|
|
||||||
|
acc = 100.0 * correct / pattern_count
|
||||||
|
test_losses.append(test_loss / pattern_count)
|
||||||
|
test_accuracy.append(acc)
|
||||||
|
|
||||||
|
# add to tb:
|
||||||
|
tb.add_scalar("Test Loss", (test_loss / pattern_count), epoch)
|
||||||
|
tb.add_scalar("Test Performance", 100.0 * correct / pattern_count, epoch)
|
||||||
|
tb.add_scalar("Test Number Correct", correct, epoch)
|
||||||
|
tb.flush()
|
||||||
|
|
||||||
|
return acc
|
80
functions/train.py
Normal file
80
functions/train.py
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
def train(
|
||||||
|
model: torch.nn.modules.container.Sequential,
|
||||||
|
loader: torch.utils.data.dataloader.DataLoader,
|
||||||
|
optimizer: torch.optim.Adam | torch.optim.SGD,
|
||||||
|
epoch: int,
|
||||||
|
device: torch.device,
|
||||||
|
tb,
|
||||||
|
test_acc,
|
||||||
|
logger: logging.Logger,
|
||||||
|
train_accuracy: list[float],
|
||||||
|
train_losses: list[float],
|
||||||
|
train_loss: list[float],
|
||||||
|
scale_data: float,
|
||||||
|
) -> float:
|
||||||
|
num_train_pattern: int = 0
|
||||||
|
running_loss: float = 0.0
|
||||||
|
correct: int = 0
|
||||||
|
pattern_count: float = 0.0
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
for data in loader:
|
||||||
|
label = data[0].to(device)
|
||||||
|
image = data[1].type(dtype=torch.float32).to(device)
|
||||||
|
if scale_data > 0:
|
||||||
|
image /= scale_data
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
output = model(image)
|
||||||
|
loss = torch.nn.functional.cross_entropy(output, label, reduction="sum")
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# for loss and accuracy plotting:
|
||||||
|
num_train_pattern += int(label.shape[0])
|
||||||
|
pattern_count += float(label.shape[0])
|
||||||
|
running_loss += float(loss)
|
||||||
|
train_loss.append(float(loss))
|
||||||
|
prediction = output.argmax(dim=1)
|
||||||
|
correct += prediction.eq(label).sum().item()
|
||||||
|
|
||||||
|
total_number_of_pattern: int = int(len(loader)) * int(label.shape[0])
|
||||||
|
|
||||||
|
# infos:
|
||||||
|
logger.info(
|
||||||
|
(
|
||||||
|
"Train Epoch:"
|
||||||
|
f" {epoch}"
|
||||||
|
f" [{int(pattern_count)}/{total_number_of_pattern}"
|
||||||
|
f" ({100.0 * pattern_count / total_number_of_pattern:.2f}%)],"
|
||||||
|
f" Loss: {float(running_loss) / float(num_train_pattern):.4e},"
|
||||||
|
f" Acc: {(100.0 * correct / num_train_pattern):.2f}"
|
||||||
|
f" Test Acc: {test_acc:.2f}%,"
|
||||||
|
f" LR: {optimizer.param_groups[0]['lr']:.2e}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
acc = 100.0 * correct / num_train_pattern
|
||||||
|
train_accuracy.append(acc)
|
||||||
|
|
||||||
|
epoch_loss = running_loss / pattern_count
|
||||||
|
train_losses.append(epoch_loss)
|
||||||
|
|
||||||
|
# add to tb:
|
||||||
|
tb.add_scalar("Train Loss", loss.item(), epoch)
|
||||||
|
tb.add_scalar("Train Performance", torch.tensor(acc), epoch)
|
||||||
|
tb.add_scalar("Train Number Correct", torch.tensor(correct), epoch)
|
||||||
|
|
||||||
|
# for parameters:
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if "weight" in name or "bias" in name:
|
||||||
|
tb.add_histogram(f"{name}", param.data.clone(), epoch)
|
||||||
|
|
||||||
|
tb.flush()
|
||||||
|
|
||||||
|
return epoch_loss
|
Loading…
Reference in a new issue