From dcee82fca67ce8bc1d1b79c666a0b16c0c40b6fd Mon Sep 17 00:00:00 2001 From: David Rotermund <54365609+davrot@users.noreply.github.com> Date: Sun, 29 Jan 2023 00:58:28 +0100 Subject: [PATCH] Add files via upload --- network/Parameter.py | 18 +---- network/SbS.py | 161 +++++++++++++++++++++++++++---------- network/build_network.py | 3 +- network/loop_train_test.py | 41 ++++++++-- 4 files changed, 158 insertions(+), 65 deletions(-) diff --git a/network/Parameter.py b/network/Parameter.py index 1469334..b9e3924 100644 --- a/network/Parameter.py +++ b/network/Parameter.py @@ -1,7 +1,6 @@ # %% from dataclasses import dataclass, field import numpy as np -import torch import os @@ -101,6 +100,8 @@ class Config: default_factory=ApproximationSetting ) + extract_noisy_pictures: bool = field(default=False) + # For labeling simulations # (not actively used) simulation_id: int = field(default=0) @@ -163,21 +164,6 @@ class Config: self.batch_size = np.max((self.batch_size, self.number_of_cpu_processes)) self.batch_size = int(self.batch_size) - def get_epsilon_t(self, number_of_spikes: int): - """Generates the time series of the basic epsilon.""" - t = np.arange(0, number_of_spikes, dtype=np.float32) + 1 - np_epsilon_t: np.ndarray = t ** ( - -1.0 / 2.0 - ) # np.ones((number_of_spikes), dtype=np.float32) - - if (self.cooldown_after_number_of_spikes < number_of_spikes) and ( - self.cooldown_after_number_of_spikes >= 0 - ): - np_epsilon_t[ - self.cooldown_after_number_of_spikes : number_of_spikes - ] /= self.reduction_cooldown - return torch.tensor(np_epsilon_t) - def get_update_after_x_pattern(self): """Tells us after how many pattern we need to update the weights.""" return ( diff --git a/network/SbS.py b/network/SbS.py index da56e01..9ff9ac9 100644 --- a/network/SbS.py +++ b/network/SbS.py @@ -3,6 +3,8 @@ import torch from network.CPP.PySpikeGeneration2DManyIP import SpikeGeneration2DManyIP from network.CPP.PyHDynamicCNNManyIP import HDynamicCNNManyIP from network.calculate_output_size import calculate_output_size +import os +import numpy as np global_sbs_gpu_setting: list[torch.Tensor] = [] global_sbs_size: list[torch.Tensor] = [] @@ -16,7 +18,6 @@ class SbS(torch.nn.Module): _epsilon_xy: torch.Tensor | None = None _epsilon_0: float - _epsilon_t: torch.Tensor | None = None _weights: torch.nn.parameter.Parameter _weights_exists: bool = False _kernel_size: list[int] @@ -58,6 +59,9 @@ class SbS(torch.nn.Module): spike_generation_cpp_position: int = -1 spike_generation_gpu_setting_position: int = -1 + _cooldown_after_number_of_spikes: int = -1 + _reduction_cooldown: float = 1.0 + def __init__( self, number_of_input_neurons: int, @@ -65,7 +69,6 @@ class SbS(torch.nn.Module): input_size: list[int], forward_kernel_size: list[int], number_of_spikes: int, - epsilon_t: torch.Tensor, epsilon_xy_intitial: float = 0.1, epsilon_0: float = 1.0, weight_noise_range: list[float] = [0.0, 1.0], @@ -83,6 +86,8 @@ class SbS(torch.nn.Module): default_dtype: torch.dtype | None = None, gpu_tuning_factor: int = 5, layer_id: int = -1, + cooldown_after_number_of_spikes: int = -1, + reduction_cooldown: float = 1.0, ) -> None: super().__init__() @@ -107,6 +112,8 @@ class SbS(torch.nn.Module): self._number_of_spikes = int(number_of_spikes) self._weight_noise_range = weight_noise_range self._is_pooling_layer = bool(is_pooling_layer) + self._cooldown_after_number_of_spikes = int(cooldown_after_number_of_spikes) + self.reduction_cooldown = float(reduction_cooldown) assert len(input_size) == 2 self._input_size = input_size @@ -145,8 +152,6 @@ class SbS(torch.nn.Module): forgetting_offset, dtype=self.default_dtype, device=self.device ) - self.epsilon_t = epsilon_t.type(dtype=self.default_dtype).to(device=self.device) - self._output_size = calculate_output_size( value=input_size, kernel_size=self._kernel_size, @@ -158,6 +163,7 @@ class SbS(torch.nn.Module): self.set_h_init_to_uniform() self.functional_sbs = FunctionalSbS.apply + self.functional_spike_generation = FunctionalSpikeGeneration.apply # ############################################################### # Initialize the weights @@ -190,22 +196,23 @@ class SbS(torch.nn.Module): # Variables in and out # #################################################################### - @property - def epsilon_t(self) -> torch.Tensor | None: - return self._epsilon_t + def get_epsilon_t(self, number_of_spikes: int): + """Generates the time series of the basic epsilon.""" + t = np.arange(0, number_of_spikes, dtype=np.float32) + 1 + np_epsilon_t: np.ndarray = t ** ( + -1.0 / 2.0 + ) # np.ones((number_of_spikes), dtype=np.float32) - @epsilon_t.setter - def epsilon_t(self, value: torch.Tensor): - assert value is not None - assert torch.is_tensor(value) is True - assert value.dim() == 1 - assert value.dtype == self.default_dtype - self._epsilon_t = ( - value.detach() - .clone(memory_format=torch.contiguous_format) + if (self._cooldown_after_number_of_spikes < number_of_spikes) and ( + self._cooldown_after_number_of_spikes >= 0 + ): + np_epsilon_t[ + self._cooldown_after_number_of_spikes : number_of_spikes + ] /= self._reduction_cooldown + return ( + torch.tensor(np_epsilon_t) .type(dtype=self.default_dtype) .to(device=self.device) - .requires_grad_(False) ) @property @@ -348,7 +355,13 @@ class SbS(torch.nn.Module): #################################################################### def forward( - self, input: torch.Tensor, labels: torch.Tensor | None = None + self, + input: torch.Tensor, + labels: torch.Tensor | None = None, + extract_noisy_pictures: bool = False, + layer_id: int = -1, + mini_batch_id: int = -1, + overwrite_number_of_spikes: int = -1, ) -> torch.Tensor: # Are we happy with the input? @@ -362,7 +375,6 @@ class SbS(torch.nn.Module): # Are we happy with the rest of the network? assert self._epsilon_0 is not None - assert self._epsilon_t is not None assert self._h_initial is not None assert self._forgetting_offset is not None @@ -405,8 +417,15 @@ class SbS(torch.nn.Module): else: self.last_input_data = None + if overwrite_number_of_spikes >= 1: + _number_of_spikes = int(overwrite_number_of_spikes) + else: + _number_of_spikes = int(self._number_of_spikes) + epsilon_t_0: torch.Tensor = ( - (self._epsilon_t * self._epsilon_0).type(input.dtype).to(input.device) + (self.get_epsilon_t(_number_of_spikes) * self._epsilon_0) + .type(input.dtype) + .to(input.device) ) parameter_list = torch.tensor( @@ -415,7 +434,7 @@ class SbS(torch.nn.Module): int(self._disable_scale_grade), # 1 int(self._keep_last_grad_scale), # 2 int(self._skip_gradient_calculation), # 3 - int(self._number_of_spikes), # 4 + int(_number_of_spikes), # 4 int(self._number_of_cpu_processes), # 5 int(self._output_size[0]), # 6 int(self._output_size[1]), # 7 @@ -448,9 +467,46 @@ class SbS(torch.nn.Module): assert self._epsilon_xy.shape[1] == input_convolved.shape[2] assert self._epsilon_xy.shape[2] == input_convolved.shape[3] + spike = self.functional_spike_generation(input_convolved, parameter_list) + + if ( + (extract_noisy_pictures is True) + and (layer_id == 0) + and (labels is not None) + and (mini_batch_id >= 0) + ): + assert labels.shape[0] == spike.shape[0] + + path_sub: str = "noisy_picture_data" + path_sub_spikes: str = f"{int(_number_of_spikes)}" + path = os.path.join(path_sub, path_sub_spikes) + os.makedirs(path_sub, exist_ok=True) + os.makedirs(path, exist_ok=True) + + the_images = torch.zeros_like( + input_convolved, dtype=torch.int64, device=self.device + ) + + for p_id in range(0, the_images.shape[0]): + for sp_id in range(0, spike.shape[1]): + for x_id in range(0, the_images.shape[2]): + for y_id in range(0, the_images.shape[3]): + the_images[ + p_id, spike[p_id, sp_id, x_id, y_id], x_id, y_id + ] += 1 + + np.savez_compressed( + os.path.join(path, f"{mini_batch_id}.npz"), + the_images=the_images.cpu().numpy(), + labels=labels.cpu().numpy(), + ) + + assert spike.shape[1] == _number_of_spikes + # SbS forward functional output = self.functional_sbs( input_convolved, + spike, self._epsilon_xy, epsilon_t_0, self._weights, @@ -468,19 +524,12 @@ class SbS(torch.nn.Module): return output -class FunctionalSbS(torch.autograd.Function): +class FunctionalSpikeGeneration(torch.autograd.Function): @staticmethod def forward( # type: ignore ctx, input: torch.Tensor, - epsilon_xy: torch.Tensor, - epsilon_t_0: torch.Tensor, - weights: torch.Tensor, - h_initial: torch.Tensor, parameter_list: torch.Tensor, - grad_output_scale: torch.Tensor, - forgetting_offset: torch.Tensor, - labels: torch.Tensor, ) -> torch.Tensor: assert input.dim() == 4 @@ -492,17 +541,6 @@ class FunctionalSbS(torch.autograd.Function): else: spike_number_of_cpu_processes = -1 - if input.device == torch.device("cpu"): - hdyn_number_of_cpu_processes: int = int(parameter_list[5]) - else: - hdyn_number_of_cpu_processes = -1 - - output_size_0: int = int(parameter_list[6]) - output_size_1: int = int(parameter_list[7]) - gpu_tuning_factor: int = int(parameter_list[8]) - - sbs_gpu_setting_position = int(parameter_list[11]) - sbs_hdynamic_cpp_position = int(parameter_list[12]) spike_generation_cpp_position = int(parameter_list[13]) spike_generation_gpu_setting_position = int(parameter_list[14]) @@ -615,6 +653,45 @@ class FunctionalSbS(torch.autograd.Function): del random_values del input_cumsum + return spikes + + @staticmethod + def backward(ctx, grad_output): + grad_input = grad_output + grad_parameter_list = None + return (grad_input, grad_parameter_list) + + +class FunctionalSbS(torch.autograd.Function): + @staticmethod + def forward( # type: ignore + ctx, + input: torch.Tensor, + spikes: torch.Tensor, + epsilon_xy: torch.Tensor, + epsilon_t_0: torch.Tensor, + weights: torch.Tensor, + h_initial: torch.Tensor, + parameter_list: torch.Tensor, + grad_output_scale: torch.Tensor, + forgetting_offset: torch.Tensor, + labels: torch.Tensor, + ) -> torch.Tensor: + + number_of_spikes: int = int(parameter_list[4]) + + if input.device == torch.device("cpu"): + hdyn_number_of_cpu_processes: int = int(parameter_list[5]) + else: + hdyn_number_of_cpu_processes = -1 + + output_size_0: int = int(parameter_list[6]) + output_size_1: int = int(parameter_list[7]) + gpu_tuning_factor: int = int(parameter_list[8]) + + sbs_gpu_setting_position = int(parameter_list[11]) + sbs_hdynamic_cpp_position = int(parameter_list[12]) + # ########################################################### # H dynamic # ########################################################### @@ -713,7 +790,6 @@ class FunctionalSbS(torch.autograd.Function): float(forgetting_offset.item()), int(gpu_tuning_factor), ) - del spikes # ########################################################### # Save the necessary data for the backward pass @@ -750,6 +826,7 @@ class FunctionalSbS(torch.autograd.Function): # Default output # ############################################## grad_input = None + grad_spikes = None grad_eps_xy = None grad_epsilon_t_0 = None grad_weights = None @@ -789,6 +866,7 @@ class FunctionalSbS(torch.autograd.Function): return ( grad_input, + grad_spikes, grad_eps_xy, grad_epsilon_t_0, grad_weights, @@ -894,6 +972,7 @@ class FunctionalSbS(torch.autograd.Function): return ( grad_input, + grad_spikes, grad_eps_xy, grad_epsilon_t_0, grad_weights, diff --git a/network/build_network.py b/network/build_network.py index f075a0b..c506eb2 100644 --- a/network/build_network.py +++ b/network/build_network.py @@ -150,7 +150,6 @@ def build_network( input_size=input_size[-1], forward_kernel_size=kernel_size, number_of_spikes=number_of_spikes, - epsilon_t=cfg.get_epsilon_t(number_of_spikes), epsilon_xy_intitial=cfg.learning_parameters.eps_xy_intitial, epsilon_0=cfg.epsilon_0, weight_noise_range=weight_noise_range, @@ -167,6 +166,8 @@ def build_network( device=device, default_dtype=default_dtype, layer_id=layer_id, + cooldown_after_number_of_spikes=cfg.cooldown_after_number_of_spikes, + reduction_cooldown=cfg.reduction_cooldown, ) ) # Adding the x,y output dimensions diff --git a/network/loop_train_test.py b/network/loop_train_test.py index fc75c9e..2593af8 100644 --- a/network/loop_train_test.py +++ b/network/loop_train_test.py @@ -185,11 +185,14 @@ def forward_pass_train( def forward_pass_test( input: torch.Tensor, + labels: torch.Tensor | None, the_dataset_test, cfg: Config, network: torch.nn.modules.container.Sequential, device: torch.device, default_dtype: torch.dtype, + mini_batch_id: int = -1, + overwrite_number_of_spikes: int = -1, ) -> list[torch.Tensor]: h_collection = [] @@ -199,7 +202,22 @@ def forward_pass_test( .to(device=device) ) for id in range(0, len(network)): - h_collection.append(network[id](h_collection[-1])) + if (cfg.extract_noisy_pictures is True) or (overwrite_number_of_spikes != -1): + if isinstance(network[id], SbS) is True: + h_collection.append( + network[id]( + h_collection[-1], + layer_id=id, + labels=labels, + extract_noisy_pictures=cfg.extract_noisy_pictures, + mini_batch_id=mini_batch_id, + overwrite_number_of_spikes=overwrite_number_of_spikes, + ) + ) + else: + h_collection.append(network[id](h_collection[-1])) + else: + h_collection.append(network[id](h_collection[-1])) return h_collection @@ -545,7 +563,8 @@ def loop_test( device: torch.device, default_dtype: torch.dtype, logging, - tb: SummaryWriter, + tb: SummaryWriter | None, + overwrite_number_of_spikes: int = -1, ) -> float: test_correct = 0 @@ -554,17 +573,21 @@ def loop_test( logging.info("") logging.info("Testing:") + mini_batch_id: int = 0 for h_x, h_x_labels in my_loader_test: time_0 = time.perf_counter() h_collection = forward_pass_test( input=h_x, + labels=h_x_labels, the_dataset_test=the_dataset_test, cfg=cfg, network=network, device=device, default_dtype=default_dtype, + mini_batch_id=mini_batch_id, + overwrite_number_of_spikes=overwrite_number_of_spikes, ) h_h: torch.Tensor = h_collection[-1].detach().clone().cpu() @@ -580,11 +603,13 @@ def loop_test( f" with {performance/100:^6.2%} \t Time used: {time_measure_a:^6.2f}sec" ) ) + mini_batch_id += 1 logging.info("") - tb.add_scalar("Test Error", 100.0 - performance, epoch_id) - tb.flush() + if tb is not None: + tb.add_scalar("Test Error", 100.0 - performance, epoch_id) + tb.flush() return performance @@ -598,7 +623,7 @@ def loop_test_reconstruction( device: torch.device, default_dtype: torch.dtype, logging, - tb: SummaryWriter, + tb: SummaryWriter | None, ) -> float: test_count: int = 0 @@ -613,6 +638,7 @@ def loop_test_reconstruction( h_collection = forward_pass_test( input=h_x, + labels=None, the_dataset_test=the_dataset_test, cfg=cfg, network=network, @@ -645,7 +671,8 @@ def loop_test_reconstruction( logging.info("") - tb.add_scalar("Test Error", performance, epoch_id) - tb.flush() + if tb is not None: + tb.add_scalar("Test Error", performance, epoch_id) + tb.flush() return performance