diff --git a/train_it.py b/train_it.py index 1f1da1f..6cab3c2 100644 --- a/train_it.py +++ b/train_it.py @@ -4,10 +4,17 @@ import os os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import sys + +if len(sys.argv) < 2: + order_id: float | int | None = None +else: + order_id = float(sys.argv[1]) + import torch import dataconf import logging from datetime import datetime +import math from network.Parameter import Config @@ -24,10 +31,16 @@ from network.loop_train_test import ( loop_test_reconstruction, ) from network.SbSReconstruction import SbSReconstruction +from network.InputSpikeImage import InputSpikeImage from torch.utils.tensorboard import SummaryWriter +if order_id is None: + order_id_string: str = "" +else: + order_id_string = f"_{order_id}" + # ###################################################################### # We want to log what is going on into a file and screen # ###################################################################### @@ -35,7 +48,7 @@ from torch.utils.tensorboard import SummaryWriter now = datetime.now() dt_string_filename = now.strftime("%Y_%m_%d_%H_%M_%S") logging.basicConfig( - filename="log_" + dt_string_filename + ".txt", + filename=f"log_{dt_string_filename}{order_id_string}.txt", filemode="w", level=logging.INFO, format="%(asctime)s %(message)s", @@ -57,7 +70,9 @@ if os.path.exists("dataset.json") is False: raise Exception("Config file not found! dataset.json") -cfg = dataconf.multi.file("network.json").file("dataset.json").file("def.json").on(Config) +cfg = ( + dataconf.multi.file("network.json").file("dataset.json").file("def.json").on(Config) +) logging.info(cfg) logging.info(f"Number of spikes: {cfg.number_of_spikes}") @@ -71,7 +86,7 @@ logging.info("") logging.info("*** Config loaded.") logging.info("") -tb = SummaryWriter(log_dir=cfg.log_path) +tb = SummaryWriter(log_dir=f"{cfg.log_path}{order_id_string}") # ########################################### # GPU Yes / NO ? @@ -114,10 +129,33 @@ load_previous_weights( logging=logging, device=device, default_dtype=default_dtype, + order_id=order_id, ) logging.info("") +# Fiddling around with the amount of spikes in the input layer +if order_id is not None: + + image_size_x = ( + the_dataset_train.initial_size[0] - 2 * cfg.augmentation.crop_width_in_pixel + ) + image_size_y = ( + the_dataset_train.initial_size[1] - 2 * cfg.augmentation.crop_width_in_pixel + ) + number_of_spikes_in_input_layer = int( + math.ceil( + order_id * the_dataset_train.channel_size * image_size_x * image_size_y + ) + ) + + assert len(cfg.number_of_spikes) > 0 + cfg.number_of_spikes[0] = number_of_spikes_in_input_layer + + if isinstance(network[0], InputSpikeImage) is True: + network[0].number_of_spikes = number_of_spikes_in_input_layer + + last_test_performance: float = -1.0 with torch.no_grad(): if cfg.learning_parameters.learning_active is True: @@ -145,6 +183,7 @@ with torch.no_grad(): adapt_learning_rate=cfg.learning_parameters.adapt_learning_rate_after_minibatch, lr_scheduler=lr_scheduler, last_test_performance=last_test_performance, + order_id=order_id, ) # Let the torch learning rate scheduler update the