Add files via upload
This commit is contained in:
parent
76ea2096c4
commit
ff64a9da14
2 changed files with 55 additions and 22 deletions
|
@ -43,5 +43,5 @@ for te_item in te:
|
|||
temp = np.array(temp)
|
||||
|
||||
print(temp)
|
||||
np.save(f"test_error_{number_of_spikes}.npy", temp)
|
||||
np.save(f"test_error.npy", temp)
|
||||
|
||||
|
|
37
learn_it.py
37
learn_it.py
|
@ -17,7 +17,13 @@ from network.build_lr_scheduler import build_lr_scheduler
|
|||
from network.build_datasets import build_datasets
|
||||
from network.load_previous_weights import load_previous_weights
|
||||
|
||||
from network.loop_train_test import loop_test, loop_train, run_lr_scheduler
|
||||
from network.loop_train_test import (
|
||||
loop_test,
|
||||
loop_train,
|
||||
run_lr_scheduler,
|
||||
loop_test_reconstruction,
|
||||
)
|
||||
from network.SbSReconstruction import SbSReconstruction
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
|
@ -80,7 +86,7 @@ default_dtype = torch.float32
|
|||
torch.set_default_dtype(default_dtype)
|
||||
torch_device: str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
use_gpu: bool = True if torch.cuda.is_available() else False
|
||||
print(f"Using {torch_device} device")
|
||||
logging.info(f"Using {torch_device} device")
|
||||
device = torch.device(torch_device)
|
||||
|
||||
# ######################################################################
|
||||
|
@ -164,6 +170,8 @@ with torch.no_grad():
|
|||
# Run test data
|
||||
# ##############################################
|
||||
network.eval()
|
||||
if isinstance(network[-1], SbSReconstruction) is False:
|
||||
|
||||
last_test_performance = loop_test(
|
||||
epoch_id=cfg.epoch_id,
|
||||
cfg=cfg,
|
||||
|
@ -175,6 +183,18 @@ with torch.no_grad():
|
|||
logging=logging,
|
||||
tb=tb,
|
||||
)
|
||||
else:
|
||||
last_test_performance = loop_test_reconstruction(
|
||||
epoch_id=cfg.epoch_id,
|
||||
cfg=cfg,
|
||||
network=network,
|
||||
my_loader_test=my_loader_test,
|
||||
the_dataset_test=the_dataset_test,
|
||||
device=device,
|
||||
default_dtype=default_dtype,
|
||||
logging=logging,
|
||||
tb=tb,
|
||||
)
|
||||
|
||||
# Next epoch
|
||||
cfg.epoch_id += 1
|
||||
|
@ -183,6 +203,7 @@ with torch.no_grad():
|
|||
# Run test data
|
||||
# ##############################################
|
||||
network.eval()
|
||||
if isinstance(network[-1], SbSReconstruction) is False:
|
||||
last_test_performance = loop_test(
|
||||
epoch_id=cfg.epoch_id,
|
||||
cfg=cfg,
|
||||
|
@ -194,6 +215,18 @@ with torch.no_grad():
|
|||
logging=logging,
|
||||
tb=tb,
|
||||
)
|
||||
else:
|
||||
last_test_performance = loop_test_reconstruction(
|
||||
epoch_id=cfg.epoch_id,
|
||||
cfg=cfg,
|
||||
network=network,
|
||||
my_loader_test=my_loader_test,
|
||||
the_dataset_test=the_dataset_test,
|
||||
device=device,
|
||||
default_dtype=default_dtype,
|
||||
logging=logging,
|
||||
tb=tb,
|
||||
)
|
||||
|
||||
tb.close()
|
||||
|
||||
|
|
Loading…
Reference in a new issue