Add files via upload
This commit is contained in:
parent
41df07230d
commit
5ac2f1dc96
1 changed files with 42 additions and 3 deletions
45
train_it.py
45
train_it.py
|
@ -4,10 +4,17 @@ import os
|
|||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
|
||||
import sys
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
order_id: float | int | None = None
|
||||
else:
|
||||
order_id = float(sys.argv[1])
|
||||
|
||||
import torch
|
||||
import dataconf
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import math
|
||||
|
||||
from network.Parameter import Config
|
||||
|
||||
|
@ -24,10 +31,16 @@ from network.loop_train_test import (
|
|||
loop_test_reconstruction,
|
||||
)
|
||||
from network.SbSReconstruction import SbSReconstruction
|
||||
from network.InputSpikeImage import InputSpikeImage
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
|
||||
if order_id is None:
|
||||
order_id_string: str = ""
|
||||
else:
|
||||
order_id_string = f"_{order_id}"
|
||||
|
||||
# ######################################################################
|
||||
# We want to log what is going on into a file and screen
|
||||
# ######################################################################
|
||||
|
@ -35,7 +48,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||
now = datetime.now()
|
||||
dt_string_filename = now.strftime("%Y_%m_%d_%H_%M_%S")
|
||||
logging.basicConfig(
|
||||
filename="log_" + dt_string_filename + ".txt",
|
||||
filename=f"log_{dt_string_filename}{order_id_string}.txt",
|
||||
filemode="w",
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(message)s",
|
||||
|
@ -57,7 +70,9 @@ if os.path.exists("dataset.json") is False:
|
|||
raise Exception("Config file not found! dataset.json")
|
||||
|
||||
|
||||
cfg = dataconf.multi.file("network.json").file("dataset.json").file("def.json").on(Config)
|
||||
cfg = (
|
||||
dataconf.multi.file("network.json").file("dataset.json").file("def.json").on(Config)
|
||||
)
|
||||
logging.info(cfg)
|
||||
|
||||
logging.info(f"Number of spikes: {cfg.number_of_spikes}")
|
||||
|
@ -71,7 +86,7 @@ logging.info("")
|
|||
logging.info("*** Config loaded.")
|
||||
logging.info("")
|
||||
|
||||
tb = SummaryWriter(log_dir=cfg.log_path)
|
||||
tb = SummaryWriter(log_dir=f"{cfg.log_path}{order_id_string}")
|
||||
|
||||
# ###########################################
|
||||
# GPU Yes / NO ?
|
||||
|
@ -114,10 +129,33 @@ load_previous_weights(
|
|||
logging=logging,
|
||||
device=device,
|
||||
default_dtype=default_dtype,
|
||||
order_id=order_id,
|
||||
)
|
||||
|
||||
logging.info("")
|
||||
|
||||
# Fiddling around with the amount of spikes in the input layer
|
||||
if order_id is not None:
|
||||
|
||||
image_size_x = (
|
||||
the_dataset_train.initial_size[0] - 2 * cfg.augmentation.crop_width_in_pixel
|
||||
)
|
||||
image_size_y = (
|
||||
the_dataset_train.initial_size[1] - 2 * cfg.augmentation.crop_width_in_pixel
|
||||
)
|
||||
number_of_spikes_in_input_layer = int(
|
||||
math.ceil(
|
||||
order_id * the_dataset_train.channel_size * image_size_x * image_size_y
|
||||
)
|
||||
)
|
||||
|
||||
assert len(cfg.number_of_spikes) > 0
|
||||
cfg.number_of_spikes[0] = number_of_spikes_in_input_layer
|
||||
|
||||
if isinstance(network[0], InputSpikeImage) is True:
|
||||
network[0].number_of_spikes = number_of_spikes_in_input_layer
|
||||
|
||||
|
||||
last_test_performance: float = -1.0
|
||||
with torch.no_grad():
|
||||
if cfg.learning_parameters.learning_active is True:
|
||||
|
@ -145,6 +183,7 @@ with torch.no_grad():
|
|||
adapt_learning_rate=cfg.learning_parameters.adapt_learning_rate_after_minibatch,
|
||||
lr_scheduler=lr_scheduler,
|
||||
last_test_performance=last_test_performance,
|
||||
order_id=order_id,
|
||||
)
|
||||
|
||||
# Let the torch learning rate scheduler update the
|
||||
|
|
Loading…
Reference in a new issue