diff --git a/SbS.py b/SbS.py deleted file mode 100644 index 59cd9fe..0000000 --- a/SbS.py +++ /dev/null @@ -1,1291 +0,0 @@ -# MIT License -# Copyright 2022 University of Bremen -# -# Permission is hereby granted, free of charge, to any person obtaining -# a copy of this software and associated documentation files (the "Software"), -# to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, -# and/or sell copies of the Software, and to permit persons to whom the -# Software is furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included -# in all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. -# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, -# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR -# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR -# THE USE OR OTHER DEALINGS IN THE SOFTWARE. -# -# -# David Rotermund ( davrot@uni-bremen.de ) -# -# -# Release history: -# ================ -# 1.0.0 -- 01.05.2022: first release -# -# - -# %% -import torch -import numpy as np - -try: - import PySpikeGeneration2DManyIP - - cpp_spike: bool = True -except Exception: - cpp_spike = False - -try: - import PyHDynamicCNNManyIP - - cpp_sbs: bool = True -except Exception: - cpp_sbs = False - - -class SbS(torch.nn.Module): - - _epsilon_xy: torch.nn.parameter.Parameter - _epsilon_xy_exists: bool = False - _epsilon_0: torch.Tensor | None = None - _epsilon_t: torch.Tensor | None = None - _weights: torch.nn.parameter.Parameter - _weights_exists: bool = False - _kernel_size: torch.Tensor | None = None - _stride: torch.Tensor | None = None - _dilation: torch.Tensor | None = None - _padding: torch.Tensor | None = None - _output_size: torch.Tensor | None = None - _number_of_spikes: torch.Tensor | None = None - _number_of_cpu_processes: torch.Tensor | None = None - _number_of_neurons: torch.Tensor | None = None - _number_of_input_neurons: torch.Tensor | None = None - _h_initial: torch.Tensor | None = None - _epsilon_xy_backup: torch.Tensor | None = None - _weights_backup: torch.Tensor | None = None - _alpha_number_of_iterations: torch.Tensor | None = None - - def __init__( - self, - number_of_input_neurons: int, - number_of_neurons: int, - input_size: list[int], - forward_kernel_size: list[int], - number_of_spikes: int, - epsilon_t: torch.Tensor, - epsilon_xy_intitial: float = 0.1, - epsilon_0: float = 1.0, - weight_noise_amplitude: float = 0.01, - is_pooling_layer: bool = False, - strides: list[int] = [1, 1], - dilation: list[int] = [0, 0], - padding: list[int] = [0, 0], - alpha_number_of_iterations: int = 0, - number_of_cpu_processes: int = 1, - ) -> None: - """Constructor""" - super().__init__() - - self.stride = torch.tensor(strides, dtype=torch.int64) - - self.dilation = torch.tensor(dilation, dtype=torch.int64) - - self.padding = torch.tensor(padding, dtype=torch.int64) - - self.kernel_size = torch.tensor( - forward_kernel_size, - dtype=torch.int64, - ) - - self.number_of_input_neurons = torch.tensor( - number_of_input_neurons, - dtype=torch.int64, - ) - - self.number_of_neurons = torch.tensor( - number_of_neurons, - dtype=torch.int64, - ) - - self.alpha_number_of_iterations = torch.tensor( - alpha_number_of_iterations, dtype=torch.int64 - ) - - self.calculate_output_size(torch.tensor(input_size, dtype=torch.int64)) - - self.set_h_init_to_uniform() - - self.initialize_epsilon_xy(epsilon_xy_intitial) - - self.epsilon_0 = torch.tensor(epsilon_0, dtype=torch.float32) - - self.number_of_cpu_processes = torch.tensor( - number_of_cpu_processes, dtype=torch.int64 - ) - - self.number_of_spikes = torch.tensor(number_of_spikes, dtype=torch.int64) - - self.epsilon_t = epsilon_t.type(dtype=torch.float32) - - self.initialize_weights( - is_pooling_layer=is_pooling_layer, - noise_amplitude=weight_noise_amplitude, - ) - - self.functional_sbs = FunctionalSbS.apply - - #################################################################### - # Variables in and out # - #################################################################### - - @property - def epsilon_xy(self) -> torch.Tensor | None: - if self._epsilon_xy_exists is False: - return None - else: - return self._epsilon_xy - - @epsilon_xy.setter - def epsilon_xy(self, value: torch.Tensor): - assert value is not None - assert torch.is_tensor(value) is True - assert value.dim() == 4 - assert value.dtype == torch.float32 - if self._epsilon_xy_exists is False: - self._epsilon_xy = torch.nn.parameter.Parameter( - value.detach().clone(memory_format=torch.contiguous_format), - requires_grad=True, - ) - self._epsilon_xy_exists = True - else: - self._epsilon_xy.data = value.detach().clone( - memory_format=torch.contiguous_format - ) - - @property - def epsilon_0(self) -> torch.Tensor | None: - return self._epsilon_0 - - @epsilon_0.setter - def epsilon_0(self, value: torch.Tensor): - assert value is not None - assert torch.is_tensor(value) is True - assert torch.numel(value) == 1 - assert value.dtype == torch.float32 - assert value.item() > 0 - self._epsilon_0 = value.detach().clone(memory_format=torch.contiguous_format) - self._epsilon_0.requires_grad_(False) - - @property - def epsilon_t(self) -> torch.Tensor | None: - return self._epsilon_t - - @epsilon_t.setter - def epsilon_t(self, value: torch.Tensor): - assert value is not None - assert torch.is_tensor(value) is True - assert value.dim() == 1 - assert value.dtype == torch.float32 - self._epsilon_t = value.detach().clone(memory_format=torch.contiguous_format) - self._epsilon_t.requires_grad_(False) - - @property - def weights(self) -> torch.Tensor | None: - if self._weights_exists is False: - return None - else: - return self._weights - - @weights.setter - def weights(self, value: torch.Tensor): - assert value is not None - assert torch.is_tensor(value) is True - assert value.dim() == 2 - assert value.dtype == torch.float32 - temp: torch.Tensor = value.detach().clone(memory_format=torch.contiguous_format) - temp /= temp.sum(dim=0, keepdim=True, dtype=torch.float32) - if self._weights_exists is False: - self._weights = torch.nn.parameter.Parameter( - temp, - requires_grad=True, - ) - self._weights_exists = True - else: - self._weights.data = temp - - @property - def kernel_size(self) -> torch.Tensor | None: - return self._kernel_size - - @kernel_size.setter - def kernel_size(self, value: torch.Tensor): - assert value is not None - assert torch.is_tensor(value) is True - assert value.dim() == 1 - assert torch.numel(value) == 2 - assert value.dtype == torch.int64 - assert value[0] > 0 - assert value[1] > 0 - self._kernel_size = value.detach().clone(memory_format=torch.contiguous_format) - self._kernel_size.requires_grad_(False) - - @property - def stride(self) -> torch.Tensor | None: - return self._stride - - @stride.setter - def stride(self, value: torch.Tensor): - assert value is not None - assert torch.is_tensor(value) is True - assert value.dim() == 1 - assert torch.numel(value) == 2 - assert value.dtype == torch.int64 - assert value[0] > 0 - assert value[1] > 0 - self._stride = value.detach().clone(memory_format=torch.contiguous_format) - self._stride.requires_grad_(False) - - @property - def dilation(self) -> torch.Tensor | None: - return self._dilation - - @dilation.setter - def dilation(self, value: torch.Tensor): - assert value is not None - assert torch.is_tensor(value) is True - assert value.dim() == 1 - assert torch.numel(value) == 2 - assert value.dtype == torch.int64 - assert value[0] > 0 - assert value[1] > 0 - self._dilation = value.detach().clone(memory_format=torch.contiguous_format) - self._dilation.requires_grad_(False) - - @property - def padding(self) -> torch.Tensor | None: - return self._padding - - @padding.setter - def padding(self, value: torch.Tensor): - assert value is not None - assert torch.is_tensor(value) is True - assert value.dim() == 1 - assert torch.numel(value) == 2 - assert value.dtype == torch.int64 - assert value[0] >= 0 - assert value[1] >= 0 - self._padding = value.detach().clone(memory_format=torch.contiguous_format) - self._padding.requires_grad_(False) - - @property - def output_size(self) -> torch.Tensor | None: - return self._output_size - - @output_size.setter - def output_size(self, value: torch.Tensor): - assert value is not None - assert torch.is_tensor(value) is True - assert value.dim() == 1 - assert torch.numel(value) == 2 - assert value.dtype == torch.int64 - assert value[0] > 0 - assert value[1] > 0 - self._output_size = value.detach().clone(memory_format=torch.contiguous_format) - self._output_size.requires_grad_(False) - - @property - def number_of_spikes(self) -> torch.Tensor | None: - return self._number_of_spikes - - @number_of_spikes.setter - def number_of_spikes(self, value: torch.Tensor): - assert value is not None - assert torch.is_tensor(value) is True - assert torch.numel(value) == 1 - assert value.dtype == torch.int64 - assert value.item() > 0 - self._number_of_spikes = value.detach().clone( - memory_format=torch.contiguous_format - ) - self._number_of_spikes.requires_grad_(False) - - @property - def number_of_cpu_processes(self) -> torch.Tensor | None: - return self._number_of_cpu_processes - - @number_of_cpu_processes.setter - def number_of_cpu_processes(self, value: torch.Tensor): - assert value is not None - assert torch.is_tensor(value) is True - assert torch.numel(value) == 1 - assert value.dtype == torch.int64 - assert value.item() > 0 - self._number_of_cpu_processes = value.detach().clone( - memory_format=torch.contiguous_format - ) - self._number_of_cpu_processes.requires_grad_(False) - - @property - def number_of_neurons(self) -> torch.Tensor | None: - return self._number_of_neurons - - @number_of_neurons.setter - def number_of_neurons(self, value: torch.Tensor): - assert value is not None - assert torch.is_tensor(value) is True - assert torch.numel(value) == 1 - assert value.dtype == torch.int64 - assert value.item() > 0 - self._number_of_neurons = value.detach().clone( - memory_format=torch.contiguous_format - ) - self._number_of_neurons.requires_grad_(False) - - @property - def number_of_input_neurons(self) -> torch.Tensor | None: - return self._number_of_input_neurons - - @number_of_input_neurons.setter - def number_of_input_neurons(self, value: torch.Tensor): - assert value is not None - assert torch.is_tensor(value) is True - assert torch.numel(value) == 1 - assert value.dtype == torch.int64 - assert value.item() > 0 - self._number_of_input_neurons = value.detach().clone( - memory_format=torch.contiguous_format - ) - self._number_of_input_neurons.requires_grad_(False) - - @property - def h_initial(self) -> torch.Tensor | None: - return self._h_initial - - @h_initial.setter - def h_initial(self, value: torch.Tensor): - assert value is not None - assert torch.is_tensor(value) is True - assert value.dim() == 1 - assert value.dtype == torch.float32 - self._h_initial = value.detach().clone(memory_format=torch.contiguous_format) - self._h_initial.requires_grad_(False) - - @property - def alpha_number_of_iterations(self) -> torch.Tensor | None: - return self._alpha_number_of_iterations - - @alpha_number_of_iterations.setter - def alpha_number_of_iterations(self, value: torch.Tensor): - assert value is not None - assert torch.is_tensor(value) is True - assert torch.numel(value) == 1 - assert value.dtype == torch.int64 - assert value.item() >= 0 - self._alpha_number_of_iterations = value.detach().clone( - memory_format=torch.contiguous_format - ) - self._alpha_number_of_iterations.requires_grad_(False) - - #################################################################### - # Forward # - #################################################################### - - def forward(self, input: torch.Tensor) -> torch.Tensor: - """PyTorch Forward method. Does the work.""" - - # Are we happy with the input? - assert input is not None - assert torch.is_tensor(input) is True - assert input.dim() == 4 - assert input.dtype == torch.float32 - - # Are we happy with the rest of the network? - assert self._epsilon_xy_exists is True - assert self._epsilon_xy is not None - assert self._epsilon_0 is not None - assert self._epsilon_t is not None - assert self._weights_exists is True - assert self._weights is not None - assert self._kernel_size is not None - assert self._stride is not None - assert self._dilation is not None - assert self._padding is not None - assert self._output_size is not None - assert self._number_of_spikes is not None - assert self._number_of_cpu_processes is not None - assert self._h_initial is not None - assert self._alpha_number_of_iterations is not None - - # SbS forward functional - return self.functional_sbs( - input, - self._epsilon_xy, - self._epsilon_0, - self._epsilon_t, - self._weights, - self._kernel_size, - self._stride, - self._dilation, - self._padding, - self._output_size, - self._number_of_spikes, - self._number_of_cpu_processes, - self._h_initial, - self._alpha_number_of_iterations, - ) - - #################################################################### - # Helper functions # - #################################################################### - - def calculate_output_size(self, value: torch.Tensor) -> None: - - coordinates_0, coordinates_1 = self._get_coordinates(value) - - self._output_size: torch.Tensor = torch.tensor( - [ - coordinates_0.shape[1], - coordinates_1.shape[1], - ], - dtype=torch.int64, - ) - self._output_size.requires_grad_(False) - - def _get_coordinates( - self, value: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - """Function converts parameter in coordinates - for the convolution window""" - - assert value is not None - assert torch.is_tensor(value) is True - assert value.dim() == 1 - assert torch.numel(value) == 2 - assert value.dtype == torch.int64 - assert value[0] > 0 - assert value[1] > 0 - - assert self._kernel_size is not None - assert self._stride is not None - assert self._dilation is not None - assert self._padding is not None - - assert torch.numel(self._kernel_size) == 2 - assert torch.numel(self._stride) == 2 - assert torch.numel(self._dilation) == 2 - assert torch.numel(self._padding) == 2 - - unfold_0: torch.nn.Unfold = torch.nn.Unfold( - kernel_size=(int(self._kernel_size[0]), 1), - dilation=int(self._dilation[0]), - padding=int(self._padding[0]), - stride=int(self._stride[0]), - ) - - unfold_1: torch.nn.Unfold = torch.nn.Unfold( - kernel_size=(1, int(self._kernel_size[1])), - dilation=int(self._dilation[1]), - padding=int(self._padding[1]), - stride=int(self._stride[1]), - ) - - coordinates_0: torch.Tensor = ( - unfold_0( - torch.unsqueeze( - torch.unsqueeze( - torch.unsqueeze( - torch.arange(0, int(value[0]), dtype=torch.float32), - 1, - ), - 0, - ), - 0, - ) - ) - .squeeze(0) - .type(torch.int64) - ) - - coordinates_1: torch.Tensor = ( - unfold_1( - torch.unsqueeze( - torch.unsqueeze( - torch.unsqueeze( - torch.arange(0, int(value[1]), dtype=torch.float32), - 0, - ), - 0, - ), - 0, - ) - ) - .squeeze(0) - .type(torch.int64) - ) - - return coordinates_0, coordinates_1 - - def _initial_random_weights(self, noise_amplitude: torch.Tensor) -> torch.Tensor: - """Creates initial weights - Uniform plus random noise plus normalization - """ - - assert torch.numel(noise_amplitude) == 1 - assert noise_amplitude.item() >= 0 - assert noise_amplitude.dtype == torch.float32 - - assert self._number_of_neurons is not None - assert self._number_of_input_neurons is not None - assert self._kernel_size is not None - - weights = torch.empty( - ( - int(self._kernel_size[0]), - int(self._kernel_size[1]), - int(self._number_of_input_neurons), - int(self._number_of_neurons), - ), - dtype=torch.float32, - ) - torch.nn.init.uniform_(weights, a=1.0, b=(1.0 + noise_amplitude.item())) - - return weights - - def _make_pooling_weights(self) -> torch.Tensor: - """For generating the pooling weights.""" - - assert self._number_of_neurons is not None - assert self._kernel_size is not None - - norm: float = 1.0 / (self._kernel_size[0] * self._kernel_size[1]) - - weights: torch.Tensor = torch.zeros( - ( - int(self._kernel_size[0]), - int(self._kernel_size[1]), - int(self._number_of_neurons), - int(self._number_of_neurons), - ), - dtype=torch.float32, - ) - - for i in range(0, int(self._number_of_neurons)): - weights[:, :, i, i] = norm - - return weights - - def initialize_weights( - self, - is_pooling_layer: bool = False, - noise_amplitude: float = 0.01, - ) -> None: - """For the generation of the initital weights. - Switches between normal initial random weights and pooling weights.""" - - assert self._kernel_size is not None - - if is_pooling_layer is True: - weights = self._make_pooling_weights() - else: - weights = self._initial_random_weights( - torch.tensor(noise_amplitude, dtype=torch.float32) - ) - - weights = weights.moveaxis(-1, 0).moveaxis(-1, 1) - - weights_t = torch.nn.functional.unfold( - input=weights, - kernel_size=(int(self._kernel_size[0]), int(self._kernel_size[1])), - dilation=(1, 1), - padding=(0, 0), - stride=(1, 1), - ).squeeze() - - weights_t = torch.moveaxis(weights_t, 0, 1) - - self.weights = weights_t - - def initialize_epsilon_xy( - self, - eps_xy_intitial: float, - ) -> None: - """Creates initial epsilon xy matrices""" - - assert self._output_size is not None - assert self._kernel_size is not None - assert eps_xy_intitial > 0 - - eps_xy_temp: torch.Tensor = torch.full( - ( - int(self._output_size[0]), - int(self._output_size[1]), - int(self._kernel_size[0]), - int(self._kernel_size[1]), - ), - eps_xy_intitial, - dtype=torch.float32, - ) - - self.epsilon_xy = eps_xy_temp - - def set_h_init_to_uniform(self) -> None: - - assert self._number_of_neurons is not None - - h_initial: torch.Tensor = torch.full( - (int(self._number_of_neurons.item()),), - (1.0 / float(self._number_of_neurons.item())), - dtype=torch.float32, - ) - - self.h_initial = h_initial - - # Epsilon XY - def backup_epsilon_xy(self) -> None: - assert self._epsilon_xy_exists is True - self._epsilon_xy_backup = self._epsilon_xy.data.clone() - - def restore_epsilon_xy(self) -> None: - assert self._epsilon_xy_backup is not None - assert self._epsilon_xy_exists is True - self._epsilon_xy.data = self._epsilon_xy_backup.clone() - - def mean_epsilon_xy(self) -> None: - assert self._epsilon_xy_exists is True - - fill_value: float = float(self._epsilon_xy.data.mean()) - self._epsilon_xy.data = torch.full_like( - self._epsilon_xy.data, fill_value, dtype=torch.float32 - ) - - def threshold_epsilon_xy(self, threshold: float) -> None: - assert self._epsilon_xy_exists is True - assert threshold >= 0 - torch.clamp( - self._epsilon_xy.data, - min=float(threshold), - max=None, - out=self._epsilon_xy.data, - ) - - # Weights - def backup_weights(self) -> None: - assert self._weights_exists is True - self._weights_backup = self._weights.data.clone() - - def restore_weights(self) -> None: - assert self._weights_backup is not None - assert self._weights_exists is True - self._weights.data = self._weights_backup.clone() - - def norm_weights(self) -> None: - assert self._weights_exists is True - temp: torch.Tensor = ( - self._weights.data.detach() - .clone(memory_format=torch.contiguous_format) - .type(dtype=torch.float32) - ) - temp /= temp.sum(dim=0, keepdim=True, dtype=torch.float32) - self._weights.data = temp - - def threshold_weights(self, threshold: float) -> None: - assert self._weights_exists is True - assert threshold >= 0 - torch.clamp( - self._weights.data, - min=float(threshold), - max=None, - out=self._weights.data, - ) - - -class FunctionalSbS(torch.autograd.Function): - @staticmethod - def forward( # type: ignore - ctx, - input: torch.Tensor, - epsilon_xy: torch.Tensor, - epsilon_0: torch.Tensor, - epsilon_t: torch.Tensor, - weights: torch.Tensor, - kernel_size: torch.Tensor, - stride: torch.Tensor, - dilation: torch.Tensor, - padding: torch.Tensor, - output_size: torch.Tensor, - number_of_spikes: torch.Tensor, - number_of_cpu_processes: torch.Tensor, - h_initial: torch.Tensor, - alpha_number_of_iterations: torch.Tensor, - ) -> torch.Tensor: - - torch.set_default_dtype(torch.float32) - - assert input.dim() == 4 - assert torch.numel(kernel_size) == 2 - assert torch.numel(dilation) == 2 - assert torch.numel(padding) == 2 - assert torch.numel(stride) == 2 - assert torch.numel(output_size) == 2 - - assert torch.numel(epsilon_0) == 1 - assert torch.numel(number_of_spikes) == 1 - assert torch.numel(number_of_cpu_processes) == 1 - assert torch.numel(alpha_number_of_iterations) == 1 - - input_size = torch.tensor([input.shape[2], input.shape[3]]) - - ############################################################ - # Pre convolving the input # - ############################################################ - - input_convolved_temp = torch.nn.functional.unfold( - input, - kernel_size=tuple(kernel_size.tolist()), - dilation=tuple(dilation.tolist()), - padding=tuple(padding.tolist()), - stride=tuple(stride.tolist()), - ) - - input_convolved = torch.nn.functional.fold( - input_convolved_temp, - output_size=tuple(output_size.tolist()), - kernel_size=(1, 1), - dilation=(1, 1), - padding=(0, 0), - stride=(1, 1), - ).requires_grad_(True) - - epsilon_xy_convolved: torch.Tensor = ( - ( - torch.nn.functional.unfold( - epsilon_xy.reshape( - ( - int(epsilon_xy.shape[0]) * int(epsilon_xy.shape[1]), - int(epsilon_xy.shape[2]), - int(epsilon_xy.shape[3]), - ) - ) - .unsqueeze(1) - .tile((1, input.shape[1], 1, 1)), - kernel_size=tuple(kernel_size.tolist()), - dilation=1, - padding=0, - stride=1, - ) - .squeeze(-1) - .reshape( - ( - int(epsilon_xy.shape[0]), - int(epsilon_xy.shape[1]), - int(input_convolved.shape[1]), - ) - ) - ) - .moveaxis(-1, 0) - .contiguous(memory_format=torch.contiguous_format) - ) - - ############################################################ - # Spike generation # - ############################################################ - - if cpp_spike is False: - # Alternative to the C++ module but 5x slower: - spikes = ( - ( - input_convolved.movedim(source=(2, 3), destination=(0, 1)) - .reshape( - shape=( - input_convolved.shape[2] - * input_convolved.shape[3] - * input_convolved.shape[0], - input_convolved.shape[1], - ) - ) - .multinomial( - num_samples=int(number_of_spikes.item()), replacement=True - ) - ) - .reshape( - shape=( - input_convolved.shape[2], - input_convolved.shape[3], - input_convolved.shape[0], - int(number_of_spikes.item()), - ) - ) - .movedim(source=(0, 1), destination=(2, 3)) - ).contiguous(memory_format=torch.contiguous_format) - else: - # Normalized cumsum - input_cumsum: torch.Tensor = torch.cumsum( - input_convolved, dim=1, dtype=torch.float32 - ) - input_cumsum_last: torch.Tensor = input_cumsum[:, -1, :, :].unsqueeze(1) - input_cumsum /= input_cumsum_last - - random_values = torch.rand( - size=[ - input_cumsum.shape[0], - int(number_of_spikes.item()), - input_cumsum.shape[2], - input_cumsum.shape[3], - ], - dtype=torch.float32, - ) - - spikes = torch.empty_like(random_values, dtype=torch.int64) - - # Prepare for Export (Pointer and stuff)-> - np_input: np.ndarray = input_cumsum.detach().numpy() - assert input_cumsum.dtype == torch.float32 - assert np_input.flags["C_CONTIGUOUS"] is True - assert np_input.ndim == 4 - - np_random_values: np.ndarray = random_values.detach().numpy() - assert random_values.dtype == torch.float32 - assert np_random_values.flags["C_CONTIGUOUS"] is True - assert np_random_values.ndim == 4 - - np_spikes: np.ndarray = spikes.detach().numpy() - assert spikes.dtype == torch.int64 - assert np_spikes.flags["C_CONTIGUOUS"] is True - assert np_spikes.ndim == 4 - # <- Prepare for Export - - spike_generation: PySpikeGeneration2DManyIP.SpikeGeneration2DManyIP = ( - PySpikeGeneration2DManyIP.SpikeGeneration2DManyIP() - ) - - spike_generation.spike_generation_multi_pattern( - np_input.__array_interface__["data"][0], - int(np_input.shape[0]), - int(np_input.shape[1]), - int(np_input.shape[2]), - int(np_input.shape[3]), - np_random_values.__array_interface__["data"][0], - int(np_random_values.shape[0]), - int(np_random_values.shape[1]), - int(np_random_values.shape[2]), - int(np_random_values.shape[3]), - np_spikes.__array_interface__["data"][0], - int(np_spikes.shape[0]), - int(np_spikes.shape[1]), - int(np_spikes.shape[2]), - int(np_spikes.shape[3]), - int(number_of_cpu_processes.item()), - ) - - ############################################################ - # H dynamic # - ############################################################ - - assert epsilon_t.ndim == 1 - assert epsilon_t.shape[0] >= number_of_spikes - - if cpp_sbs is False: - h = torch.tile( - h_initial.unsqueeze(0).unsqueeze(0).unsqueeze(0), - dims=[int(input.shape[0]), int(output_size[0]), int(output_size[1]), 1], - ) - - epsilon_scale: torch.Tensor = torch.ones( - size=[ - int(spikes.shape[0]), - int(spikes.shape[2]), - int(spikes.shape[3]), - 1, - ], - dtype=torch.float32, - ) - - for t in range(0, spikes.shape[1]): - - if epsilon_scale.max() > 1e10: - h /= epsilon_scale - epsilon_scale = torch.ones_like(epsilon_scale) - - h_temp: torch.Tensor = weights[spikes[:, t, :, :], :] * h - wx = 0 - wy = 0 - - if t == 0: - epsilon_temp: torch.Tensor = torch.empty( - ( - int(spikes.shape[0]), - int(spikes.shape[2]), - int(spikes.shape[3]), - ), - dtype=torch.float32, - ) - for wx in range(0, int(spikes.shape[2])): - for wy in range(0, int(spikes.shape[3])): - epsilon_temp[:, wx, wy] = epsilon_xy_convolved[ - spikes[:, t, wx, wy], wx, wy - ] - - epsilon_subsegment: torch.Tensor = ( - epsilon_temp.unsqueeze(-1) * epsilon_t[t] * epsilon_0 - ) - - h_temp_sum: torch.Tensor = ( - epsilon_scale * epsilon_subsegment / h_temp.sum(dim=3, keepdim=True) - ) - torch.nan_to_num( - h_temp_sum, out=h_temp_sum, nan=0.0, posinf=0.0, neginf=0.0 - ) - h_temp *= h_temp_sum - h += h_temp - - epsilon_scale *= 1.0 + epsilon_subsegment - - h /= epsilon_scale - output = h.movedim(3, 1) - - else: - epsilon_t_0: torch.Tensor = epsilon_t * epsilon_0 - - h_shape: tuple[int, int, int, int] = ( - int(input.shape[0]), - int(weights.shape[1]), - int(output_size[0]), - int(output_size[1]), - ) - - output = torch.empty(h_shape, dtype=torch.float32) - - # Prepare the export to C++ -> - np_h: np.ndarray = output.detach().numpy() - assert output.dtype == torch.float32 - assert np_h.flags["C_CONTIGUOUS"] is True - assert np_h.ndim == 4 - - np_epsilon_xy: np.ndarray = epsilon_xy_convolved.detach().numpy() - assert epsilon_xy.dtype == torch.float32 - assert np_epsilon_xy.flags["C_CONTIGUOUS"] is True - assert np_epsilon_xy.ndim == 3 - - np_epsilon_t: np.ndarray = epsilon_t_0.detach().numpy() - assert epsilon_t_0.dtype == torch.float32 - assert np_epsilon_t.flags["C_CONTIGUOUS"] is True - assert np_epsilon_t.ndim == 1 - - np_weights: np.ndarray = weights.detach().numpy() - assert weights.dtype == torch.float32 - assert np_weights.flags["C_CONTIGUOUS"] is True - assert np_weights.ndim == 2 - - np_spikes = spikes.contiguous().detach().numpy() - assert spikes.dtype == torch.int64 - assert np_spikes.flags["C_CONTIGUOUS"] is True - assert np_spikes.ndim == 4 - - np_h_initial = h_initial.contiguous().detach().numpy() - assert h_initial.dtype == torch.float32 - assert np_h_initial.flags["C_CONTIGUOUS"] is True - assert np_h_initial.ndim == 1 - # <- Prepare the export to C++ - - h_dynamic: PyHDynamicCNNManyIP.HDynamicCNNManyIP = ( - PyHDynamicCNNManyIP.HDynamicCNNManyIP() - ) - - h_dynamic.update_with_init_vector_multi_pattern( - np_h.__array_interface__["data"][0], - int(np_h.shape[0]), - int(np_h.shape[1]), - int(np_h.shape[2]), - int(np_h.shape[3]), - np_epsilon_xy.__array_interface__["data"][0], - int(np_epsilon_xy.shape[0]), - int(np_epsilon_xy.shape[1]), - int(np_epsilon_xy.shape[2]), - np_epsilon_t.__array_interface__["data"][0], - int(np_epsilon_t.shape[0]), - np_weights.__array_interface__["data"][0], - int(np_weights.shape[0]), - int(np_weights.shape[1]), - np_spikes.__array_interface__["data"][0], - int(np_spikes.shape[0]), - int(np_spikes.shape[1]), - int(np_spikes.shape[2]), - int(np_spikes.shape[3]), - np_h_initial.__array_interface__["data"][0], - int(np_h_initial.shape[0]), - int(number_of_cpu_processes.item()), - ) - - ############################################################ - # Alpha # - ############################################################ - alpha_number_of_iterations_int: int = int(alpha_number_of_iterations.item()) - - if alpha_number_of_iterations_int > 0: - # Initialization - virtual_reconstruction_weight: torch.Tensor = torch.einsum( - "bixy,ji->bjxy", output, weights - ) - alpha_fill_value: float = 1.0 / ( - virtual_reconstruction_weight.shape[2] - * virtual_reconstruction_weight.shape[3] - ) - alpha_dynamic: torch.Tensor = torch.full( - ( - int(virtual_reconstruction_weight.shape[0]), - 1, - int(virtual_reconstruction_weight.shape[2]), - int(virtual_reconstruction_weight.shape[3]), - ), - alpha_fill_value, - dtype=torch.float32, - ) - - # Iterations - for _ in range(0, alpha_number_of_iterations_int): - alpha_temp: torch.Tensor = alpha_dynamic * virtual_reconstruction_weight - alpha_temp /= alpha_temp.sum(dim=3, keepdim=True).sum( - dim=2, keepdim=True - ) - torch.nan_to_num( - alpha_temp, out=alpha_temp, nan=0.0, posinf=0.0, neginf=0.0 - ) - - alpha_temp = torch.nn.functional.unfold( - alpha_temp, - kernel_size=(1, 1), - dilation=1, - padding=0, - stride=1, - ) - - alpha_temp = torch.nn.functional.fold( - alpha_temp, - output_size=tuple(input_size.tolist()), - kernel_size=tuple(kernel_size.tolist()), - dilation=tuple(dilation.tolist()), - padding=tuple(padding.tolist()), - stride=tuple(stride.tolist()), - ) - - alpha_temp = (alpha_temp * input).sum(dim=1, keepdim=True) - - alpha_temp = torch.nn.functional.unfold( - alpha_temp, - kernel_size=tuple(kernel_size.tolist()), - dilation=tuple(dilation.tolist()), - padding=tuple(padding.tolist()), - stride=tuple(stride.tolist()), - ) - - alpha_temp = torch.nn.functional.fold( - alpha_temp, - output_size=tuple(output_size.tolist()), - kernel_size=(1, 1), - dilation=(1, 1), - padding=(0, 0), - stride=(1, 1), - ) - alpha_dynamic = alpha_temp.sum(dim=1, keepdim=True) - - alpha_dynamic += 1e-20 - - # Alpha normalization - alpha_dynamic /= alpha_dynamic.sum(dim=3, keepdim=True).sum( - dim=2, keepdim=True - ) - torch.nan_to_num( - alpha_dynamic, out=alpha_dynamic, nan=0.0, posinf=0.0, neginf=0.0 - ) - - # Applied to the output - output *= alpha_dynamic - - ############################################################ - # Save the necessary data for the backward pass # - ############################################################ - - ctx.save_for_backward( - input_convolved, - epsilon_xy_convolved, - epsilon_0, - weights, - output, - kernel_size, - stride, - dilation, - padding, - input_size, - ) - - return output - - @staticmethod - def backward(ctx, grad_output): - - # Get the variables back - ( - input, - epsilon_xy, - epsilon_0, - weights, - output, - kernel_size, - stride, - dilation, - padding, - input_size, - ) = ctx.saved_tensors - - torch.set_default_dtype(torch.float32) - - input /= input.sum(dim=1, keepdim=True, dtype=torch.float32) - - # For debugging: - # print( - # f"S: O: {output.min().item():e} {output.max().item():e} I: {input.min().item():e} {input.max().item():e} G: {grad_output.min().item():e} {grad_output.max().item():e}" - # ) - - epsilon_0_float: float = epsilon_0.item() - - temp_e: torch.Tensor = 1.0 / ((epsilon_xy * epsilon_0_float) + 1.0) - - eps_a: torch.Tensor = temp_e.clone() - eps_a *= epsilon_xy * epsilon_0_float - - eps_b: torch.Tensor = temp_e**2 * epsilon_0_float - - backprop_r: torch.Tensor = weights.unsqueeze(0).unsqueeze(-1).unsqueeze( - -1 - ) * output.unsqueeze(1) - - backprop_bigr: torch.Tensor = backprop_r.sum(axis=2) - - temp: torch.Tensor = input / (backprop_bigr**2 + 1e-20) - - backprop_f: torch.Tensor = output.unsqueeze(1) * temp.unsqueeze(2) - torch.nan_to_num( - backprop_f, out=backprop_f, nan=1e30, posinf=1e30, neginf=-1e30 - ) - torch.clip(backprop_f, out=backprop_f, min=-1e30, max=1e30) - - tempz: torch.Tensor = 1.0 / (backprop_bigr + 1e-20) - - backprop_z: torch.Tensor = backprop_r * tempz.unsqueeze(2) - torch.nan_to_num( - backprop_z, out=backprop_z, nan=1e30, posinf=1e30, neginf=-1e30 - ) - torch.clip(backprop_z, out=backprop_z, min=-1e30, max=1e30) - - result_omega: torch.Tensor = backprop_bigr.unsqueeze(2) * grad_output.unsqueeze( - 1 - ) - result_omega -= torch.einsum( - "bijxy,bjxy->bixy", backprop_r, grad_output - ).unsqueeze(2) - result_omega *= backprop_f - - result_eps_xy: torch.Tensor = ( - ( - (backprop_z * input.unsqueeze(2) - output.unsqueeze(1)) - * grad_output.unsqueeze(1) - ) - .sum(dim=2) - .sum(dim=0) - ) * eps_b - - result_phi: torch.Tensor = torch.einsum( - "bijxy,bjxy->bixy", backprop_z, grad_output - ) * eps_a.unsqueeze(0) - - grad_weights = result_omega.sum(0).sum(-1).sum(-1) - torch.nan_to_num( - grad_weights, out=grad_weights, nan=1e30, posinf=1e30, neginf=-1e30 - ) - torch.clip(grad_weights, out=grad_weights, min=-1e30, max=1e30) - - grad_input = torch.nn.functional.fold( - torch.nn.functional.unfold( - result_phi, - kernel_size=(1, 1), - dilation=1, - padding=0, - stride=1, - ), - output_size=input_size, - kernel_size=kernel_size, - dilation=dilation, - padding=padding, - stride=stride, - ) - torch.nan_to_num( - grad_input, out=grad_input, nan=1e30, posinf=1e30, neginf=-1e30 - ) - torch.clip(grad_input, out=grad_input, min=-1e30, max=1e30) - - grad_eps_xy_temp = torch.nn.functional.fold( - result_eps_xy.moveaxis(0, -1) - .reshape( - ( - int(result_eps_xy.shape[1]) * int(result_eps_xy.shape[2]), - int(result_eps_xy.shape[0]), - ) - ) - .unsqueeze(-1), - output_size=kernel_size, - kernel_size=kernel_size, - ) - - grad_eps_xy = ( - grad_eps_xy_temp.sum(dim=1) - .reshape( - ( - int(result_eps_xy.shape[1]), - int(result_eps_xy.shape[2]), - int(grad_eps_xy_temp.shape[-2]), - int(grad_eps_xy_temp.shape[-1]), - ) - ) - .contiguous(memory_format=torch.contiguous_format) - ) - torch.nan_to_num( - grad_eps_xy, out=grad_eps_xy, nan=1e30, posinf=1e30, neginf=-1e30 - ) - torch.clip(grad_eps_xy, out=grad_eps_xy, min=-1e30, max=1e30) - - grad_epsilon_0 = None - grad_epsilon_t = None - grad_kernel_size = None - grad_stride = None - grad_dilation = None - grad_padding = None - grad_output_size = None - grad_number_of_spikes = None - grad_number_of_cpu_processes = None - grad_h_initial = None - grad_alpha_number_of_iterations = None - - return ( - grad_input, - grad_eps_xy, - grad_epsilon_0, - grad_epsilon_t, - grad_weights, - grad_kernel_size, - grad_stride, - grad_dilation, - grad_padding, - grad_output_size, - grad_number_of_spikes, - grad_number_of_cpu_processes, - grad_h_initial, - grad_alpha_number_of_iterations, - ) - - -# %%