From 33f92b742997114c282b2f82d6bd46ec2597b621 Mon Sep 17 00:00:00 2001 From: David Rotermund <54365609+davrot@users.noreply.github.com> Date: Wed, 15 Mar 2023 16:46:09 +0100 Subject: [PATCH] Add files via upload --- network/loop_train_test.py | 134 ++++++++++++++++++++++++++++++++++--- 1 file changed, 126 insertions(+), 8 deletions(-) diff --git a/network/loop_train_test.py b/network/loop_train_test.py index ce9ed02..0153365 100644 --- a/network/loop_train_test.py +++ b/network/loop_train_test.py @@ -5,6 +5,7 @@ from torch.utils.tensorboard import SummaryWriter from network.SbSLayer import SbSLayer from network.NNMFLayer import NNMFLayer +from network.NNMFLayerSbSBP import NNMFLayerSbSBP from network.save_weight_and_bias import save_weight_and_bias from network.SbSReconstruction import SbSReconstruction @@ -20,8 +21,10 @@ def add_weight_and_bias_to_histogram( # ################################################ # Log the SbS Weights # ################################################ - if (isinstance(network[id], SbSLayer) is True) or ( - isinstance(network[id], NNMFLayer) is True + if ( + (isinstance(network[id], SbSLayer) is True) + or (isinstance(network[id], NNMFLayer) is True) + or (isinstance(network[id], NNMFLayerSbSBP) is True) ): if network[id]._w_trainable is True: @@ -231,8 +234,10 @@ def run_optimizer( cfg: Config, ) -> None: for id in range(0, len(network)): - if (isinstance(network[id], SbSLayer) is True) or ( - isinstance(network[id], NNMFLayer) is True + if ( + (isinstance(network[id], SbSLayer) is True) + or (isinstance(network[id], NNMFLayer) is True) + or (isinstance(network[id], NNMFLayerSbSBP) is True) ): network[id].update_pre_care() @@ -241,8 +246,10 @@ def run_optimizer( optimizer_item.step() for id in range(0, len(network)): - if (isinstance(network[id], SbSLayer) is True) or ( - isinstance(network[id], NNMFLayer) is True + if ( + (isinstance(network[id], SbSLayer) is True) + or (isinstance(network[id], NNMFLayer) is True) + or (isinstance(network[id], NNMFLayerSbSBP) is True) ): network[id].update_after_care( cfg.learning_parameters.learning_rate_threshold_w @@ -295,11 +302,15 @@ def run_lr_scheduler( def deal_with_gradient_scale(epoch_id: int, mini_batch_number: int, network): if (epoch_id == 0) and (mini_batch_number == 0): for id in range(0, len(network)): - if isinstance(network[id], SbSLayer) is True: + if (isinstance(network[id], SbSLayer) is True) or ( + isinstance(network[id], NNMFLayerSbSBP) is True + ): network[id].after_batch(True) else: for id in range(0, len(network)): - if isinstance(network[id], SbSLayer) is True: + if (isinstance(network[id], SbSLayer) is True) or ( + isinstance(network[id], NNMFLayerSbSBP) is True + ): network[id].after_batch() @@ -775,3 +786,110 @@ def loop_test_reconstruction( tb.flush() return performance + + +def loop_train_h_activity( + cfg: Config, + network: torch.nn.modules.container.Sequential, + my_loader_train: torch.utils.data.dataloader.DataLoader, + the_dataset_train, + device: torch.device, + default_dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + + first_run = True + position_counter = 0 + for h_x, h_x_labels in my_loader_train: + print(f"{position_counter} of {the_dataset_train.number_of_pattern}") + # ##################################################### + # The network does the forward pass (training) + # ##################################################### + h_collection = forward_pass_train( + input=h_x, + labels=h_x_labels, + the_dataset_train=the_dataset_train, + cfg=cfg, + network=network, + device=device, + default_dtype=default_dtype, + ) + + h_h: torch.Tensor = h_collection[-1].detach().clone().cpu() + + if first_run is True: + h_activity = torch.zeros( + ( + the_dataset_train.number_of_pattern, + h_h.shape[-3], + h_h.shape[-2], + h_h.shape[-1], + ), + device=device, + dtype=default_dtype, + ) + + h_labels = torch.zeros( + (the_dataset_train.number_of_pattern), + device=device, + dtype=torch.int, + ) + first_run = False + + h_activity[ + position_counter : position_counter + int(h_h.shape[0]), :, :, : + ] = h_h + + h_labels[ + position_counter : position_counter + int(h_h.shape[0]), + ] = h_x_labels + position_counter += int(h_h.shape[0]) + + return h_activity, h_labels + + +def loop_train_h_confusion( + cfg: Config, + network: torch.nn.modules.container.Sequential, + my_loader_train: torch.utils.data.dataloader.DataLoader, + the_dataset_train, + optimizer: list, + device: torch.device, + default_dtype: torch.dtype, + logging, + adapt_learning_rate: bool, + tb: SummaryWriter, + lr_scheduler, + last_test_performance: float, + order_id: float | int | None = None, +) -> torch.Tensor: + + first_run = True + + position_counter = 0 + for h_x, h_x_labels in my_loader_train: + print(f"{position_counter} of {the_dataset_train.number_of_pattern}") + # ##################################################### + # The network does the forward pass (training) + # ##################################################### + h_collection = forward_pass_train( + input=h_x, + labels=h_x_labels, + the_dataset_train=the_dataset_train, + cfg=cfg, + network=network, + device=device, + default_dtype=default_dtype, + ) + + h_h: torch.Tensor = h_collection[-1] + idx = h_h.argmax(dim=1).squeeze().cpu().numpy() + if first_run is True: + first_run = False + result = torch.zeros((h_h.shape[1], h_h.shape[1]), dtype=torch.int) + + for i in range(0, idx.shape[0]): + result[int(h_x_labels[i]), int(idx[i])] += 1 + + position_counter += int(h_h.shape[0]) + + return result