Add files via upload

This commit is contained in:
David Rotermund 2023-02-04 14:22:45 +01:00 committed by GitHub
parent 41df07230d
commit 5ac2f1dc96
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -4,10 +4,17 @@ import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import sys import sys
if len(sys.argv) < 2:
order_id: float | int | None = None
else:
order_id = float(sys.argv[1])
import torch import torch
import dataconf import dataconf
import logging import logging
from datetime import datetime from datetime import datetime
import math
from network.Parameter import Config from network.Parameter import Config
@ -24,10 +31,16 @@ from network.loop_train_test import (
loop_test_reconstruction, loop_test_reconstruction,
) )
from network.SbSReconstruction import SbSReconstruction from network.SbSReconstruction import SbSReconstruction
from network.InputSpikeImage import InputSpikeImage
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
if order_id is None:
order_id_string: str = ""
else:
order_id_string = f"_{order_id}"
# ###################################################################### # ######################################################################
# We want to log what is going on into a file and screen # We want to log what is going on into a file and screen
# ###################################################################### # ######################################################################
@ -35,7 +48,7 @@ from torch.utils.tensorboard import SummaryWriter
now = datetime.now() now = datetime.now()
dt_string_filename = now.strftime("%Y_%m_%d_%H_%M_%S") dt_string_filename = now.strftime("%Y_%m_%d_%H_%M_%S")
logging.basicConfig( logging.basicConfig(
filename="log_" + dt_string_filename + ".txt", filename=f"log_{dt_string_filename}{order_id_string}.txt",
filemode="w", filemode="w",
level=logging.INFO, level=logging.INFO,
format="%(asctime)s %(message)s", format="%(asctime)s %(message)s",
@ -57,7 +70,9 @@ if os.path.exists("dataset.json") is False:
raise Exception("Config file not found! dataset.json") raise Exception("Config file not found! dataset.json")
cfg = dataconf.multi.file("network.json").file("dataset.json").file("def.json").on(Config) cfg = (
dataconf.multi.file("network.json").file("dataset.json").file("def.json").on(Config)
)
logging.info(cfg) logging.info(cfg)
logging.info(f"Number of spikes: {cfg.number_of_spikes}") logging.info(f"Number of spikes: {cfg.number_of_spikes}")
@ -71,7 +86,7 @@ logging.info("")
logging.info("*** Config loaded.") logging.info("*** Config loaded.")
logging.info("") logging.info("")
tb = SummaryWriter(log_dir=cfg.log_path) tb = SummaryWriter(log_dir=f"{cfg.log_path}{order_id_string}")
# ########################################### # ###########################################
# GPU Yes / NO ? # GPU Yes / NO ?
@ -114,10 +129,33 @@ load_previous_weights(
logging=logging, logging=logging,
device=device, device=device,
default_dtype=default_dtype, default_dtype=default_dtype,
order_id=order_id,
) )
logging.info("") logging.info("")
# Fiddling around with the amount of spikes in the input layer
if order_id is not None:
image_size_x = (
the_dataset_train.initial_size[0] - 2 * cfg.augmentation.crop_width_in_pixel
)
image_size_y = (
the_dataset_train.initial_size[1] - 2 * cfg.augmentation.crop_width_in_pixel
)
number_of_spikes_in_input_layer = int(
math.ceil(
order_id * the_dataset_train.channel_size * image_size_x * image_size_y
)
)
assert len(cfg.number_of_spikes) > 0
cfg.number_of_spikes[0] = number_of_spikes_in_input_layer
if isinstance(network[0], InputSpikeImage) is True:
network[0].number_of_spikes = number_of_spikes_in_input_layer
last_test_performance: float = -1.0 last_test_performance: float = -1.0
with torch.no_grad(): with torch.no_grad():
if cfg.learning_parameters.learning_active is True: if cfg.learning_parameters.learning_active is True:
@ -145,6 +183,7 @@ with torch.no_grad():
adapt_learning_rate=cfg.learning_parameters.adapt_learning_rate_after_minibatch, adapt_learning_rate=cfg.learning_parameters.adapt_learning_rate_after_minibatch,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
last_test_performance=last_test_performance, last_test_performance=last_test_performance,
order_id=order_id,
) )
# Let the torch learning rate scheduler update the # Let the torch learning rate scheduler update the