diff --git a/test_it.py b/test_it.py new file mode 100644 index 0000000..71a764c --- /dev/null +++ b/test_it.py @@ -0,0 +1,157 @@ +# %% +import os + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + +import sys +import torch +import dataconf +import logging +from datetime import datetime + +from network.Parameter import Config + +from network.build_network import build_network +from network.build_optimizer import build_optimizer +from network.build_lr_scheduler import build_lr_scheduler +from network.build_datasets import build_datasets +from network.load_previous_weights import load_previous_weights + +from network.loop_train_test import ( + loop_test, + loop_train, + run_lr_scheduler, + loop_test_reconstruction, +) +from network.SbSReconstruction import SbSReconstruction + +from torch.utils.tensorboard import SummaryWriter + + +# ###################################################################### +# We want to log what is going on into a file and screen +# ###################################################################### + +now = datetime.now() +dt_string_filename = now.strftime("%Y_%m_%d_%H_%M_%S") +logging.basicConfig( + filename="log_" + dt_string_filename + ".txt", + filemode="w", + level=logging.INFO, + format="%(asctime)s %(message)s", +) +logging.getLogger().addHandler(logging.StreamHandler()) + +# ###################################################################### +# Load the config data from the json file +# ###################################################################### + +if os.path.exists("def.json") is False: + raise Exception("Config file not found! def.json") + +if os.path.exists("network.json") is False: + raise Exception("Config file not found! network.json") + +if os.path.exists("dataset.json") is False: + raise Exception("Config file not found! dataset.json") + + +cfg = dataconf.multi.file("network.json").file("dataset.json").file("def.json").on(Config) +logging.info(cfg) + +logging.info(f"Number of spikes: {cfg.number_of_spikes}") +logging.info(f"Cooldown after spikes: {cfg.cooldown_after_number_of_spikes}") +logging.info(f"Reduction cooldown: {cfg.reduction_cooldown}") +logging.info("") +logging.info(f"Epsilon 0: {cfg.epsilon_0}") +logging.info(f"Batch size: {cfg.batch_size}") +logging.info(f"Data mode: {cfg.data_mode}") +logging.info("") +logging.info("*** Config loaded.") +logging.info("") + +tb = SummaryWriter(log_dir=cfg.log_path) + +# ########################################### +# GPU Yes / NO ? +# ########################################### +default_dtype = torch.float32 +torch.set_default_dtype(default_dtype) +torch_device: str = "cuda:0" if torch.cuda.is_available() else "cpu" +use_gpu: bool = True if torch.cuda.is_available() else False +logging.info(f"Using {torch_device} device") +device = torch.device(torch_device) + +# ###################################################################### +# Prepare the test and training data +# ###################################################################### + +the_dataset_train, the_dataset_test, my_loader_test, my_loader_train = build_datasets( + cfg +) + +logging.info("*** Data loaded.") + +# ###################################################################### +# Build the network, Optimizer, and LR Scheduler # +# ###################################################################### + +network = build_network( + cfg=cfg, device=device, default_dtype=default_dtype, logging=logging +) +logging.info("") + +optimizer = build_optimizer(network=network, cfg=cfg, logging=logging) + +lr_scheduler = build_lr_scheduler(optimizer=optimizer, cfg=cfg, logging=logging) + +logging.info("*** Network generated.") + +load_previous_weights( + network=network, + overload_path=cfg.learning_parameters.overload_path, + logging=logging, + device=device, + default_dtype=default_dtype, +) + +logging.info("") + +last_test_performance: float = -1.0 + + +# ############################################## +# Run test data +# ############################################## +network.eval() +if isinstance(network[-1], SbSReconstruction) is False: + last_test_performance = loop_test( + epoch_id=cfg.epoch_id, + cfg=cfg, + network=network, + my_loader_test=my_loader_test, + the_dataset_test=the_dataset_test, + device=device, + default_dtype=default_dtype, + logging=logging, + tb=tb, + ) +else: + last_test_performance = loop_test_reconstruction( + epoch_id=cfg.epoch_id, + cfg=cfg, + network=network, + my_loader_test=my_loader_test, + the_dataset_test=the_dataset_test, + device=device, + default_dtype=default_dtype, + logging=logging, + tb=tb, + ) + +tb.close() + + + + +# %% diff --git a/train_it.py b/train_it.py new file mode 100644 index 0000000..1f1da1f --- /dev/null +++ b/train_it.py @@ -0,0 +1,228 @@ +# %% +import os + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + +import sys +import torch +import dataconf +import logging +from datetime import datetime + +from network.Parameter import Config + +from network.build_network import build_network +from network.build_optimizer import build_optimizer +from network.build_lr_scheduler import build_lr_scheduler +from network.build_datasets import build_datasets +from network.load_previous_weights import load_previous_weights + +from network.loop_train_test import ( + loop_test, + loop_train, + run_lr_scheduler, + loop_test_reconstruction, +) +from network.SbSReconstruction import SbSReconstruction + +from torch.utils.tensorboard import SummaryWriter + + +# ###################################################################### +# We want to log what is going on into a file and screen +# ###################################################################### + +now = datetime.now() +dt_string_filename = now.strftime("%Y_%m_%d_%H_%M_%S") +logging.basicConfig( + filename="log_" + dt_string_filename + ".txt", + filemode="w", + level=logging.INFO, + format="%(asctime)s %(message)s", +) +logging.getLogger().addHandler(logging.StreamHandler()) + +# ###################################################################### +# Load the config data from the json file +# ###################################################################### + + +if os.path.exists("def.json") is False: + raise Exception("Config file not found! def.json") + +if os.path.exists("network.json") is False: + raise Exception("Config file not found! network.json") + +if os.path.exists("dataset.json") is False: + raise Exception("Config file not found! dataset.json") + + +cfg = dataconf.multi.file("network.json").file("dataset.json").file("def.json").on(Config) +logging.info(cfg) + +logging.info(f"Number of spikes: {cfg.number_of_spikes}") +logging.info(f"Cooldown after spikes: {cfg.cooldown_after_number_of_spikes}") +logging.info(f"Reduction cooldown: {cfg.reduction_cooldown}") +logging.info("") +logging.info(f"Epsilon 0: {cfg.epsilon_0}") +logging.info(f"Batch size: {cfg.batch_size}") +logging.info(f"Data mode: {cfg.data_mode}") +logging.info("") +logging.info("*** Config loaded.") +logging.info("") + +tb = SummaryWriter(log_dir=cfg.log_path) + +# ########################################### +# GPU Yes / NO ? +# ########################################### +default_dtype = torch.float32 +torch.set_default_dtype(default_dtype) +torch_device: str = "cuda:0" if torch.cuda.is_available() else "cpu" +use_gpu: bool = True if torch.cuda.is_available() else False +logging.info(f"Using {torch_device} device") +device = torch.device(torch_device) + +# ###################################################################### +# Prepare the test and training data +# ###################################################################### + +the_dataset_train, the_dataset_test, my_loader_test, my_loader_train = build_datasets( + cfg +) + +logging.info("*** Data loaded.") + +# ###################################################################### +# Build the network, Optimizer, and LR Scheduler # +# ###################################################################### + +network = build_network( + cfg=cfg, device=device, default_dtype=default_dtype, logging=logging +) +logging.info("") + +optimizer = build_optimizer(network=network, cfg=cfg, logging=logging) + +lr_scheduler = build_lr_scheduler(optimizer=optimizer, cfg=cfg, logging=logging) + +logging.info("*** Network generated.") + +load_previous_weights( + network=network, + overload_path=cfg.learning_parameters.overload_path, + logging=logging, + device=device, + default_dtype=default_dtype, +) + +logging.info("") + +last_test_performance: float = -1.0 +with torch.no_grad(): + if cfg.learning_parameters.learning_active is True: + while cfg.epoch_id < cfg.epoch_id_max: + + # ############################################## + # Run a training data epoch + # ############################################## + network.train() + ( + my_loss_for_batch, + performance_for_batch, + full_loss, + full_correct, + ) = loop_train( + cfg=cfg, + network=network, + my_loader_train=my_loader_train, + the_dataset_train=the_dataset_train, + optimizer=optimizer, + device=device, + default_dtype=default_dtype, + logging=logging, + tb=tb, + adapt_learning_rate=cfg.learning_parameters.adapt_learning_rate_after_minibatch, + lr_scheduler=lr_scheduler, + last_test_performance=last_test_performance, + ) + + # Let the torch learning rate scheduler update the + # learning rates of the optimiers + if cfg.learning_parameters.adapt_learning_rate_after_minibatch is False: + run_lr_scheduler( + cfg=cfg, + lr_scheduler=lr_scheduler, + optimizer=optimizer, + performance_for_batch=performance_for_batch, + my_loss_for_batch=my_loss_for_batch, + tb=tb, + logging=logging, + ) + + # ############################################## + # Run test data + # ############################################## + network.eval() + if isinstance(network[-1], SbSReconstruction) is False: + + last_test_performance = loop_test( + epoch_id=cfg.epoch_id, + cfg=cfg, + network=network, + my_loader_test=my_loader_test, + the_dataset_test=the_dataset_test, + device=device, + default_dtype=default_dtype, + logging=logging, + tb=tb, + ) + else: + last_test_performance = loop_test_reconstruction( + epoch_id=cfg.epoch_id, + cfg=cfg, + network=network, + my_loader_test=my_loader_test, + the_dataset_test=the_dataset_test, + device=device, + default_dtype=default_dtype, + logging=logging, + tb=tb, + ) + + # Next epoch + cfg.epoch_id += 1 + else: + # ############################################## + # Run test data + # ############################################## + network.eval() + if isinstance(network[-1], SbSReconstruction) is False: + last_test_performance = loop_test( + epoch_id=cfg.epoch_id, + cfg=cfg, + network=network, + my_loader_test=my_loader_test, + the_dataset_test=the_dataset_test, + device=device, + default_dtype=default_dtype, + logging=logging, + tb=tb, + ) + else: + last_test_performance = loop_test_reconstruction( + epoch_id=cfg.epoch_id, + cfg=cfg, + network=network, + my_loader_test=my_loader_test, + the_dataset_test=the_dataset_test, + device=device, + default_dtype=default_dtype, + logging=logging, + tb=tb, + ) + + tb.close() + + +# %%