pytorch-sbs/SbS.py
2022-05-04 14:42:44 +02:00

1291 lines
44 KiB
Python

# MIT License
# Copyright 2022 University of Bremen
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR
# THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
#
# David Rotermund ( davrot@uni-bremen.de )
#
#
# Release history:
# ================
# 1.0.0 -- 01.05.2022: first release
#
#
# %%
import torch
import numpy as np
try:
import PySpikeGeneration2DManyIP
cpp_spike: bool = True
except Exception:
cpp_spike = False
try:
import PyHDynamicCNNManyIP
cpp_sbs: bool = True
except Exception:
cpp_sbs = False
class SbS(torch.nn.Module):
_epsilon_xy: torch.nn.parameter.Parameter
_epsilon_xy_exists: bool = False
_epsilon_0: torch.Tensor | None = None
_epsilon_t: torch.Tensor | None = None
_weights: torch.nn.parameter.Parameter
_weights_exists: bool = False
_kernel_size: torch.Tensor | None = None
_stride: torch.Tensor | None = None
_dilation: torch.Tensor | None = None
_padding: torch.Tensor | None = None
_output_size: torch.Tensor | None = None
_number_of_spikes: torch.Tensor | None = None
_number_of_cpu_processes: torch.Tensor | None = None
_number_of_neurons: torch.Tensor | None = None
_number_of_input_neurons: torch.Tensor | None = None
_h_initial: torch.Tensor | None = None
_epsilon_xy_backup: torch.Tensor | None = None
_weights_backup: torch.Tensor | None = None
_alpha_number_of_iterations: torch.Tensor | None = None
def __init__(
self,
number_of_input_neurons: int,
number_of_neurons: int,
input_size: list[int],
forward_kernel_size: list[int],
number_of_spikes: int,
epsilon_t: torch.Tensor,
epsilon_xy_intitial: float = 0.1,
epsilon_0: float = 1.0,
weight_noise_amplitude: float = 0.01,
is_pooling_layer: bool = False,
strides: list[int] = [1, 1],
dilation: list[int] = [0, 0],
padding: list[int] = [0, 0],
alpha_number_of_iterations: int = 0,
number_of_cpu_processes: int = 1,
) -> None:
"""Constructor"""
super().__init__()
self.stride = torch.tensor(strides, dtype=torch.int64)
self.dilation = torch.tensor(dilation, dtype=torch.int64)
self.padding = torch.tensor(padding, dtype=torch.int64)
self.kernel_size = torch.tensor(
forward_kernel_size,
dtype=torch.int64,
)
self.number_of_input_neurons = torch.tensor(
number_of_input_neurons,
dtype=torch.int64,
)
self.number_of_neurons = torch.tensor(
number_of_neurons,
dtype=torch.int64,
)
self.alpha_number_of_iterations = torch.tensor(
alpha_number_of_iterations, dtype=torch.int64
)
self.calculate_output_size(torch.tensor(input_size, dtype=torch.int64))
self.set_h_init_to_uniform()
self.initialize_epsilon_xy(epsilon_xy_intitial)
self.epsilon_0 = torch.tensor(epsilon_0, dtype=torch.float32)
self.number_of_cpu_processes = torch.tensor(
number_of_cpu_processes, dtype=torch.int64
)
self.number_of_spikes = torch.tensor(number_of_spikes, dtype=torch.int64)
self.epsilon_t = epsilon_t.type(dtype=torch.float32)
self.initialize_weights(
is_pooling_layer=is_pooling_layer,
noise_amplitude=weight_noise_amplitude,
)
self.functional_sbs = FunctionalSbS.apply
####################################################################
# Variables in and out #
####################################################################
@property
def epsilon_xy(self) -> torch.Tensor | None:
if self._epsilon_xy_exists is False:
return None
else:
return self._epsilon_xy
@epsilon_xy.setter
def epsilon_xy(self, value: torch.Tensor):
assert value is not None
assert torch.is_tensor(value) is True
assert value.dim() == 4
assert value.dtype == torch.float32
if self._epsilon_xy_exists is False:
self._epsilon_xy = torch.nn.parameter.Parameter(
value.detach().clone(memory_format=torch.contiguous_format),
requires_grad=True,
)
self._epsilon_xy_exists = True
else:
self._epsilon_xy.data = value.detach().clone(
memory_format=torch.contiguous_format
)
@property
def epsilon_0(self) -> torch.Tensor | None:
return self._epsilon_0
@epsilon_0.setter
def epsilon_0(self, value: torch.Tensor):
assert value is not None
assert torch.is_tensor(value) is True
assert torch.numel(value) == 1
assert value.dtype == torch.float32
assert value.item() > 0
self._epsilon_0 = value.detach().clone(memory_format=torch.contiguous_format)
self._epsilon_0.requires_grad_(False)
@property
def epsilon_t(self) -> torch.Tensor | None:
return self._epsilon_t
@epsilon_t.setter
def epsilon_t(self, value: torch.Tensor):
assert value is not None
assert torch.is_tensor(value) is True
assert value.dim() == 1
assert value.dtype == torch.float32
self._epsilon_t = value.detach().clone(memory_format=torch.contiguous_format)
self._epsilon_t.requires_grad_(False)
@property
def weights(self) -> torch.Tensor | None:
if self._weights_exists is False:
return None
else:
return self._weights
@weights.setter
def weights(self, value: torch.Tensor):
assert value is not None
assert torch.is_tensor(value) is True
assert value.dim() == 2
assert value.dtype == torch.float32
temp: torch.Tensor = value.detach().clone(memory_format=torch.contiguous_format)
temp /= temp.sum(dim=0, keepdim=True, dtype=torch.float32)
if self._weights_exists is False:
self._weights = torch.nn.parameter.Parameter(
temp,
requires_grad=True,
)
self._weights_exists = True
else:
self._weights.data = temp
@property
def kernel_size(self) -> torch.Tensor | None:
return self._kernel_size
@kernel_size.setter
def kernel_size(self, value: torch.Tensor):
assert value is not None
assert torch.is_tensor(value) is True
assert value.dim() == 1
assert torch.numel(value) == 2
assert value.dtype == torch.int64
assert value[0] > 0
assert value[1] > 0
self._kernel_size = value.detach().clone(memory_format=torch.contiguous_format)
self._kernel_size.requires_grad_(False)
@property
def stride(self) -> torch.Tensor | None:
return self._stride
@stride.setter
def stride(self, value: torch.Tensor):
assert value is not None
assert torch.is_tensor(value) is True
assert value.dim() == 1
assert torch.numel(value) == 2
assert value.dtype == torch.int64
assert value[0] > 0
assert value[1] > 0
self._stride = value.detach().clone(memory_format=torch.contiguous_format)
self._stride.requires_grad_(False)
@property
def dilation(self) -> torch.Tensor | None:
return self._dilation
@dilation.setter
def dilation(self, value: torch.Tensor):
assert value is not None
assert torch.is_tensor(value) is True
assert value.dim() == 1
assert torch.numel(value) == 2
assert value.dtype == torch.int64
assert value[0] > 0
assert value[1] > 0
self._dilation = value.detach().clone(memory_format=torch.contiguous_format)
self._dilation.requires_grad_(False)
@property
def padding(self) -> torch.Tensor | None:
return self._padding
@padding.setter
def padding(self, value: torch.Tensor):
assert value is not None
assert torch.is_tensor(value) is True
assert value.dim() == 1
assert torch.numel(value) == 2
assert value.dtype == torch.int64
assert value[0] >= 0
assert value[1] >= 0
self._padding = value.detach().clone(memory_format=torch.contiguous_format)
self._padding.requires_grad_(False)
@property
def output_size(self) -> torch.Tensor | None:
return self._output_size
@output_size.setter
def output_size(self, value: torch.Tensor):
assert value is not None
assert torch.is_tensor(value) is True
assert value.dim() == 1
assert torch.numel(value) == 2
assert value.dtype == torch.int64
assert value[0] > 0
assert value[1] > 0
self._output_size = value.detach().clone(memory_format=torch.contiguous_format)
self._output_size.requires_grad_(False)
@property
def number_of_spikes(self) -> torch.Tensor | None:
return self._number_of_spikes
@number_of_spikes.setter
def number_of_spikes(self, value: torch.Tensor):
assert value is not None
assert torch.is_tensor(value) is True
assert torch.numel(value) == 1
assert value.dtype == torch.int64
assert value.item() > 0
self._number_of_spikes = value.detach().clone(
memory_format=torch.contiguous_format
)
self._number_of_spikes.requires_grad_(False)
@property
def number_of_cpu_processes(self) -> torch.Tensor | None:
return self._number_of_cpu_processes
@number_of_cpu_processes.setter
def number_of_cpu_processes(self, value: torch.Tensor):
assert value is not None
assert torch.is_tensor(value) is True
assert torch.numel(value) == 1
assert value.dtype == torch.int64
assert value.item() > 0
self._number_of_cpu_processes = value.detach().clone(
memory_format=torch.contiguous_format
)
self._number_of_cpu_processes.requires_grad_(False)
@property
def number_of_neurons(self) -> torch.Tensor | None:
return self._number_of_neurons
@number_of_neurons.setter
def number_of_neurons(self, value: torch.Tensor):
assert value is not None
assert torch.is_tensor(value) is True
assert torch.numel(value) == 1
assert value.dtype == torch.int64
assert value.item() > 0
self._number_of_neurons = value.detach().clone(
memory_format=torch.contiguous_format
)
self._number_of_neurons.requires_grad_(False)
@property
def number_of_input_neurons(self) -> torch.Tensor | None:
return self._number_of_input_neurons
@number_of_input_neurons.setter
def number_of_input_neurons(self, value: torch.Tensor):
assert value is not None
assert torch.is_tensor(value) is True
assert torch.numel(value) == 1
assert value.dtype == torch.int64
assert value.item() > 0
self._number_of_input_neurons = value.detach().clone(
memory_format=torch.contiguous_format
)
self._number_of_input_neurons.requires_grad_(False)
@property
def h_initial(self) -> torch.Tensor | None:
return self._h_initial
@h_initial.setter
def h_initial(self, value: torch.Tensor):
assert value is not None
assert torch.is_tensor(value) is True
assert value.dim() == 1
assert value.dtype == torch.float32
self._h_initial = value.detach().clone(memory_format=torch.contiguous_format)
self._h_initial.requires_grad_(False)
@property
def alpha_number_of_iterations(self) -> torch.Tensor | None:
return self._alpha_number_of_iterations
@alpha_number_of_iterations.setter
def alpha_number_of_iterations(self, value: torch.Tensor):
assert value is not None
assert torch.is_tensor(value) is True
assert torch.numel(value) == 1
assert value.dtype == torch.int64
assert value.item() >= 0
self._alpha_number_of_iterations = value.detach().clone(
memory_format=torch.contiguous_format
)
self._alpha_number_of_iterations.requires_grad_(False)
####################################################################
# Forward #
####################################################################
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""PyTorch Forward method. Does the work."""
# Are we happy with the input?
assert input is not None
assert torch.is_tensor(input) is True
assert input.dim() == 4
assert input.dtype == torch.float32
# Are we happy with the rest of the network?
assert self._epsilon_xy_exists is True
assert self._epsilon_xy is not None
assert self._epsilon_0 is not None
assert self._epsilon_t is not None
assert self._weights_exists is True
assert self._weights is not None
assert self._kernel_size is not None
assert self._stride is not None
assert self._dilation is not None
assert self._padding is not None
assert self._output_size is not None
assert self._number_of_spikes is not None
assert self._number_of_cpu_processes is not None
assert self._h_initial is not None
assert self._alpha_number_of_iterations is not None
# SbS forward functional
return self.functional_sbs(
input,
self._epsilon_xy,
self._epsilon_0,
self._epsilon_t,
self._weights,
self._kernel_size,
self._stride,
self._dilation,
self._padding,
self._output_size,
self._number_of_spikes,
self._number_of_cpu_processes,
self._h_initial,
self._alpha_number_of_iterations,
)
####################################################################
# Helper functions #
####################################################################
def calculate_output_size(self, value: torch.Tensor) -> None:
coordinates_0, coordinates_1 = self._get_coordinates(value)
self._output_size: torch.Tensor = torch.tensor(
[
coordinates_0.shape[1],
coordinates_1.shape[1],
],
dtype=torch.int64,
)
self._output_size.requires_grad_(False)
def _get_coordinates(
self, value: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Function converts parameter in coordinates
for the convolution window"""
assert value is not None
assert torch.is_tensor(value) is True
assert value.dim() == 1
assert torch.numel(value) == 2
assert value.dtype == torch.int64
assert value[0] > 0
assert value[1] > 0
assert self._kernel_size is not None
assert self._stride is not None
assert self._dilation is not None
assert self._padding is not None
assert torch.numel(self._kernel_size) == 2
assert torch.numel(self._stride) == 2
assert torch.numel(self._dilation) == 2
assert torch.numel(self._padding) == 2
unfold_0: torch.nn.Unfold = torch.nn.Unfold(
kernel_size=(int(self._kernel_size[0]), 1),
dilation=int(self._dilation[0]),
padding=int(self._padding[0]),
stride=int(self._stride[0]),
)
unfold_1: torch.nn.Unfold = torch.nn.Unfold(
kernel_size=(1, int(self._kernel_size[1])),
dilation=int(self._dilation[1]),
padding=int(self._padding[1]),
stride=int(self._stride[1]),
)
coordinates_0: torch.Tensor = (
unfold_0(
torch.unsqueeze(
torch.unsqueeze(
torch.unsqueeze(
torch.arange(0, int(value[0]), dtype=torch.float32),
1,
),
0,
),
0,
)
)
.squeeze(0)
.type(torch.int64)
)
coordinates_1: torch.Tensor = (
unfold_1(
torch.unsqueeze(
torch.unsqueeze(
torch.unsqueeze(
torch.arange(0, int(value[1]), dtype=torch.float32),
0,
),
0,
),
0,
)
)
.squeeze(0)
.type(torch.int64)
)
return coordinates_0, coordinates_1
def _initial_random_weights(self, noise_amplitude: torch.Tensor) -> torch.Tensor:
"""Creates initial weights
Uniform plus random noise plus normalization
"""
assert torch.numel(noise_amplitude) == 1
assert noise_amplitude.item() >= 0
assert noise_amplitude.dtype == torch.float32
assert self._number_of_neurons is not None
assert self._number_of_input_neurons is not None
assert self._kernel_size is not None
weights = torch.empty(
(
int(self._kernel_size[0]),
int(self._kernel_size[1]),
int(self._number_of_input_neurons),
int(self._number_of_neurons),
),
dtype=torch.float32,
)
torch.nn.init.uniform_(weights, a=1.0, b=(1.0 + noise_amplitude.item()))
return weights
def _make_pooling_weights(self) -> torch.Tensor:
"""For generating the pooling weights."""
assert self._number_of_neurons is not None
assert self._kernel_size is not None
norm: float = 1.0 / (self._kernel_size[0] * self._kernel_size[1])
weights: torch.Tensor = torch.zeros(
(
int(self._kernel_size[0]),
int(self._kernel_size[1]),
int(self._number_of_neurons),
int(self._number_of_neurons),
),
dtype=torch.float32,
)
for i in range(0, int(self._number_of_neurons)):
weights[:, :, i, i] = norm
return weights
def initialize_weights(
self,
is_pooling_layer: bool = False,
noise_amplitude: float = 0.01,
) -> None:
"""For the generation of the initital weights.
Switches between normal initial random weights and pooling weights."""
assert self._kernel_size is not None
if is_pooling_layer is True:
weights = self._make_pooling_weights()
else:
weights = self._initial_random_weights(
torch.tensor(noise_amplitude, dtype=torch.float32)
)
weights = weights.moveaxis(-1, 0).moveaxis(-1, 1)
weights_t = torch.nn.functional.unfold(
input=weights,
kernel_size=(int(self._kernel_size[0]), int(self._kernel_size[1])),
dilation=(1, 1),
padding=(0, 0),
stride=(1, 1),
).squeeze()
weights_t = torch.moveaxis(weights_t, 0, 1)
self.weights = weights_t
def initialize_epsilon_xy(
self,
eps_xy_intitial: float,
) -> None:
"""Creates initial epsilon xy matrices"""
assert self._output_size is not None
assert self._kernel_size is not None
assert eps_xy_intitial > 0
eps_xy_temp: torch.Tensor = torch.full(
(
int(self._output_size[0]),
int(self._output_size[1]),
int(self._kernel_size[0]),
int(self._kernel_size[1]),
),
eps_xy_intitial,
dtype=torch.float32,
)
self.epsilon_xy = eps_xy_temp
def set_h_init_to_uniform(self) -> None:
assert self._number_of_neurons is not None
h_initial: torch.Tensor = torch.full(
(int(self._number_of_neurons.item()),),
(1.0 / float(self._number_of_neurons.item())),
dtype=torch.float32,
)
self.h_initial = h_initial
# Epsilon XY
def backup_epsilon_xy(self) -> None:
assert self._epsilon_xy_exists is True
self._epsilon_xy_backup = self._epsilon_xy.data.clone()
def restore_epsilon_xy(self) -> None:
assert self._epsilon_xy_backup is not None
assert self._epsilon_xy_exists is True
self._epsilon_xy.data = self._epsilon_xy_backup.clone()
def mean_epsilon_xy(self) -> None:
assert self._epsilon_xy_exists is True
fill_value: float = float(self._epsilon_xy.data.mean())
self._epsilon_xy.data = torch.full_like(
self._epsilon_xy.data, fill_value, dtype=torch.float32
)
def threshold_epsilon_xy(self, threshold: float) -> None:
assert self._epsilon_xy_exists is True
assert threshold >= 0
torch.clamp(
self._epsilon_xy.data,
min=float(threshold),
max=None,
out=self._epsilon_xy.data,
)
# Weights
def backup_weights(self) -> None:
assert self._weights_exists is True
self._weights_backup = self._weights.data.clone()
def restore_weights(self) -> None:
assert self._weights_backup is not None
assert self._weights_exists is True
self._weights.data = self._weights_backup.clone()
def norm_weights(self) -> None:
assert self._weights_exists is True
temp: torch.Tensor = (
self._weights.data.detach()
.clone(memory_format=torch.contiguous_format)
.type(dtype=torch.float32)
)
temp /= temp.sum(dim=0, keepdim=True, dtype=torch.float32)
self._weights.data = temp
def threshold_weights(self, threshold: float) -> None:
assert self._weights_exists is True
assert threshold >= 0
torch.clamp(
self._weights.data,
min=float(threshold),
max=None,
out=self._weights.data,
)
class FunctionalSbS(torch.autograd.Function):
@staticmethod
def forward( # type: ignore
ctx,
input: torch.Tensor,
epsilon_xy: torch.Tensor,
epsilon_0: torch.Tensor,
epsilon_t: torch.Tensor,
weights: torch.Tensor,
kernel_size: torch.Tensor,
stride: torch.Tensor,
dilation: torch.Tensor,
padding: torch.Tensor,
output_size: torch.Tensor,
number_of_spikes: torch.Tensor,
number_of_cpu_processes: torch.Tensor,
h_initial: torch.Tensor,
alpha_number_of_iterations: torch.Tensor,
) -> torch.Tensor:
torch.set_default_dtype(torch.float32)
assert input.dim() == 4
assert torch.numel(kernel_size) == 2
assert torch.numel(dilation) == 2
assert torch.numel(padding) == 2
assert torch.numel(stride) == 2
assert torch.numel(output_size) == 2
assert torch.numel(epsilon_0) == 1
assert torch.numel(number_of_spikes) == 1
assert torch.numel(number_of_cpu_processes) == 1
assert torch.numel(alpha_number_of_iterations) == 1
input_size = torch.tensor([input.shape[2], input.shape[3]])
############################################################
# Pre convolving the input #
############################################################
input_convolved_temp = torch.nn.functional.unfold(
input,
kernel_size=tuple(kernel_size.tolist()),
dilation=tuple(dilation.tolist()),
padding=tuple(padding.tolist()),
stride=tuple(stride.tolist()),
)
input_convolved = torch.nn.functional.fold(
input_convolved_temp,
output_size=tuple(output_size.tolist()),
kernel_size=(1, 1),
dilation=(1, 1),
padding=(0, 0),
stride=(1, 1),
).requires_grad_(True)
epsilon_xy_convolved: torch.Tensor = (
(
torch.nn.functional.unfold(
epsilon_xy.reshape(
(
int(epsilon_xy.shape[0]) * int(epsilon_xy.shape[1]),
int(epsilon_xy.shape[2]),
int(epsilon_xy.shape[3]),
)
)
.unsqueeze(1)
.tile((1, input.shape[1], 1, 1)),
kernel_size=tuple(kernel_size.tolist()),
dilation=1,
padding=0,
stride=1,
)
.squeeze(-1)
.reshape(
(
int(epsilon_xy.shape[0]),
int(epsilon_xy.shape[1]),
int(input_convolved.shape[1]),
)
)
)
.moveaxis(-1, 0)
.contiguous(memory_format=torch.contiguous_format)
)
############################################################
# Spike generation #
############################################################
if cpp_spike is False:
# Alternative to the C++ module but 5x slower:
spikes = (
(
input_convolved.movedim(source=(2, 3), destination=(0, 1))
.reshape(
shape=(
input_convolved.shape[2]
* input_convolved.shape[3]
* input_convolved.shape[0],
input_convolved.shape[1],
)
)
.multinomial(
num_samples=int(number_of_spikes.item()), replacement=True
)
)
.reshape(
shape=(
input_convolved.shape[2],
input_convolved.shape[3],
input_convolved.shape[0],
int(number_of_spikes.item()),
)
)
.movedim(source=(0, 1), destination=(2, 3))
).contiguous(memory_format=torch.contiguous_format)
else:
# Normalized cumsum
input_cumsum: torch.Tensor = torch.cumsum(
input_convolved, dim=1, dtype=torch.float32
)
input_cumsum_last: torch.Tensor = input_cumsum[:, -1, :, :].unsqueeze(1)
input_cumsum /= input_cumsum_last
random_values = torch.rand(
size=[
input_cumsum.shape[0],
int(number_of_spikes.item()),
input_cumsum.shape[2],
input_cumsum.shape[3],
],
dtype=torch.float32,
)
spikes = torch.empty_like(random_values, dtype=torch.int64)
# Prepare for Export (Pointer and stuff)->
np_input: np.ndarray = input_cumsum.detach().numpy()
assert input_cumsum.dtype == torch.float32
assert np_input.flags["C_CONTIGUOUS"] is True
assert np_input.ndim == 4
np_random_values: np.ndarray = random_values.detach().numpy()
assert random_values.dtype == torch.float32
assert np_random_values.flags["C_CONTIGUOUS"] is True
assert np_random_values.ndim == 4
np_spikes: np.ndarray = spikes.detach().numpy()
assert spikes.dtype == torch.int64
assert np_spikes.flags["C_CONTIGUOUS"] is True
assert np_spikes.ndim == 4
# <- Prepare for Export
spike_generation: PySpikeGeneration2DManyIP.SpikeGeneration2DManyIP = (
PySpikeGeneration2DManyIP.SpikeGeneration2DManyIP()
)
spike_generation.spike_generation_multi_pattern(
np_input.__array_interface__["data"][0],
int(np_input.shape[0]),
int(np_input.shape[1]),
int(np_input.shape[2]),
int(np_input.shape[3]),
np_random_values.__array_interface__["data"][0],
int(np_random_values.shape[0]),
int(np_random_values.shape[1]),
int(np_random_values.shape[2]),
int(np_random_values.shape[3]),
np_spikes.__array_interface__["data"][0],
int(np_spikes.shape[0]),
int(np_spikes.shape[1]),
int(np_spikes.shape[2]),
int(np_spikes.shape[3]),
int(number_of_cpu_processes.item()),
)
############################################################
# H dynamic #
############################################################
assert epsilon_t.ndim == 1
assert epsilon_t.shape[0] >= number_of_spikes
if cpp_sbs is False:
h = torch.tile(
h_initial.unsqueeze(0).unsqueeze(0).unsqueeze(0),
dims=[int(input.shape[0]), int(output_size[0]), int(output_size[1]), 1],
)
epsilon_scale: torch.Tensor = torch.ones(
size=[
int(spikes.shape[0]),
int(spikes.shape[2]),
int(spikes.shape[3]),
1,
],
dtype=torch.float32,
)
for t in range(0, spikes.shape[1]):
if epsilon_scale.max() > 1e10:
h /= epsilon_scale
epsilon_scale = torch.ones_like(epsilon_scale)
h_temp: torch.Tensor = weights[spikes[:, t, :, :], :] * h
wx = 0
wy = 0
if t == 0:
epsilon_temp: torch.Tensor = torch.empty(
(
int(spikes.shape[0]),
int(spikes.shape[2]),
int(spikes.shape[3]),
),
dtype=torch.float32,
)
for wx in range(0, int(spikes.shape[2])):
for wy in range(0, int(spikes.shape[3])):
epsilon_temp[:, wx, wy] = epsilon_xy_convolved[
spikes[:, t, wx, wy], wx, wy
]
epsilon_subsegment: torch.Tensor = (
epsilon_temp.unsqueeze(-1) * epsilon_t[t] * epsilon_0
)
h_temp_sum: torch.Tensor = (
epsilon_scale * epsilon_subsegment / h_temp.sum(dim=3, keepdim=True)
)
torch.nan_to_num(
h_temp_sum, out=h_temp_sum, nan=0.0, posinf=0.0, neginf=0.0
)
h_temp *= h_temp_sum
h += h_temp
epsilon_scale *= 1.0 + epsilon_subsegment
h /= epsilon_scale
output = h.movedim(3, 1)
else:
epsilon_t_0: torch.Tensor = epsilon_t * epsilon_0
h_shape: tuple[int, int, int, int] = (
int(input.shape[0]),
int(weights.shape[1]),
int(output_size[0]),
int(output_size[1]),
)
output = torch.empty(h_shape, dtype=torch.float32)
# Prepare the export to C++ ->
np_h: np.ndarray = output.detach().numpy()
assert output.dtype == torch.float32
assert np_h.flags["C_CONTIGUOUS"] is True
assert np_h.ndim == 4
np_epsilon_xy: np.ndarray = epsilon_xy_convolved.detach().numpy()
assert epsilon_xy.dtype == torch.float32
assert np_epsilon_xy.flags["C_CONTIGUOUS"] is True
assert np_epsilon_xy.ndim == 3
np_epsilon_t: np.ndarray = epsilon_t_0.detach().numpy()
assert epsilon_t_0.dtype == torch.float32
assert np_epsilon_t.flags["C_CONTIGUOUS"] is True
assert np_epsilon_t.ndim == 1
np_weights: np.ndarray = weights.detach().numpy()
assert weights.dtype == torch.float32
assert np_weights.flags["C_CONTIGUOUS"] is True
assert np_weights.ndim == 2
np_spikes = spikes.contiguous().detach().numpy()
assert spikes.dtype == torch.int64
assert np_spikes.flags["C_CONTIGUOUS"] is True
assert np_spikes.ndim == 4
np_h_initial = h_initial.contiguous().detach().numpy()
assert h_initial.dtype == torch.float32
assert np_h_initial.flags["C_CONTIGUOUS"] is True
assert np_h_initial.ndim == 1
# <- Prepare the export to C++
h_dynamic: PyHDynamicCNNManyIP.HDynamicCNNManyIP = (
PyHDynamicCNNManyIP.HDynamicCNNManyIP()
)
h_dynamic.update_with_init_vector_multi_pattern(
np_h.__array_interface__["data"][0],
int(np_h.shape[0]),
int(np_h.shape[1]),
int(np_h.shape[2]),
int(np_h.shape[3]),
np_epsilon_xy.__array_interface__["data"][0],
int(np_epsilon_xy.shape[0]),
int(np_epsilon_xy.shape[1]),
int(np_epsilon_xy.shape[2]),
np_epsilon_t.__array_interface__["data"][0],
int(np_epsilon_t.shape[0]),
np_weights.__array_interface__["data"][0],
int(np_weights.shape[0]),
int(np_weights.shape[1]),
np_spikes.__array_interface__["data"][0],
int(np_spikes.shape[0]),
int(np_spikes.shape[1]),
int(np_spikes.shape[2]),
int(np_spikes.shape[3]),
np_h_initial.__array_interface__["data"][0],
int(np_h_initial.shape[0]),
int(number_of_cpu_processes.item()),
)
############################################################
# Alpha #
############################################################
alpha_number_of_iterations_int: int = int(alpha_number_of_iterations.item())
if alpha_number_of_iterations_int > 0:
# Initialization
virtual_reconstruction_weight: torch.Tensor = torch.einsum(
"bixy,ji->bjxy", output, weights
)
alpha_fill_value: float = 1.0 / (
virtual_reconstruction_weight.shape[2]
* virtual_reconstruction_weight.shape[3]
)
alpha_dynamic: torch.Tensor = torch.full(
(
int(virtual_reconstruction_weight.shape[0]),
1,
int(virtual_reconstruction_weight.shape[2]),
int(virtual_reconstruction_weight.shape[3]),
),
alpha_fill_value,
dtype=torch.float32,
)
# Iterations
for _ in range(0, alpha_number_of_iterations_int):
alpha_temp: torch.Tensor = alpha_dynamic * virtual_reconstruction_weight
alpha_temp /= alpha_temp.sum(dim=3, keepdim=True).sum(
dim=2, keepdim=True
)
torch.nan_to_num(
alpha_temp, out=alpha_temp, nan=0.0, posinf=0.0, neginf=0.0
)
alpha_temp = torch.nn.functional.unfold(
alpha_temp,
kernel_size=(1, 1),
dilation=1,
padding=0,
stride=1,
)
alpha_temp = torch.nn.functional.fold(
alpha_temp,
output_size=tuple(input_size.tolist()),
kernel_size=tuple(kernel_size.tolist()),
dilation=tuple(dilation.tolist()),
padding=tuple(padding.tolist()),
stride=tuple(stride.tolist()),
)
alpha_temp = (alpha_temp * input).sum(dim=1, keepdim=True)
alpha_temp = torch.nn.functional.unfold(
alpha_temp,
kernel_size=tuple(kernel_size.tolist()),
dilation=tuple(dilation.tolist()),
padding=tuple(padding.tolist()),
stride=tuple(stride.tolist()),
)
alpha_temp = torch.nn.functional.fold(
alpha_temp,
output_size=tuple(output_size.tolist()),
kernel_size=(1, 1),
dilation=(1, 1),
padding=(0, 0),
stride=(1, 1),
)
alpha_dynamic = alpha_temp.sum(dim=1, keepdim=True)
alpha_dynamic += 1e-20
# Alpha normalization
alpha_dynamic /= alpha_dynamic.sum(dim=3, keepdim=True).sum(
dim=2, keepdim=True
)
torch.nan_to_num(
alpha_dynamic, out=alpha_dynamic, nan=0.0, posinf=0.0, neginf=0.0
)
# Applied to the output
output *= alpha_dynamic
############################################################
# Save the necessary data for the backward pass #
############################################################
ctx.save_for_backward(
input_convolved,
epsilon_xy_convolved,
epsilon_0,
weights,
output,
kernel_size,
stride,
dilation,
padding,
input_size,
)
return output
@staticmethod
def backward(ctx, grad_output):
# Get the variables back
(
input,
epsilon_xy,
epsilon_0,
weights,
output,
kernel_size,
stride,
dilation,
padding,
input_size,
) = ctx.saved_tensors
torch.set_default_dtype(torch.float32)
input /= input.sum(dim=1, keepdim=True, dtype=torch.float32)
# For debugging:
# print(
# f"S: O: {output.min().item():e} {output.max().item():e} I: {input.min().item():e} {input.max().item():e} G: {grad_output.min().item():e} {grad_output.max().item():e}"
# )
epsilon_0_float: float = epsilon_0.item()
temp_e: torch.Tensor = 1.0 / ((epsilon_xy * epsilon_0_float) + 1.0)
eps_a: torch.Tensor = temp_e.clone()
eps_a *= epsilon_xy * epsilon_0_float
eps_b: torch.Tensor = temp_e**2 * epsilon_0_float
backprop_r: torch.Tensor = weights.unsqueeze(0).unsqueeze(-1).unsqueeze(
-1
) * output.unsqueeze(1)
backprop_bigr: torch.Tensor = backprop_r.sum(axis=2)
temp: torch.Tensor = input / (backprop_bigr**2 + 1e-20)
backprop_f: torch.Tensor = output.unsqueeze(1) * temp.unsqueeze(2)
torch.nan_to_num(
backprop_f, out=backprop_f, nan=1e30, posinf=1e30, neginf=-1e30
)
torch.clip(backprop_f, out=backprop_f, min=-1e30, max=1e30)
tempz: torch.Tensor = 1.0 / (backprop_bigr + 1e-20)
backprop_z: torch.Tensor = backprop_r * tempz.unsqueeze(2)
torch.nan_to_num(
backprop_z, out=backprop_z, nan=1e30, posinf=1e30, neginf=-1e30
)
torch.clip(backprop_z, out=backprop_z, min=-1e30, max=1e30)
result_omega: torch.Tensor = backprop_bigr.unsqueeze(2) * grad_output.unsqueeze(
1
)
result_omega -= torch.einsum(
"bijxy,bjxy->bixy", backprop_r, grad_output
).unsqueeze(2)
result_omega *= backprop_f
result_eps_xy: torch.Tensor = (
(
(backprop_z * input.unsqueeze(2) - output.unsqueeze(1))
* grad_output.unsqueeze(1)
)
.sum(dim=2)
.sum(dim=0)
) * eps_b
result_phi: torch.Tensor = torch.einsum(
"bijxy,bjxy->bixy", backprop_z, grad_output
) * eps_a.unsqueeze(0)
grad_weights = result_omega.sum(0).sum(-1).sum(-1)
torch.nan_to_num(
grad_weights, out=grad_weights, nan=1e30, posinf=1e30, neginf=-1e30
)
torch.clip(grad_weights, out=grad_weights, min=-1e30, max=1e30)
grad_input = torch.nn.functional.fold(
torch.nn.functional.unfold(
result_phi,
kernel_size=(1, 1),
dilation=1,
padding=0,
stride=1,
),
output_size=input_size,
kernel_size=kernel_size,
dilation=dilation,
padding=padding,
stride=stride,
)
torch.nan_to_num(
grad_input, out=grad_input, nan=1e30, posinf=1e30, neginf=-1e30
)
torch.clip(grad_input, out=grad_input, min=-1e30, max=1e30)
grad_eps_xy_temp = torch.nn.functional.fold(
result_eps_xy.moveaxis(0, -1)
.reshape(
(
int(result_eps_xy.shape[1]) * int(result_eps_xy.shape[2]),
int(result_eps_xy.shape[0]),
)
)
.unsqueeze(-1),
output_size=kernel_size,
kernel_size=kernel_size,
)
grad_eps_xy = (
grad_eps_xy_temp.sum(dim=1)
.reshape(
(
int(result_eps_xy.shape[1]),
int(result_eps_xy.shape[2]),
int(grad_eps_xy_temp.shape[-2]),
int(grad_eps_xy_temp.shape[-1]),
)
)
.contiguous(memory_format=torch.contiguous_format)
)
torch.nan_to_num(
grad_eps_xy, out=grad_eps_xy, nan=1e30, posinf=1e30, neginf=-1e30
)
torch.clip(grad_eps_xy, out=grad_eps_xy, min=-1e30, max=1e30)
grad_epsilon_0 = None
grad_epsilon_t = None
grad_kernel_size = None
grad_stride = None
grad_dilation = None
grad_padding = None
grad_output_size = None
grad_number_of_spikes = None
grad_number_of_cpu_processes = None
grad_h_initial = None
grad_alpha_number_of_iterations = None
return (
grad_input,
grad_eps_xy,
grad_epsilon_0,
grad_epsilon_t,
grad_weights,
grad_kernel_size,
grad_stride,
grad_dilation,
grad_padding,
grad_output_size,
grad_number_of_spikes,
grad_number_of_cpu_processes,
grad_h_initial,
grad_alpha_number_of_iterations,
)
# %%