Add files via upload

This commit is contained in:
David Rotermund 2023-01-13 21:31:12 +01:00 committed by GitHub
parent 76ea2096c4
commit ff64a9da14
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 55 additions and 22 deletions

View file

@ -43,5 +43,5 @@ for te_item in te:
temp = np.array(temp) temp = np.array(temp)
print(temp) print(temp)
np.save(f"test_error_{number_of_spikes}.npy", temp) np.save(f"test_error.npy", temp)

View file

@ -17,7 +17,13 @@ from network.build_lr_scheduler import build_lr_scheduler
from network.build_datasets import build_datasets from network.build_datasets import build_datasets
from network.load_previous_weights import load_previous_weights from network.load_previous_weights import load_previous_weights
from network.loop_train_test import loop_test, loop_train, run_lr_scheduler from network.loop_train_test import (
loop_test,
loop_train,
run_lr_scheduler,
loop_test_reconstruction,
)
from network.SbSReconstruction import SbSReconstruction
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -80,7 +86,7 @@ default_dtype = torch.float32
torch.set_default_dtype(default_dtype) torch.set_default_dtype(default_dtype)
torch_device: str = "cuda:0" if torch.cuda.is_available() else "cpu" torch_device: str = "cuda:0" if torch.cuda.is_available() else "cpu"
use_gpu: bool = True if torch.cuda.is_available() else False use_gpu: bool = True if torch.cuda.is_available() else False
print(f"Using {torch_device} device") logging.info(f"Using {torch_device} device")
device = torch.device(torch_device) device = torch.device(torch_device)
# ###################################################################### # ######################################################################
@ -164,6 +170,8 @@ with torch.no_grad():
# Run test data # Run test data
# ############################################## # ##############################################
network.eval() network.eval()
if isinstance(network[-1], SbSReconstruction) is False:
last_test_performance = loop_test( last_test_performance = loop_test(
epoch_id=cfg.epoch_id, epoch_id=cfg.epoch_id,
cfg=cfg, cfg=cfg,
@ -175,6 +183,18 @@ with torch.no_grad():
logging=logging, logging=logging,
tb=tb, tb=tb,
) )
else:
last_test_performance = loop_test_reconstruction(
epoch_id=cfg.epoch_id,
cfg=cfg,
network=network,
my_loader_test=my_loader_test,
the_dataset_test=the_dataset_test,
device=device,
default_dtype=default_dtype,
logging=logging,
tb=tb,
)
# Next epoch # Next epoch
cfg.epoch_id += 1 cfg.epoch_id += 1
@ -183,6 +203,7 @@ with torch.no_grad():
# Run test data # Run test data
# ############################################## # ##############################################
network.eval() network.eval()
if isinstance(network[-1], SbSReconstruction) is False:
last_test_performance = loop_test( last_test_performance = loop_test(
epoch_id=cfg.epoch_id, epoch_id=cfg.epoch_id,
cfg=cfg, cfg=cfg,
@ -194,6 +215,18 @@ with torch.no_grad():
logging=logging, logging=logging,
tb=tb, tb=tb,
) )
else:
last_test_performance = loop_test_reconstruction(
epoch_id=cfg.epoch_id,
cfg=cfg,
network=network,
my_loader_test=my_loader_test,
the_dataset_test=the_dataset_test,
device=device,
default_dtype=default_dtype,
logging=logging,
tb=tb,
)
tb.close() tb.close()