diff --git a/network/SbS.py b/network/SbS.py deleted file mode 100644 index 7bf43ce..0000000 --- a/network/SbS.py +++ /dev/null @@ -1,997 +0,0 @@ -import torch - -from network.PySpikeGenerationCPU import SpikeGenerationCPU -from network.PySpikeGenerationGPU import SpikeGenerationGPU - -from network.PyHDynamicCNNCPU import HDynamicCNNCPU -from network.PyHDynamicCNNGPU import HDynamicCNNGPU - -from network.calculate_output_size import calculate_output_size -import os -import numpy as np - -global_sbs_gpu_setting: list[torch.Tensor] = [] -global_sbs_size: list[torch.Tensor] = [] -global_sbs_hdynamic_cpp: list[HDynamicCNNCPU | HDynamicCNNGPU] = [] -global_spike_generation_gpu_setting: list[torch.Tensor] = [] -global_spike_size: list[torch.Tensor] = [] -global_spike_generation_cpp: list[SpikeGenerationCPU | SpikeGenerationGPU] = [] - - -class SbS(torch.nn.Module): - - _epsilon_xy: torch.Tensor | None = None - _epsilon_0: float - _weights: torch.nn.parameter.Parameter - _weights_exists: bool = False - _kernel_size: list[int] - _stride: list[int] - _dilation: list[int] - _padding: list[int] - _output_size: torch.Tensor - _number_of_spikes: int - _number_of_cpu_processes: int - _number_of_neurons: int - _number_of_input_neurons: int - _epsilon_xy_intitial: float - _h_initial: torch.Tensor | None = None - _w_trainable: bool - _last_grad_scale: torch.nn.parameter.Parameter - _keep_last_grad_scale: bool - _disable_scale_grade: bool - _forgetting_offset: torch.Tensor | None = None - _weight_noise_range: list[float] - _skip_gradient_calculation: bool - _is_pooling_layer: bool - _input_size: list[int] - _output_layer: bool = False - _local_learning: bool = False - - device: torch.device - default_dtype: torch.dtype - _gpu_tuning_factor: int - - _max_grad_weights: torch.Tensor | None = None - - _number_of_grad_weight_contributions: float = 0.0 - - last_input_store: bool = False - last_input_data: torch.Tensor | None = None - - sbs_gpu_setting_position: int = -1 - sbs_hdynamic_cpp_position: int = -1 - spike_generation_cpp_position: int = -1 - spike_generation_gpu_setting_position: int = -1 - - _cooldown_after_number_of_spikes: int = -1 - _reduction_cooldown: float = 1.0 - _layer_id: int = (-1,) - - def __init__( - self, - number_of_input_neurons: int, - number_of_neurons: int, - input_size: list[int], - forward_kernel_size: list[int], - number_of_spikes: int, - epsilon_xy_intitial: float = 0.1, - epsilon_0: float = 1.0, - weight_noise_range: list[float] = [0.0, 1.0], - is_pooling_layer: bool = False, - strides: list[int] = [1, 1], - dilation: list[int] = [0, 0], - padding: list[int] = [0, 0], - number_of_cpu_processes: int = 1, - w_trainable: bool = False, - keep_last_grad_scale: bool = False, - disable_scale_grade: bool = True, - forgetting_offset: float = -1.0, - skip_gradient_calculation: bool = False, - device: torch.device | None = None, - default_dtype: torch.dtype | None = None, - gpu_tuning_factor: int = 5, - layer_id: int = -1, - cooldown_after_number_of_spikes: int = -1, - reduction_cooldown: float = 1.0, - ) -> None: - super().__init__() - - assert device is not None - assert default_dtype is not None - self.device = device - self.default_dtype = default_dtype - - self._w_trainable = bool(w_trainable) - self._keep_last_grad_scale = bool(keep_last_grad_scale) - self._skip_gradient_calculation = bool(skip_gradient_calculation) - self._disable_scale_grade = bool(disable_scale_grade) - self._epsilon_xy_intitial = float(epsilon_xy_intitial) - self._stride = strides - self._dilation = dilation - self._padding = padding - self._kernel_size = forward_kernel_size - self._number_of_input_neurons = int(number_of_input_neurons) - self._number_of_neurons = int(number_of_neurons) - self._epsilon_0 = float(epsilon_0) - self._number_of_cpu_processes = int(number_of_cpu_processes) - self._number_of_spikes = int(number_of_spikes) - self._weight_noise_range = weight_noise_range - self._is_pooling_layer = bool(is_pooling_layer) - self._cooldown_after_number_of_spikes = int(cooldown_after_number_of_spikes) - self.reduction_cooldown = float(reduction_cooldown) - self._layer_id = layer_id - - assert len(input_size) == 2 - self._input_size = input_size - - global_sbs_gpu_setting.append(torch.tensor([0])) - global_spike_generation_gpu_setting.append(torch.tensor([0])) - global_sbs_size.append(torch.tensor([0, 0, 0, 0])) - global_spike_size.append(torch.tensor([0, 0, 0, 0])) - - if device == torch.device("cpu"): - global_sbs_hdynamic_cpp.append(HDynamicCNNCPU()) - else: - global_sbs_hdynamic_cpp.append(HDynamicCNNGPU()) - - if device == torch.device("cpu"): - global_spike_generation_cpp.append(SpikeGenerationCPU()) - else: - global_spike_generation_cpp.append(SpikeGenerationGPU()) - - self.sbs_gpu_setting_position = len(global_sbs_gpu_setting) - 1 - self.sbs_hdynamic_cpp_position = len(global_sbs_hdynamic_cpp) - 1 - self.spike_generation_cpp_position = len(global_spike_generation_cpp) - 1 - self.spike_generation_gpu_setting_position = ( - len(global_spike_generation_gpu_setting) - 1 - ) - - # The GPU hates me... - # Too many SbS threads == bad - # Thus I need to limit them... - # (Reminder: We cannot access the mini-batch size here, - # which is part of the GPU thread size calculation...) - if (self._input_size[0] * self._input_size[1]) > gpu_tuning_factor: - self._gpu_tuning_factor = gpu_tuning_factor - else: - self._gpu_tuning_factor = 0 - - self._last_grad_scale = torch.nn.parameter.Parameter( - torch.tensor(-1.0, dtype=self.default_dtype), - requires_grad=True, - ) - - self._forgetting_offset = torch.tensor( - forgetting_offset, dtype=self.default_dtype, device=self.device - ) - - self._output_size = calculate_output_size( - value=input_size, - kernel_size=self._kernel_size, - stride=self._stride, - dilation=self._dilation, - padding=self._padding, - ) - - self.set_h_init_to_uniform() - - self.functional_sbs = FunctionalSbS.apply - self.functional_spike_generation = FunctionalSpikeGeneration.apply - - # ############################################################### - # Initialize the weights - # ############################################################### - - if self._is_pooling_layer is True: - self.weights = self._make_pooling_weights() - - else: - assert len(self._weight_noise_range) == 2 - weights = torch.empty( - ( - int(self._kernel_size[0]) - * int(self._kernel_size[1]) - * int(self._number_of_input_neurons), - int(self._number_of_neurons), - ), - dtype=self.default_dtype, - device=self.device, - ) - - torch.nn.init.uniform_( - weights, - a=float(self._weight_noise_range[0]), - b=float(self._weight_noise_range[1]), - ) - self.weights = weights - - #################################################################### - # Variables in and out # - #################################################################### - - def get_epsilon_t(self, number_of_spikes: int): - """Generates the time series of the basic epsilon.""" - t = np.arange(0, number_of_spikes, dtype=np.float32) + 1 - np_epsilon_t: np.ndarray = t ** ( - -1.0 / 2.0 - ) # np.ones((number_of_spikes), dtype=np.float32) - - if (self._cooldown_after_number_of_spikes < number_of_spikes) and ( - self._cooldown_after_number_of_spikes >= 0 - ): - np_epsilon_t[ - self._cooldown_after_number_of_spikes : number_of_spikes - ] /= self._reduction_cooldown - return ( - torch.tensor(np_epsilon_t) - .type(dtype=self.default_dtype) - .to(device=self.device) - ) - - @property - def weights(self) -> torch.Tensor | None: - if self._weights_exists is False: - return None - else: - return self._weights - - @weights.setter - def weights(self, value: torch.Tensor): - assert value is not None - assert torch.is_tensor(value) is True - assert value.dim() == 2 - temp: torch.Tensor = ( - value.detach() - .clone(memory_format=torch.contiguous_format) - .type(dtype=self.default_dtype) - .to(device=self.device) - ) - temp /= temp.sum(dim=0, keepdim=True, dtype=self.default_dtype) - if self._weights_exists is False: - self._weights = torch.nn.parameter.Parameter(temp, requires_grad=True) - self._weights_exists = True - else: - self._weights.data = temp - - @property - def h_initial(self) -> torch.Tensor | None: - return self._h_initial - - @h_initial.setter - def h_initial(self, value: torch.Tensor): - assert value is not None - assert torch.is_tensor(value) is True - assert value.dim() == 1 - assert value.dtype == self.default_dtype - self._h_initial = ( - value.detach() - .clone(memory_format=torch.contiguous_format) - .type(dtype=self.default_dtype) - .to(device=self.device) - .requires_grad_(False) - ) - - def update_pre_care(self): - - if self._weights.grad is not None: - assert self._number_of_grad_weight_contributions > 0 - self._weights.grad /= self._number_of_grad_weight_contributions - self._number_of_grad_weight_contributions = 0.0 - - def update_after_care(self, threshold_weight: float): - - if self._w_trainable is True: - self.norm_weights() - self.threshold_weights(threshold_weight) - self.norm_weights() - - def after_batch(self, new_state: bool = False): - if self._keep_last_grad_scale is True: - self._last_grad_scale.data = self._last_grad_scale.grad - self._keep_last_grad_scale = new_state - - self._last_grad_scale.grad = torch.zeros_like(self._last_grad_scale.grad) - - #################################################################### - # Helper functions # - #################################################################### - - def _make_pooling_weights(self) -> torch.Tensor: - """For generating the pooling weights.""" - - assert self._number_of_neurons is not None - assert self._kernel_size is not None - - weights: torch.Tensor = torch.zeros( - ( - int(self._kernel_size[0]), - int(self._kernel_size[1]), - int(self._number_of_neurons), - int(self._number_of_neurons), - ), - dtype=self.default_dtype, - device=self.device, - ) - - for i in range(0, int(self._number_of_neurons)): - weights[:, :, i, i] = 1.0 - - weights = weights.moveaxis(-1, 0).moveaxis(-1, 1) - - weights = torch.nn.functional.unfold( - input=weights, - kernel_size=(int(self._kernel_size[0]), int(self._kernel_size[1])), - dilation=(1, 1), - padding=(0, 0), - stride=(1, 1), - ).squeeze() - - weights = torch.moveaxis(weights, 0, 1) - - return weights - - def set_h_init_to_uniform(self) -> None: - - assert self._number_of_neurons > 2 - - self.h_initial: torch.Tensor = torch.full( - (self._number_of_neurons,), - (1.0 / float(self._number_of_neurons)), - dtype=self.default_dtype, - device=self.device, - ) - - def norm_weights(self) -> None: - assert self._weights_exists is True - temp: torch.Tensor = ( - self._weights.data.detach() - .clone(memory_format=torch.contiguous_format) - .type(dtype=self.default_dtype) - .to(device=self.device) - ) - temp /= temp.sum(dim=0, keepdim=True, dtype=self.default_dtype) - self._weights.data = temp - - def threshold_weights(self, threshold: float) -> None: - assert self._weights_exists is True - assert threshold >= 0 - - torch.clamp( - self._weights.data, - min=float(threshold), - max=None, - out=self._weights.data, - ) - - #################################################################### - # Forward # - #################################################################### - - def forward( - self, - input: torch.Tensor, - labels: torch.Tensor | None = None, - extract_noisy_pictures: bool = False, - layer_id: int = -1, - mini_batch_id: int = -1, - overwrite_number_of_spikes: int = -1, - ) -> torch.Tensor: - - # Are we happy with the input? - assert input is not None - assert torch.is_tensor(input) is True - assert input.dim() == 4 - assert input.dtype == self.default_dtype - assert input.shape[1] == self._number_of_input_neurons - assert input.shape[2] == self._input_size[0] - assert input.shape[3] == self._input_size[1] - - # Are we happy with the rest of the network? - assert self._epsilon_0 is not None - - assert self._h_initial is not None - assert self._forgetting_offset is not None - - assert self._weights_exists is True - assert self._weights is not None - - assert self.sbs_gpu_setting_position != -1 - assert self.sbs_hdynamic_cpp_position != -1 - assert self.spike_generation_cpp_position != -1 - assert self.spike_generation_gpu_setting_position != -1 - - if labels is None: - labels_copy: torch.Tensor = torch.tensor( - [], dtype=torch.int64, device=self.device - ) - else: - labels_copy = ( - labels.detach().clone().type(dtype=torch.int64).to(device=self.device) - ) - - input_convolved = torch.nn.functional.fold( - torch.nn.functional.unfold( - input.requires_grad_(True), - kernel_size=(int(self._kernel_size[0]), int(self._kernel_size[1])), - dilation=(int(self._dilation[0]), int(self._dilation[1])), - padding=(int(self._padding[0]), int(self._padding[1])), - stride=(int(self._stride[0]), int(self._stride[1])), - ), - output_size=tuple(self._output_size.tolist()), - kernel_size=(1, 1), - dilation=(1, 1), - padding=(0, 0), - stride=(1, 1), - ) - - if self.last_input_store is True: - self.last_input_data = input_convolved.detach().clone() - self.last_input_data /= self.last_input_data.sum(dim=1, keepdim=True) - else: - self.last_input_data = None - - if overwrite_number_of_spikes >= 1: - _number_of_spikes = int(overwrite_number_of_spikes) - else: - _number_of_spikes = int(self._number_of_spikes) - - epsilon_t_0: torch.Tensor = ( - (self.get_epsilon_t(_number_of_spikes) * self._epsilon_0) - .type(input.dtype) - .to(input.device) - ) - - parameter_list = torch.tensor( - [ - int(self._w_trainable), # 0 - int(self._disable_scale_grade), # 1 - int(self._keep_last_grad_scale), # 2 - int(self._skip_gradient_calculation), # 3 - int(_number_of_spikes), # 4 - int(self._number_of_cpu_processes), # 5 - int(self._output_size[0]), # 6 - int(self._output_size[1]), # 7 - int(self._gpu_tuning_factor), # 8 - int(self._output_layer), # 9 - int(self._local_learning), # 10 - int(self.sbs_gpu_setting_position), # 11 - int(self.sbs_hdynamic_cpp_position), # 12 - int(self.spike_generation_cpp_position), # 13 - int(self.spike_generation_gpu_setting_position), # 14 - ], - dtype=torch.int64, - ) - - if self._epsilon_xy is None: - self._epsilon_xy = torch.full( - ( - input_convolved.shape[1], - input_convolved.shape[2], - input_convolved.shape[3], - ), - float(self._epsilon_xy_intitial), - dtype=self.default_dtype, - device=self.device, - ) - - assert self._epsilon_xy is not None - # In the case somebody tried to replace the matrix with wrong dimensions - assert self._epsilon_xy.shape[0] == input_convolved.shape[1] - assert self._epsilon_xy.shape[1] == input_convolved.shape[2] - assert self._epsilon_xy.shape[2] == input_convolved.shape[3] - - spike = self.functional_spike_generation(input_convolved, parameter_list) - - if ( - (extract_noisy_pictures is True) - and (layer_id == 0) - and (labels is not None) - and (mini_batch_id >= 0) - ): - assert labels.shape[0] == spike.shape[0] - - path_sub: str = "noisy_picture_data" - path_sub_spikes: str = f"{int(_number_of_spikes)}" - path = os.path.join(path_sub, path_sub_spikes) - os.makedirs(path_sub, exist_ok=True) - os.makedirs(path, exist_ok=True) - - the_images = torch.zeros_like( - input_convolved, dtype=torch.int64, device=self.device - ) - - for p_id in range(0, the_images.shape[0]): - for sp_id in range(0, spike.shape[1]): - for x_id in range(0, the_images.shape[2]): - for y_id in range(0, the_images.shape[3]): - the_images[ - p_id, spike[p_id, sp_id, x_id, y_id], x_id, y_id - ] += 1 - - np.savez_compressed( - os.path.join(path, f"{mini_batch_id}.npz"), - the_images=the_images.cpu().numpy(), - labels=labels.cpu().numpy(), - ) - - assert spike.shape[1] == _number_of_spikes - - # SbS forward functional - output = self.functional_sbs( - input_convolved, - spike, - self._epsilon_xy, - epsilon_t_0, - self._weights, - self._h_initial, - parameter_list, - self._last_grad_scale, - self._forgetting_offset, - labels_copy, - ) - - self._number_of_grad_weight_contributions += ( - output.shape[0] * output.shape[-2] * output.shape[-1] - ) - - return output - - -class FunctionalSpikeGeneration(torch.autograd.Function): - @staticmethod - def forward( # type: ignore - ctx, - input: torch.Tensor, - parameter_list: torch.Tensor, - ) -> torch.Tensor: - - assert input.dim() == 4 - - number_of_spikes: int = int(parameter_list[4]) - - if input.device == torch.device("cpu"): - spike_number_of_cpu_processes: int = int(parameter_list[5]) - else: - spike_number_of_cpu_processes = -1 - - spike_generation_cpp_position = int(parameter_list[13]) - spike_generation_gpu_setting_position = int(parameter_list[14]) - - # ########################################################### - # Spike generation - # ########################################################### - - # ############################################ - # Normalized cumsum - # (beware of the pytorch bug! Thus .clone()!) - # ############################################ - input_cumsum: torch.Tensor = torch.cumsum(input, dim=1, dtype=input.dtype) - input_cumsum_last: torch.Tensor = input_cumsum[:, -1, :, :].unsqueeze(1).clone() - input_cumsum /= input_cumsum_last - - # ############################################ - # Get the required random numbers - # ############################################ - random_values = torch.rand( - size=[ - input_cumsum.shape[0], - number_of_spikes, - input_cumsum.shape[2], - input_cumsum.shape[3], - ], - dtype=input.dtype, - device=input.device, - ) - - # ############################################ - # Make space for the results - # ############################################ - spikes = torch.empty_like(random_values, dtype=torch.int64, device=input.device) - - assert input_cumsum.is_contiguous() is True - assert random_values.is_contiguous() is True - assert spikes.is_contiguous() is True - - # time_start: float = time.perf_counter() - spike_generation_profile = global_spike_generation_gpu_setting[ - spike_generation_gpu_setting_position - ].clone() - - spike_generation_size = global_spike_size[ - spike_generation_gpu_setting_position - ].clone() - - if input.device != torch.device("cpu"): - if ( - (spike_generation_profile.numel() == 1) - or (spike_generation_size[0] != int(spikes.shape[0])) - or (spike_generation_size[1] != int(spikes.shape[1])) - or (spike_generation_size[2] != int(spikes.shape[2])) - or (spike_generation_size[3] != int(spikes.shape[3])) - ): - spike_generation_profile = torch.zeros( - (1, 7), dtype=torch.int64, device=torch.device("cpu") - ) - - global_spike_generation_cpp[ - spike_generation_cpp_position - ].gpu_occupancy_export( - int(spikes.shape[2]), - int(spikes.shape[3]), - int(spikes.shape[0]), - int(spikes.shape[1]), - spike_generation_profile.data_ptr(), - int(spike_generation_profile.shape[0]), - int(spike_generation_profile.shape[1]), - ) - global_spike_generation_gpu_setting[ - spike_generation_gpu_setting_position - ] = spike_generation_profile.clone() - - spike_generation_size[0] = int(spikes.shape[0]) - spike_generation_size[1] = int(spikes.shape[1]) - spike_generation_size[2] = int(spikes.shape[2]) - spike_generation_size[3] = int(spikes.shape[3]) - global_spike_size[ - spike_generation_gpu_setting_position - ] = spike_generation_size.clone() - - else: - global_spike_generation_cpp[ - spike_generation_cpp_position - ].gpu_occupancy_import( - spike_generation_profile.data_ptr(), - int(spike_generation_profile.shape[0]), - int(spike_generation_profile.shape[1]), - ) - - global_spike_generation_cpp[spike_generation_cpp_position].spike_generation( - input_cumsum.data_ptr(), - int(input_cumsum.shape[0]), - int(input_cumsum.shape[1]), - int(input_cumsum.shape[2]), - int(input_cumsum.shape[3]), - random_values.data_ptr(), - int(random_values.shape[0]), - int(random_values.shape[1]), - int(random_values.shape[2]), - int(random_values.shape[3]), - spikes.data_ptr(), - int(spikes.shape[0]), - int(spikes.shape[1]), - int(spikes.shape[2]), - int(spikes.shape[3]), - int(spike_number_of_cpu_processes), - ) - del random_values - del input_cumsum - - return spikes - - @staticmethod - def backward(ctx, grad_output): - grad_input = grad_output - grad_parameter_list = None - return (grad_input, grad_parameter_list) - - -class FunctionalSbS(torch.autograd.Function): - @staticmethod - def forward( # type: ignore - ctx, - input: torch.Tensor, - spikes: torch.Tensor, - epsilon_xy: torch.Tensor, - epsilon_t_0: torch.Tensor, - weights: torch.Tensor, - h_initial: torch.Tensor, - parameter_list: torch.Tensor, - grad_output_scale: torch.Tensor, - forgetting_offset: torch.Tensor, - labels: torch.Tensor, - ) -> torch.Tensor: - - number_of_spikes: int = int(parameter_list[4]) - - if input.device == torch.device("cpu"): - hdyn_number_of_cpu_processes: int = int(parameter_list[5]) - else: - hdyn_number_of_cpu_processes = -1 - - output_size_0: int = int(parameter_list[6]) - output_size_1: int = int(parameter_list[7]) - gpu_tuning_factor: int = int(parameter_list[8]) - - sbs_gpu_setting_position = int(parameter_list[11]) - sbs_hdynamic_cpp_position = int(parameter_list[12]) - - # ########################################################### - # H dynamic - # ########################################################### - - assert epsilon_t_0.ndim == 1 - assert epsilon_t_0.shape[0] >= number_of_spikes - - # ############################################ - # Make space for the results - # ############################################ - - output = torch.empty( - ( - int(input.shape[0]), - int(weights.shape[1]), - output_size_0, - output_size_1, - ), - dtype=input.dtype, - device=input.device, - ) - - assert output.is_contiguous() is True - assert epsilon_xy.is_contiguous() is True - assert epsilon_t_0.is_contiguous() is True - assert weights.is_contiguous() is True - assert spikes.is_contiguous() is True - assert h_initial.is_contiguous() is True - - assert epsilon_xy.ndim == 3 - assert weights.ndim == 2 - assert h_initial.ndim == 1 - - sbs_profile = global_sbs_gpu_setting[sbs_gpu_setting_position].clone() - - sbs_size = global_sbs_size[sbs_gpu_setting_position].clone() - - if input.device != torch.device("cpu"): - if ( - (sbs_profile.numel() == 1) - or (sbs_size[0] != int(output.shape[0])) - or (sbs_size[1] != int(output.shape[1])) - or (sbs_size[2] != int(output.shape[2])) - or (sbs_size[3] != int(output.shape[3])) - ): - sbs_profile = torch.zeros( - (14, 7), dtype=torch.int64, device=torch.device("cpu") - ) - - global_sbs_hdynamic_cpp[sbs_hdynamic_cpp_position].gpu_occupancy_export( - int(output.shape[2]), - int(output.shape[3]), - int(output.shape[0]), - int(output.shape[1]), - sbs_profile.data_ptr(), - int(sbs_profile.shape[0]), - int(sbs_profile.shape[1]), - ) - global_sbs_gpu_setting[sbs_gpu_setting_position] = sbs_profile.clone() - sbs_size[0] = int(output.shape[0]) - sbs_size[1] = int(output.shape[1]) - sbs_size[2] = int(output.shape[2]) - sbs_size[3] = int(output.shape[3]) - global_sbs_size[sbs_gpu_setting_position] = sbs_size.clone() - - else: - global_sbs_hdynamic_cpp[sbs_hdynamic_cpp_position].gpu_occupancy_import( - sbs_profile.data_ptr(), - int(sbs_profile.shape[0]), - int(sbs_profile.shape[1]), - ) - - global_sbs_hdynamic_cpp[sbs_hdynamic_cpp_position].update( - output.data_ptr(), - int(output.shape[0]), - int(output.shape[1]), - int(output.shape[2]), - int(output.shape[3]), - epsilon_xy.data_ptr(), - int(epsilon_xy.shape[0]), - int(epsilon_xy.shape[1]), - int(epsilon_xy.shape[2]), - epsilon_t_0.data_ptr(), - int(epsilon_t_0.shape[0]), - weights.data_ptr(), - int(weights.shape[0]), - int(weights.shape[1]), - spikes.data_ptr(), - int(spikes.shape[0]), - int(spikes.shape[1]), - int(spikes.shape[2]), - int(spikes.shape[3]), - h_initial.data_ptr(), - int(h_initial.shape[0]), - hdyn_number_of_cpu_processes, - float(forgetting_offset.item()), - int(gpu_tuning_factor), - ) - - # ########################################################### - # Save the necessary data for the backward pass - # ########################################################### - - ctx.save_for_backward( - input, - weights, - output, - parameter_list, - grad_output_scale, - labels, - ) - - return output - - @staticmethod - def backward(ctx, grad_output): - # ############################################## - # Get the variables back - # ############################################## - ( - input, - weights, - output, - parameter_list, - last_grad_scale, - labels, - ) = ctx.saved_tensors - - assert labels.numel() > 0 - - # ############################################## - # Default output - # ############################################## - grad_input = None - grad_spikes = None - grad_eps_xy = None - grad_epsilon_t_0 = None - grad_weights = None - grad_h_initial = None - grad_parameter_list = None - grad_forgetting_offset = None - grad_labels = None - - # ############################################## - # Parameters - # ############################################## - parameter_w_trainable: bool = bool(parameter_list[0]) - parameter_disable_scale_grade: bool = bool(parameter_list[1]) - parameter_keep_last_grad_scale: bool = bool(parameter_list[2]) - parameter_skip_gradient_calculation: bool = bool(parameter_list[3]) - parameter_output_layer: bool = bool(parameter_list[9]) - parameter_local_learning: bool = bool(parameter_list[10]) - - # ############################################## - # Dealing with overall scale of the gradient - # ############################################## - if parameter_disable_scale_grade is False: - if parameter_keep_last_grad_scale is True: - last_grad_scale = torch.tensor( - [torch.abs(grad_output).max(), last_grad_scale] - ).max() - grad_output /= last_grad_scale - grad_output_scale = last_grad_scale.clone() - - input /= input.sum(dim=1, keepdim=True, dtype=weights.dtype) - - # ################################################# - # User doesn't want us to calculate the gradients - # ################################################# - - if parameter_skip_gradient_calculation is True: - - return ( - grad_input, - grad_spikes, - grad_eps_xy, - grad_epsilon_t_0, - grad_weights, - grad_h_initial, - grad_parameter_list, - grad_output_scale, - grad_forgetting_offset, - grad_labels, - ) - - # ################################################# - # Calculate backprop error (grad_input) - # ################################################# - - backprop_r: torch.Tensor = weights.unsqueeze(0).unsqueeze(-1).unsqueeze( - -1 - ) * output.unsqueeze(1) - - backprop_bigr: torch.Tensor = backprop_r.sum(dim=2) - - backprop_z: torch.Tensor = backprop_r * ( - 1.0 / (backprop_bigr + 1e-20) - ).unsqueeze(2) - grad_input: torch.Tensor = (backprop_z * grad_output.unsqueeze(1)).sum(2) - del backprop_z - - # ################################################# - # Calculate weight gradient (grad_weights) - # ################################################# - - if parameter_w_trainable is False: - - # ################################################# - # We don't train this weight - # ################################################# - grad_weights = None - - elif (parameter_output_layer is False) and (parameter_local_learning is True): - # ################################################# - # Local learning - # ################################################# - grad_weights = ( - (-2 * (input - backprop_bigr).unsqueeze(2) * output.unsqueeze(1)) - .sum(0) - .sum(-1) - .sum(-1) - ) - - elif (parameter_output_layer is True) and (parameter_local_learning is True): - - target_one_hot: torch.Tensor = torch.zeros( - ( - labels.shape[0], - output.shape[1], - ), - device=input.device, - dtype=input.dtype, - ) - - target_one_hot.scatter_( - 1, - labels.to(input.device).unsqueeze(1), - torch.ones( - (labels.shape[0], 1), - device=input.device, - dtype=input.dtype, - ), - ) - target_one_hot = target_one_hot.unsqueeze(-1).unsqueeze(-1) - - # (-2 * (input - backprop_bigr).unsqueeze(2) * (target_one_hot-output).unsqueeze(1)) - # (-2 * input.unsqueeze(2) * (target_one_hot-output).unsqueeze(1)) - grad_weights = ( - ( - -2 - * (input - backprop_bigr).unsqueeze(2) - * target_one_hot.unsqueeze(1) - ) - .sum(0) - .sum(-1) - .sum(-1) - ) - - else: - # ################################################# - # Backprop - # ################################################# - backprop_f: torch.Tensor = output.unsqueeze(1) * ( - input / (backprop_bigr**2 + 1e-20) - ).unsqueeze(2) - - result_omega: torch.Tensor = backprop_bigr.unsqueeze( - 2 - ) * grad_output.unsqueeze(1) - result_omega -= (backprop_r * grad_output.unsqueeze(1)).sum(2).unsqueeze(2) - result_omega *= backprop_f - del backprop_f - grad_weights = result_omega.sum(0).sum(-1).sum(-1) - del result_omega - - del backprop_bigr - del backprop_r - - return ( - grad_input, - grad_spikes, - grad_eps_xy, - grad_epsilon_t_0, - grad_weights, - grad_h_initial, - grad_parameter_list, - grad_output_scale, - grad_forgetting_offset, - grad_labels, - )