diff --git a/get_perf.py b/get_perf.py index f37bceb..863802a 100644 --- a/get_perf.py +++ b/get_perf.py @@ -43,5 +43,5 @@ for te_item in te: temp = np.array(temp) print(temp) -np.save(f"test_error_{number_of_spikes}.npy", temp) +np.save(f"test_error.npy", temp) diff --git a/learn_it.py b/learn_it.py index 73f0b0e..c0a8a34 100644 --- a/learn_it.py +++ b/learn_it.py @@ -17,7 +17,13 @@ from network.build_lr_scheduler import build_lr_scheduler from network.build_datasets import build_datasets from network.load_previous_weights import load_previous_weights -from network.loop_train_test import loop_test, loop_train, run_lr_scheduler +from network.loop_train_test import ( + loop_test, + loop_train, + run_lr_scheduler, + loop_test_reconstruction, +) +from network.SbSReconstruction import SbSReconstruction from torch.utils.tensorboard import SummaryWriter @@ -80,7 +86,7 @@ default_dtype = torch.float32 torch.set_default_dtype(default_dtype) torch_device: str = "cuda:0" if torch.cuda.is_available() else "cpu" use_gpu: bool = True if torch.cuda.is_available() else False -print(f"Using {torch_device} device") +logging.info(f"Using {torch_device} device") device = torch.device(torch_device) # ###################################################################### @@ -164,6 +170,40 @@ with torch.no_grad(): # Run test data # ############################################## network.eval() + if isinstance(network[-1], SbSReconstruction) is False: + + last_test_performance = loop_test( + epoch_id=cfg.epoch_id, + cfg=cfg, + network=network, + my_loader_test=my_loader_test, + the_dataset_test=the_dataset_test, + device=device, + default_dtype=default_dtype, + logging=logging, + tb=tb, + ) + else: + last_test_performance = loop_test_reconstruction( + epoch_id=cfg.epoch_id, + cfg=cfg, + network=network, + my_loader_test=my_loader_test, + the_dataset_test=the_dataset_test, + device=device, + default_dtype=default_dtype, + logging=logging, + tb=tb, + ) + + # Next epoch + cfg.epoch_id += 1 + else: + # ############################################## + # Run test data + # ############################################## + network.eval() + if isinstance(network[-1], SbSReconstruction) is False: last_test_performance = loop_test( epoch_id=cfg.epoch_id, cfg=cfg, @@ -175,25 +215,18 @@ with torch.no_grad(): logging=logging, tb=tb, ) - - # Next epoch - cfg.epoch_id += 1 - else: - # ############################################## - # Run test data - # ############################################## - network.eval() - last_test_performance = loop_test( - epoch_id=cfg.epoch_id, - cfg=cfg, - network=network, - my_loader_test=my_loader_test, - the_dataset_test=the_dataset_test, - device=device, - default_dtype=default_dtype, - logging=logging, - tb=tb, - ) + else: + last_test_performance = loop_test_reconstruction( + epoch_id=cfg.epoch_id, + cfg=cfg, + network=network, + my_loader_test=my_loader_test, + the_dataset_test=the_dataset_test, + device=device, + default_dtype=default_dtype, + logging=logging, + tb=tb, + ) tb.close()