Add files via upload
This commit is contained in:
parent
41df07230d
commit
5ac2f1dc96
1 changed files with 42 additions and 3 deletions
45
train_it.py
45
train_it.py
|
@ -4,10 +4,17 @@ import os
|
||||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
order_id: float | int | None = None
|
||||||
|
else:
|
||||||
|
order_id = float(sys.argv[1])
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import dataconf
|
import dataconf
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
import math
|
||||||
|
|
||||||
from network.Parameter import Config
|
from network.Parameter import Config
|
||||||
|
|
||||||
|
@ -24,10 +31,16 @@ from network.loop_train_test import (
|
||||||
loop_test_reconstruction,
|
loop_test_reconstruction,
|
||||||
)
|
)
|
||||||
from network.SbSReconstruction import SbSReconstruction
|
from network.SbSReconstruction import SbSReconstruction
|
||||||
|
from network.InputSpikeImage import InputSpikeImage
|
||||||
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
|
||||||
|
if order_id is None:
|
||||||
|
order_id_string: str = ""
|
||||||
|
else:
|
||||||
|
order_id_string = f"_{order_id}"
|
||||||
|
|
||||||
# ######################################################################
|
# ######################################################################
|
||||||
# We want to log what is going on into a file and screen
|
# We want to log what is going on into a file and screen
|
||||||
# ######################################################################
|
# ######################################################################
|
||||||
|
@ -35,7 +48,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
dt_string_filename = now.strftime("%Y_%m_%d_%H_%M_%S")
|
dt_string_filename = now.strftime("%Y_%m_%d_%H_%M_%S")
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
filename="log_" + dt_string_filename + ".txt",
|
filename=f"log_{dt_string_filename}{order_id_string}.txt",
|
||||||
filemode="w",
|
filemode="w",
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format="%(asctime)s %(message)s",
|
format="%(asctime)s %(message)s",
|
||||||
|
@ -57,7 +70,9 @@ if os.path.exists("dataset.json") is False:
|
||||||
raise Exception("Config file not found! dataset.json")
|
raise Exception("Config file not found! dataset.json")
|
||||||
|
|
||||||
|
|
||||||
cfg = dataconf.multi.file("network.json").file("dataset.json").file("def.json").on(Config)
|
cfg = (
|
||||||
|
dataconf.multi.file("network.json").file("dataset.json").file("def.json").on(Config)
|
||||||
|
)
|
||||||
logging.info(cfg)
|
logging.info(cfg)
|
||||||
|
|
||||||
logging.info(f"Number of spikes: {cfg.number_of_spikes}")
|
logging.info(f"Number of spikes: {cfg.number_of_spikes}")
|
||||||
|
@ -71,7 +86,7 @@ logging.info("")
|
||||||
logging.info("*** Config loaded.")
|
logging.info("*** Config loaded.")
|
||||||
logging.info("")
|
logging.info("")
|
||||||
|
|
||||||
tb = SummaryWriter(log_dir=cfg.log_path)
|
tb = SummaryWriter(log_dir=f"{cfg.log_path}{order_id_string}")
|
||||||
|
|
||||||
# ###########################################
|
# ###########################################
|
||||||
# GPU Yes / NO ?
|
# GPU Yes / NO ?
|
||||||
|
@ -114,10 +129,33 @@ load_previous_weights(
|
||||||
logging=logging,
|
logging=logging,
|
||||||
device=device,
|
device=device,
|
||||||
default_dtype=default_dtype,
|
default_dtype=default_dtype,
|
||||||
|
order_id=order_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("")
|
logging.info("")
|
||||||
|
|
||||||
|
# Fiddling around with the amount of spikes in the input layer
|
||||||
|
if order_id is not None:
|
||||||
|
|
||||||
|
image_size_x = (
|
||||||
|
the_dataset_train.initial_size[0] - 2 * cfg.augmentation.crop_width_in_pixel
|
||||||
|
)
|
||||||
|
image_size_y = (
|
||||||
|
the_dataset_train.initial_size[1] - 2 * cfg.augmentation.crop_width_in_pixel
|
||||||
|
)
|
||||||
|
number_of_spikes_in_input_layer = int(
|
||||||
|
math.ceil(
|
||||||
|
order_id * the_dataset_train.channel_size * image_size_x * image_size_y
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(cfg.number_of_spikes) > 0
|
||||||
|
cfg.number_of_spikes[0] = number_of_spikes_in_input_layer
|
||||||
|
|
||||||
|
if isinstance(network[0], InputSpikeImage) is True:
|
||||||
|
network[0].number_of_spikes = number_of_spikes_in_input_layer
|
||||||
|
|
||||||
|
|
||||||
last_test_performance: float = -1.0
|
last_test_performance: float = -1.0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if cfg.learning_parameters.learning_active is True:
|
if cfg.learning_parameters.learning_active is True:
|
||||||
|
@ -145,6 +183,7 @@ with torch.no_grad():
|
||||||
adapt_learning_rate=cfg.learning_parameters.adapt_learning_rate_after_minibatch,
|
adapt_learning_rate=cfg.learning_parameters.adapt_learning_rate_after_minibatch,
|
||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
last_test_performance=last_test_performance,
|
last_test_performance=last_test_performance,
|
||||||
|
order_id=order_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Let the torch learning rate scheduler update the
|
# Let the torch learning rate scheduler update the
|
||||||
|
|
Loading…
Reference in a new issue