From a85805e92c0ae9fd0c7e9a50a768c481d60cf558 Mon Sep 17 00:00:00 2001 From: David Rotermund <54365609+davrot@users.noreply.github.com> Date: Tue, 21 Feb 2023 14:37:51 +0100 Subject: [PATCH] Add files via upload --- network/Adam.py | 6 +- network/Dataset.py | 2 + network/DatasetMix.py | 90 ++++++++++ network/HDynamicLayer.py | 1 - network/InputSpikeImage.py | 13 +- network/NNMFLayer.py | 280 +++++++++++++++++++++++++++++++ network/Parameter.py | 3 +- network/SbSLayer.py | 4 + network/SplitOnOffLayer.py | 35 ++-- network/build_datasets.py | 26 +++ network/build_lr_scheduler.py | 3 +- network/build_network.py | 85 +++++++++- network/build_optimizer.py | 8 + network/load_previous_weights.py | 19 +++ network/loop_train_test.py | 101 ++++++++++- network/save_weight_and_bias.py | 29 +++- 16 files changed, 659 insertions(+), 46 deletions(-) create mode 100644 network/DatasetMix.py create mode 100644 network/NNMFLayer.py diff --git a/network/Adam.py b/network/Adam.py index 4ed7baf..08443b2 100644 --- a/network/Adam.py +++ b/network/Adam.py @@ -151,11 +151,7 @@ class Adam(torch.optim.Optimizer): if sbs_setting[i] is False: param -= step_size * (exp_avg / denom) else: - # delta = torch.exp(-step_size * (exp_avg / denom)) - delta = torch.tanh(-step_size * (exp_avg / denom)) - delta += 1.0 - delta *= 0.5 - delta += 0.5 + delta = 0.5 * torch.tanh(-step_size * (exp_avg / denom)) + 1.0 self._logging.info( f"ADAM: Layer {i} -> dw_min:{float(delta.min()):.4e} dw_max:{float(delta.max()):.4e} lr:{lr:.4e}" ) diff --git a/network/Dataset.py b/network/Dataset.py index 83b18fb..55bd88b 100644 --- a/network/Dataset.py +++ b/network/Dataset.py @@ -15,6 +15,8 @@ class DatasetMaster(torch.utils.data.Dataset, ABC): initial_size: list[int] channel_size: int + alpha: float + # Initialize def __init__( self, diff --git a/network/DatasetMix.py b/network/DatasetMix.py new file mode 100644 index 0000000..ab36199 --- /dev/null +++ b/network/DatasetMix.py @@ -0,0 +1,90 @@ +import torch +from network.Dataset import DatasetMNIST, DatasetFashionMNIST, DatasetCIFAR +import math + + +class DatasetMNISTMix(DatasetMNIST): + def __init__( + self, + train: bool = False, + path_pattern: str = "./", + path_label: str = "./", + alpha: float = 1.0, + ) -> None: + super().__init__(train, path_pattern, path_label) + self.alpha = alpha + + def __getitem__(self, index: int) -> tuple[torch.Tensor, list[int]]: # type: ignore + + assert self.alpha >= 0.0 + assert self.alpha <= 1.0 + + image_a, target_a = super().__getitem__(index) + + target_b: int = target_a + while target_b == target_a: + image_b, target_b = super().__getitem__( + int(math.floor(self.number_of_pattern * torch.rand((1)).item())) + ) + + image = self.alpha * image_a + (1.0 - self.alpha) * image_b + target = [target_a, target_b] + return image, target + + +class DatasetFashionMNISTMix(DatasetFashionMNIST): + def __init__( + self, + train: bool = False, + path_pattern: str = "./", + path_label: str = "./", + alpha: float = 1.0, + ) -> None: + super().__init__(train, path_pattern, path_label) + self.alpha = alpha + + def __getitem__(self, index: int) -> tuple[torch.Tensor, list[int]]: # type: ignore + + assert self.alpha >= 0.0 + assert self.alpha <= 1.0 + + image_a, target_a = super().__getitem__(index) + + target_b: int = target_a + while target_b == target_a: + image_b, target_b = super().__getitem__( + int(math.floor(self.number_of_pattern * torch.rand((1)).item())) + ) + + image = self.alpha * image_a + (1.0 - self.alpha) * image_b + target = [target_a, target_b] + return image, target + + +class DatasetCIFARMix(DatasetCIFAR): + def __init__( + self, + train: bool = False, + path_pattern: str = "./", + path_label: str = "./", + alpha: float = 1.0, + ) -> None: + super().__init__(train, path_pattern, path_label) + self.alpha = alpha + + def __getitem__(self, index: int) -> tuple[torch.Tensor, list[int]]: # type: ignore + + assert self.alpha >= 0.0 + assert self.alpha <= 1.0 + + image_a, target_a = super().__getitem__(index) + + target_b: int = target_a + while target_b == target_a: + image_b, target_b = super().__getitem__( + int(math.floor(self.number_of_pattern * torch.rand((1)).item())) + ) + + image = self.alpha * image_a + (1.0 - self.alpha) * image_b + target = [target_a, target_b] + return image, target diff --git a/network/HDynamicLayer.py b/network/HDynamicLayer.py index 33712c1..12c9641 100644 --- a/network/HDynamicLayer.py +++ b/network/HDynamicLayer.py @@ -443,7 +443,6 @@ class FunctionalSbS(torch.autograd.Function): ) elif (parameter_output_layer is True) and (parameter_local_learning is True): - target_one_hot: torch.Tensor = torch.zeros( ( labels.shape[0], diff --git a/network/InputSpikeImage.py b/network/InputSpikeImage.py index 03de7da..254a83f 100644 --- a/network/InputSpikeImage.py +++ b/network/InputSpikeImage.py @@ -51,7 +51,13 @@ class InputSpikeImage(torch.nn.Module): def forward(self, input: torch.Tensor) -> torch.Tensor: if self.number_of_spikes < 1: - return input + output = input + output = output.type(dtype=input.dtype) + if self._normalize is True: + output = output * output.shape[-1] * output.shape[-2] * output.shape[-3] / output.sum(dim=-1, keepdim=True).sum( + dim=-2, keepdim=True + ).sum(dim=-3, keepdim=True) + return output input_shape: list[int] = [ int(input.shape[0]), @@ -95,9 +101,10 @@ class InputSpikeImage(torch.nn.Module): ) ) + output = output.type(dtype=input_work.dtype) + if self._normalize is True: - output = output.type(dtype=input_work.dtype) - output = output / output.sum(dim=-1, keepdim=True).sum( + output = output * output.shape[-1] * output.shape[-2] * output.shape[-3] / output.sum(dim=-1, keepdim=True).sum( dim=-2, keepdim=True ).sum(dim=-3, keepdim=True) diff --git a/network/NNMFLayer.py b/network/NNMFLayer.py new file mode 100644 index 0000000..7b7d5e8 --- /dev/null +++ b/network/NNMFLayer.py @@ -0,0 +1,280 @@ +import torch + +from network.calculate_output_size import calculate_output_size + + +class NNMFLayer(torch.nn.Module): + + _epsilon_0: float + _weights: torch.nn.parameter.Parameter + _weights_exists: bool = False + _kernel_size: list[int] + _stride: list[int] + _dilation: list[int] + _padding: list[int] + _output_size: torch.Tensor + _number_of_neurons: int + _number_of_input_neurons: int + _h_initial: torch.Tensor | None = None + _w_trainable: bool + _weight_noise_range: list[float] + _input_size: list[int] + _output_layer: bool = False + _number_of_iterations: int + _local_learning: bool = False + + device: torch.device + default_dtype: torch.dtype + + _number_of_grad_weight_contributions: float = 0.0 + + last_input_store: bool = False + last_input_data: torch.Tensor | None = None + + _layer_id: int = -1 + + def __init__( + self, + number_of_input_neurons: int, + number_of_neurons: int, + input_size: list[int], + forward_kernel_size: list[int], + number_of_iterations: int, + epsilon_0: float = 1.0, + weight_noise_range: list[float] = [0.0, 1.0], + strides: list[int] = [1, 1], + dilation: list[int] = [0, 0], + padding: list[int] = [0, 0], + w_trainable: bool = False, + device: torch.device | None = None, + default_dtype: torch.dtype | None = None, + layer_id: int = -1, + local_learning: bool = False, + output_layer: bool = False, + ) -> None: + super().__init__() + + assert device is not None + assert default_dtype is not None + self.device = device + self.default_dtype = default_dtype + + self._w_trainable = bool(w_trainable) + self._stride = strides + self._dilation = dilation + self._padding = padding + self._kernel_size = forward_kernel_size + self._number_of_input_neurons = int(number_of_input_neurons) + self._number_of_neurons = int(number_of_neurons) + self._epsilon_0 = float(epsilon_0) + self._number_of_iterations = int(number_of_iterations) + self._weight_noise_range = weight_noise_range + self._layer_id = layer_id + self._local_learning = local_learning + self._output_layer = output_layer + + assert len(input_size) == 2 + self._input_size = input_size + + self._output_size = calculate_output_size( + value=input_size, + kernel_size=self._kernel_size, + stride=self._stride, + dilation=self._dilation, + padding=self._padding, + ) + + self.set_h_init_to_uniform() + + # ############################################################### + # Initialize the weights + # ############################################################### + + assert len(self._weight_noise_range) == 2 + weights = torch.empty( + ( + int(self._kernel_size[0]) + * int(self._kernel_size[1]) + * int(self._number_of_input_neurons), + int(self._number_of_neurons), + ), + dtype=self.default_dtype, + device=self.device, + ) + + torch.nn.init.uniform_( + weights, + a=float(self._weight_noise_range[0]), + b=float(self._weight_noise_range[1]), + ) + self.weights = weights + + @property + def weights(self) -> torch.Tensor | None: + if self._weights_exists is False: + return None + else: + return self._weights + + @weights.setter + def weights(self, value: torch.Tensor): + assert value is not None + assert torch.is_tensor(value) is True + assert value.dim() == 2 + temp: torch.Tensor = ( + value.detach() + .clone(memory_format=torch.contiguous_format) + .type(dtype=self.default_dtype) + .to(device=self.device) + ) + temp /= temp.sum(dim=0, keepdim=True, dtype=self.default_dtype) + if self._weights_exists is False: + self._weights = torch.nn.parameter.Parameter(temp, requires_grad=True) + self._weights_exists = True + else: + self._weights.data = temp + + @property + def h_initial(self) -> torch.Tensor | None: + return self._h_initial + + @h_initial.setter + def h_initial(self, value: torch.Tensor): + assert value is not None + assert torch.is_tensor(value) is True + assert value.dim() == 1 + assert value.dtype == self.default_dtype + self._h_initial = ( + value.detach() + .clone(memory_format=torch.contiguous_format) + .type(dtype=self.default_dtype) + .to(device=self.device) + .requires_grad_(False) + ) + + def update_pre_care(self): + + if self._weights.grad is not None: + assert self._number_of_grad_weight_contributions > 0 + self._weights.grad /= self._number_of_grad_weight_contributions + self._number_of_grad_weight_contributions = 0.0 + + def update_after_care(self, threshold_weight: float): + + if self._w_trainable is True: + self.norm_weights() + self.threshold_weights(threshold_weight) + self.norm_weights() + + def set_h_init_to_uniform(self) -> None: + + assert self._number_of_neurons > 2 + + self.h_initial: torch.Tensor = torch.full( + (self._number_of_neurons,), + (1.0 / float(self._number_of_neurons)), + dtype=self.default_dtype, + device=self.device, + ) + + def norm_weights(self) -> None: + assert self._weights_exists is True + temp: torch.Tensor = ( + self._weights.data.detach() + .clone(memory_format=torch.contiguous_format) + .type(dtype=self.default_dtype) + .to(device=self.device) + ) + temp /= temp.sum(dim=0, keepdim=True, dtype=self.default_dtype) + self._weights.data = temp + + def threshold_weights(self, threshold: float) -> None: + assert self._weights_exists is True + assert threshold >= 0 + + torch.clamp( + self._weights.data, + min=float(threshold), + max=None, + out=self._weights.data, + ) + + #################################################################### + # Forward # + #################################################################### + + def forward( + self, + input: torch.Tensor, + ) -> torch.Tensor: + + # Are we happy with the input? + assert input is not None + assert torch.is_tensor(input) is True + assert input.dim() == 4 + assert input.dtype == self.default_dtype + assert input.shape[1] == self._number_of_input_neurons + assert input.shape[2] == self._input_size[0] + assert input.shape[3] == self._input_size[1] + + # Are we happy with the rest of the network? + assert self._epsilon_0 is not None + assert self._h_initial is not None + assert self._weights_exists is True + assert self._weights is not None + + # Convolution of the input... + # Well, this is a convoltion layer + # there needs to be convolution somewhere + input_convolved = torch.nn.functional.fold( + torch.nn.functional.unfold( + input.requires_grad_(True), + kernel_size=(int(self._kernel_size[0]), int(self._kernel_size[1])), + dilation=(int(self._dilation[0]), int(self._dilation[1])), + padding=(int(self._padding[0]), int(self._padding[1])), + stride=(int(self._stride[0]), int(self._stride[1])), + ), + output_size=tuple(self._output_size.tolist()), + kernel_size=(1, 1), + dilation=(1, 1), + padding=(0, 0), + stride=(1, 1), + ) + + # We might need the convolved input for other layers + # let us keep it for the future + if self.last_input_store is True: + self.last_input_data = input_convolved.detach().clone() + self.last_input_data /= self.last_input_data.sum(dim=1, keepdim=True) + else: + self.last_input_data = None + + input_convolved = input_convolved / input_convolved.sum(dim=1, keepdim=True) + + h = torch.tile( + self._h_initial.unsqueeze(0).unsqueeze(-1).unsqueeze(-1), + dims=[ + int(input.shape[0]), + 1, + int(self._output_size[0]), + int(self._output_size[1]), + ], + ).requires_grad_(True) + + for _ in range(0, self._number_of_iterations): + h_w = h.unsqueeze(1) * self._weights.unsqueeze(0).unsqueeze(-1).unsqueeze( + -1 + ) + h_w = h_w / (h_w.sum(dim=2, keepdim=True) + 1e-20) + h_w = (h_w * input_convolved.unsqueeze(2)).sum(dim=1) + if self._epsilon_0 > 0: + h = h + self._epsilon_0 * h_w + else: + h = h_w + h = h / (h.sum(dim=1, keepdim=True) + 1e-20) + + self._number_of_grad_weight_contributions += ( + h.shape[0] * h.shape[-2] * h.shape[-1] + ) + + return h diff --git a/network/Parameter.py b/network/Parameter.py index 8847b4e..8c6a1bf 100644 --- a/network/Parameter.py +++ b/network/Parameter.py @@ -44,7 +44,7 @@ class LearningParameters: overload_path: str = field(default="Previous") weight_noise_range: list[float] = field(default_factory=list) - eps_xy_intitial: float = field(default=0.1) + eps_xy_intitial: float = field(default=1.0) disable_scale_grade: bool = field(default=False) kepp_last_grad_scale: bool = field(default=True) @@ -55,7 +55,6 @@ class LearningParameters: w_trainable: list[bool] = field(default_factory=list) - @dataclass class Augmentation: """Parameters used for data augmentation.""" diff --git a/network/SbSLayer.py b/network/SbSLayer.py index ae593a6..ce86abf 100644 --- a/network/SbSLayer.py +++ b/network/SbSLayer.py @@ -87,6 +87,8 @@ class SbSLayer(torch.nn.Module): spike_full_layer_input_distribution: bool = False, force_forward_spike_on_cpu: bool = False, force_forward_spike_output_on_cpu: bool = False, + local_learning: bool = False, + output_layer: bool = False, ) -> None: super().__init__() @@ -117,6 +119,8 @@ class SbSLayer(torch.nn.Module): self._epsilon_xy_use = epsilon_xy_use self._force_forward_h_dynamic_on_cpu = force_forward_h_dynamic_on_cpu self._spike_full_layer_input_distribution = spike_full_layer_input_distribution + self._local_learning = local_learning + self._output_layer = output_layer assert len(input_size) == 2 self._input_size = input_size diff --git a/network/SplitOnOffLayer.py b/network/SplitOnOffLayer.py index 4190917..a884703 100644 --- a/network/SplitOnOffLayer.py +++ b/network/SplitOnOffLayer.py @@ -28,27 +28,28 @@ class SplitOnOffLayer(torch.nn.Module): def forward(self, input: torch.Tensor) -> torch.Tensor: assert input.ndim == 4 - # self.training is switched by network.eval() and network.train() - if self.training is True: - mean_temp = ( - input.mean(dim=0, keepdim=True) - .mean(dim=1, keepdim=True) - .detach() - .clone() - ) +# # self.training is switched by network.eval() and network.train() +# if self.training is True: +# mean_temp = ( +# input.mean(dim=0, keepdim=True) +# .mean(dim=1, keepdim=True) +# .detach() +# .clone() +# ) +# +# if self.mean is None: +# self.mean = mean_temp +# else: +# self.mean = (1.0 - self.epsilon) * self.mean + self.epsilon * mean_temp +# +# assert self.mean is not None - if self.mean is None: - self.mean = mean_temp - else: - self.mean = (1.0 - self.epsilon) * self.mean + self.epsilon * mean_temp - - assert self.mean is not None - - temp = input - self.mean.detach().clone() +# temp = input - self.mean.detach().clone() + temp = input - 0.5 temp_a = torch.nn.functional.relu(temp) temp_b = torch.nn.functional.relu(-temp) output = torch.cat((temp_a, temp_b), dim=1) - output /= output.sum(dim=1, keepdim=True) + 1e-20 + #output /= output.sum(dim=1, keepdim=True) + 1e-20 return output diff --git a/network/build_datasets.py b/network/build_datasets.py index 3fdb637..6a67fcb 100644 --- a/network/build_datasets.py +++ b/network/build_datasets.py @@ -6,6 +6,11 @@ from network.Dataset import ( DatasetMNIST, DatasetFashionMNIST, ) +from network.DatasetMix import ( + DatasetCIFARMix, + DatasetMNISTMix, + DatasetFashionMNISTMix, +) from network.Parameter import Config @@ -42,6 +47,27 @@ def build_datasets( the_dataset_test = DatasetFashionMNIST( train=False, path_pattern=cfg.data_path, path_label=cfg.data_path ) + elif cfg.data_mode == "MIX_CIFAR10": + the_dataset_train = DatasetCIFARMix( + train=True, path_pattern=cfg.data_path, path_label=cfg.data_path + ) + the_dataset_test = DatasetCIFARMix( + train=False, path_pattern=cfg.data_path, path_label=cfg.data_path + ) + elif cfg.data_mode == "MIX_MNIST": + the_dataset_train = DatasetMNISTMix( + train=True, path_pattern=cfg.data_path, path_label=cfg.data_path + ) + the_dataset_test = DatasetMNISTMix( + train=False, path_pattern=cfg.data_path, path_label=cfg.data_path + ) + elif cfg.data_mode == "MIX_MNIST_FASHION": + the_dataset_train = DatasetFashionMNISTMix( + train=True, path_pattern=cfg.data_path, path_label=cfg.data_path + ) + the_dataset_test = DatasetFashionMNISTMix( + train=False, path_pattern=cfg.data_path, path_label=cfg.data_path + ) else: raise Exception("data_mode unknown") diff --git a/network/build_lr_scheduler.py b/network/build_lr_scheduler.py index e771950..34ed75a 100644 --- a/network/build_lr_scheduler.py +++ b/network/build_lr_scheduler.py @@ -39,7 +39,8 @@ def build_lr_scheduler( ): lr_scheduler_list.append( torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer[id_optimizer],eps=1e-14, + optimizer[id_optimizer], + eps=1e-14, ) ) else: diff --git a/network/build_network.py b/network/build_network.py index ed95395..c0482b2 100644 --- a/network/build_network.py +++ b/network/build_network.py @@ -4,6 +4,7 @@ import torch from network.calculate_output_size import calculate_output_size from network.Parameter import Config from network.SbSLayer import SbSLayer +from network.NNMFLayer import NNMFLayer from network.SplitOnOffLayer import SplitOnOffLayer from network.Conv2dApproximation import Conv2dApproximation from network.SbSReconstruction import SbSReconstruction @@ -153,6 +154,14 @@ def build_network( if cfg.network_structure.layer_type[layer_id].upper().find("POOLING") != -1: is_pooling_layer = True + local_learning = False + if cfg.network_structure.layer_type[layer_id].upper().find("LOCAL") != -1: + local_learning = True + + output_layer = False + if layer_id == len(cfg.network_structure.layer_type) - 1: + output_layer = True + network.append( SbSLayer( number_of_input_neurons=in_channels, @@ -180,19 +189,13 @@ def build_network( reduction_cooldown=cfg.reduction_cooldown, force_forward_h_dynamic_on_cpu=cfg.force_forward_h_dynamic_on_cpu, spike_full_layer_input_distribution=spike_full_layer_input_distribution, + local_learning=local_learning, + output_layer=output_layer, ) ) # Adding the x,y output dimensions input_size.append(network[-1]._output_size.tolist()) - network[-1]._output_layer = False - if layer_id == len(cfg.network_structure.layer_type) - 1: - network[-1]._output_layer = True - - network[-1]._local_learning = False - if cfg.network_structure.layer_type[layer_id].upper().find("LOCAL") != -1: - network[-1]._local_learning = True - elif ( cfg.network_structure.layer_type[layer_id] .upper() @@ -276,6 +279,8 @@ def build_network( ): logging.info(f"Layer: {layer_id} -> RELU Layer") network.append(torch.nn.ReLU()) + network[-1]._w_trainable = False + input_size.append(input_size[-1]) # ############################################################# @@ -296,6 +301,8 @@ def build_network( ) ) + network[-1]._w_trainable = False + # Calculate the x,y output dimensions input_size_temp = calculate_output_size( value=input_size[-1], @@ -323,6 +330,9 @@ def build_network( padding=(int(padding[0]), int(padding[1])), ) ) + + network[-1]._w_trainable = False + # Calculate the x,y output dimensions input_size_temp = calculate_output_size( value=input_size[-1], @@ -405,8 +415,67 @@ def build_network( ) ) + network[-1]._w_trainable = False + input_size.append(input_size[-1]) + # ############################################################# + # NNMF: + # ############################################################# + + elif ( + cfg.network_structure.layer_type[layer_id].upper().startswith("NNMF") + is True + ): + + assert in_channels > 0 + assert out_channels > 0 + + number_of_iterations: int = -1 + if len(cfg.number_of_spikes) > layer_id: + number_of_iterations = cfg.number_of_spikes[layer_id] + elif len(cfg.number_of_spikes) == 1: + number_of_iterations = cfg.number_of_spikes[0] + + assert number_of_iterations > 0 + + logging.info( + ( + f"Layer: {layer_id} -> NNMF Layer with {number_of_iterations} iterations " + ) + ) + + local_learning = False + if cfg.network_structure.layer_type[layer_id].upper().find("LOCAL") != -1: + local_learning = True + + output_layer = False + if layer_id == len(cfg.network_structure.layer_type) - 1: + output_layer = True + + network.append( + NNMFLayer( + number_of_input_neurons=in_channels, + number_of_neurons=out_channels, + input_size=input_size[-1], + forward_kernel_size=kernel_size, + number_of_iterations=number_of_iterations, + epsilon_0=cfg.epsilon_0, + weight_noise_range=weight_noise_range, + strides=strides, + dilation=dilation, + padding=padding, + w_trainable=w_trainable, + device=device, + default_dtype=default_dtype, + layer_id=layer_id, + local_learning=local_learning, + output_layer=output_layer, + ) + ) + # Adding the x,y output dimensions + input_size.append(network[-1]._output_size.tolist()) + # ############################################################# # Failure becaue we didn't found the selection of layer # ############################################################# diff --git a/network/build_optimizer.py b/network/build_optimizer.py index 91888fe..ca2f1dd 100644 --- a/network/build_optimizer.py +++ b/network/build_optimizer.py @@ -2,6 +2,8 @@ import torch from network.Parameter import Config from network.SbSLayer import SbSLayer +from network.NNMFLayer import NNMFLayer + from network.Conv2dApproximation import Conv2dApproximation from network.Adam import Adam @@ -26,6 +28,12 @@ def build_optimizer( parameter_list_weights.append(network[id]._weights) parameter_list_sbs.append(True) + if (isinstance(network[id], NNMFLayer) is True) and ( + network[id]._w_trainable is True + ): + parameter_list_weights.append(network[id]._weights) + parameter_list_sbs.append(True) + if (isinstance(network[id], torch.nn.modules.conv.Conv2d) is True) and ( network[id]._w_trainable is True ): diff --git a/network/load_previous_weights.py b/network/load_previous_weights.py index 80fd13c..999c3a2 100644 --- a/network/load_previous_weights.py +++ b/network/load_previous_weights.py @@ -4,6 +4,8 @@ import glob import numpy as np from network.SbSLayer import SbSLayer +from network.NNMFLayer import NNMFLayer + from network.SplitOnOffLayer import SplitOnOffLayer from network.Conv2dApproximation import Conv2dApproximation import os @@ -46,6 +48,23 @@ def load_previous_weights( ) logging.info(f"Weights file used for layer {id} : {file_to_load[0]}") + if isinstance(network[id], NNMFLayer) is True: + filename_wilcard = os.path.join( + overload_path, f"Weight_L{id}_*{post_fix}.npy" + ) + file_to_load = glob.glob(filename_wilcard) + + if len(file_to_load) > 1: + raise Exception(f"Too many previous weights files {filename_wilcard}") + + if len(file_to_load) == 1: + network[id].weights = torch.tensor( + np.load(file_to_load[0]), + dtype=default_dtype, + device=device, + ) + logging.info(f"Weights file used for layer {id} : {file_to_load[0]}") + if isinstance(network[id], torch.nn.modules.conv.Conv2d) is True: # ################################################# diff --git a/network/loop_train_test.py b/network/loop_train_test.py index bc594de..ce9ed02 100644 --- a/network/loop_train_test.py +++ b/network/loop_train_test.py @@ -4,6 +4,7 @@ from network.Parameter import Config from torch.utils.tensorboard import SummaryWriter from network.SbSLayer import SbSLayer +from network.NNMFLayer import NNMFLayer from network.save_weight_and_bias import save_weight_and_bias from network.SbSReconstruction import SbSReconstruction @@ -19,7 +20,9 @@ def add_weight_and_bias_to_histogram( # ################################################ # Log the SbS Weights # ################################################ - if isinstance(network[id], SbSLayer) is True: + if (isinstance(network[id], SbSLayer) is True) or ( + isinstance(network[id], NNMFLayer) is True + ): if network[id]._w_trainable is True: try: @@ -228,7 +231,9 @@ def run_optimizer( cfg: Config, ) -> None: for id in range(0, len(network)): - if isinstance(network[id], SbSLayer) is True: + if (isinstance(network[id], SbSLayer) is True) or ( + isinstance(network[id], NNMFLayer) is True + ): network[id].update_pre_care() for optimizer_item in optimizer: @@ -236,7 +241,9 @@ def run_optimizer( optimizer_item.step() for id in range(0, len(network)): - if isinstance(network[id], SbSLayer) is True: + if (isinstance(network[id], SbSLayer) is True) or ( + isinstance(network[id], NNMFLayer) is True + ): network[id].update_after_care( cfg.learning_parameters.learning_rate_threshold_w / float( @@ -618,6 +625,94 @@ def loop_test( return performance +def loop_test_mix( + epoch_id: int, + cfg: Config, + network: torch.nn.modules.container.Sequential, + my_loader_test: torch.utils.data.dataloader.DataLoader, + the_dataset_test, + device: torch.device, + default_dtype: torch.dtype, + logging, + tb: SummaryWriter | None, + overwrite_number_of_spikes: int = -1, +) -> tuple[float, float]: + + test_correct_a_0: int = 0 + test_correct_a_1: int = 0 + test_correct_b_0: int = 0 + test_correct_b_1: int = 0 + + test_count: int = 0 + test_complete: int = the_dataset_test.__len__() + + logging.info("") + logging.info("Testing:") + mini_batch_id: int = 0 + + for h_x, h_x_labels in my_loader_test: + assert len(h_x_labels) == 2 + label_a = h_x_labels[0] + label_b = h_x_labels[1] + assert label_a.shape[0] == label_b.shape[0] + assert h_x.shape[0] == label_b.shape[0] + + time_0 = time.perf_counter() + + h_collection = forward_pass_test( + input=h_x, + labels=label_a, + the_dataset_test=the_dataset_test, + cfg=cfg, + network=network, + device=device, + default_dtype=default_dtype, + mini_batch_id=mini_batch_id, + overwrite_number_of_spikes=overwrite_number_of_spikes, + ) + h_h: torch.Tensor = h_collection[-1].detach().clone().cpu() + + # ------------- + + for id in range(0, h_h.shape[0]): + temp = h_h[id, ...].squeeze().argsort(descending=True) + test_correct_a_0 += float(label_a[id] == int(temp[0])) + test_correct_a_1 += float(label_a[id] == int(temp[1])) + test_correct_b_0 += float(label_b[id] == int(temp[0])) + test_correct_b_1 += float(label_b[id] == int(temp[1])) + + test_count += h_h.shape[0] + performance_a_0: float = 100.0 * test_correct_a_0 / test_count + performance_a_1: float = 100.0 * test_correct_a_1 / test_count + performance_b_0: float = 100.0 * test_correct_b_0 / test_count + performance_b_1: float = 100.0 * test_correct_b_1 / test_count + time_1 = time.perf_counter() + time_measure_a = time_1 - time_0 + + logging.info( + ( + f"\t\t{test_count} of {test_complete}" + f" with {performance_a_0/100:^6.2%}, " + f"{performance_a_1/100:^6.2%}, " + f"{performance_b_0/100:^6.2%}, " + f"{performance_b_1/100:^6.2%} \t " + f"Time used: {time_measure_a:^6.2f}sec" + ) + ) + mini_batch_id += 1 + + logging.info("") + + if tb is not None: + tb.add_scalar("Test Error A0", 100.0 - performance_a_0, epoch_id) + tb.add_scalar("Test Error A1", 100.0 - performance_a_1, epoch_id) + tb.add_scalar("Test Error B0", 100.0 - performance_b_0, epoch_id) + tb.add_scalar("Test Error B1", 100.0 - performance_b_1, epoch_id) + tb.flush() + + return performance_a_0, performance_a_1, performance_b_0, performance_b_1 + + def loop_test_reconstruction( epoch_id: int, cfg: Config, diff --git a/network/save_weight_and_bias.py b/network/save_weight_and_bias.py index ad4e178..61f70a8 100644 --- a/network/save_weight_and_bias.py +++ b/network/save_weight_and_bias.py @@ -4,6 +4,7 @@ from network.Parameter import Config import numpy as np from network.SbSLayer import SbSLayer +from network.NNMFLayer import NNMFLayer from network.SplitOnOffLayer import SplitOnOffLayer from network.Conv2dApproximation import Conv2dApproximation @@ -38,6 +39,21 @@ def save_weight_and_bias( network[id].weights.detach().cpu().numpy(), ) + # ################################################ + # Save the NNMF Weights + # ################################################ + + if isinstance(network[id], NNMFLayer) is True: + if network[id]._w_trainable is True: + + np.save( + os.path.join( + cfg.weight_path, + f"Weight_L{id}_S{iteration_number}{post_fix}.npy", + ), + network[id].weights.detach().cpu().numpy(), + ) + # ################################################ # Save the Conv2 Weights and Biases # ################################################ @@ -88,9 +104,10 @@ def save_weight_and_bias( if isinstance(network[id], SplitOnOffLayer) is True: - np.save( - os.path.join( - cfg.weight_path, f"Mean_L{id}_S{iteration_number}{post_fix}.npy" - ), - network[id].mean.detach().cpu().numpy(), - ) + if network[id].mean is not None: + np.save( + os.path.join( + cfg.weight_path, f"Mean_L{id}_S{iteration_number}{post_fix}.npy" + ), + network[id].mean.detach().cpu().numpy(), + )