1221 lines
42 KiB
Python
1221 lines
42 KiB
Python
|
# MIT License
|
||
|
# Copyright 2022 University of Bremen
|
||
|
#
|
||
|
# Permission is hereby granted, free of charge, to any person obtaining
|
||
|
# a copy of this software and associated documentation files (the "Software"),
|
||
|
# to deal in the Software without restriction, including without limitation
|
||
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||
|
# Software is furnished to do so, subject to the following conditions:
|
||
|
#
|
||
|
# The above copyright notice and this permission notice shall be included
|
||
|
# in all copies or substantial portions of the Software.
|
||
|
#
|
||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||
|
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
||
|
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
||
|
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR
|
||
|
# THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||
|
#
|
||
|
#
|
||
|
# David Rotermund ( davrot@uni-bremen.de )
|
||
|
#
|
||
|
#
|
||
|
# Release history:
|
||
|
# ================
|
||
|
# 1.0.0 -- 01.05.2022: first release
|
||
|
#
|
||
|
#
|
||
|
|
||
|
# %%
|
||
|
import torch
|
||
|
import numpy as np
|
||
|
|
||
|
try:
|
||
|
import PySpikeGeneration2DManyIP
|
||
|
|
||
|
cpp_spike: bool = True
|
||
|
except Exception:
|
||
|
cpp_spike = False
|
||
|
|
||
|
try:
|
||
|
import PyHDynamicCNNManyIP
|
||
|
|
||
|
cpp_sbs: bool = True
|
||
|
except Exception:
|
||
|
cpp_sbs = False
|
||
|
|
||
|
|
||
|
class SbS(torch.nn.Module):
|
||
|
|
||
|
_epsilon_xy: torch.nn.parameter.Parameter
|
||
|
_epsilon_xy_exists: bool = False
|
||
|
_epsilon_0: torch.Tensor | None = None
|
||
|
_epsilon_t: torch.Tensor | None = None
|
||
|
_weights: torch.nn.parameter.Parameter
|
||
|
_weights_exists: bool = False
|
||
|
_kernel_size: torch.Tensor | None = None
|
||
|
_stride: torch.Tensor | None = None
|
||
|
_dilation: torch.Tensor | None = None
|
||
|
_padding: torch.Tensor | None = None
|
||
|
_output_size: torch.Tensor | None = None
|
||
|
_number_of_spikes: torch.Tensor | None = None
|
||
|
_number_of_cpu_processes: torch.Tensor | None = None
|
||
|
_number_of_neurons: torch.Tensor | None = None
|
||
|
_number_of_input_neurons: torch.Tensor | None = None
|
||
|
_h_initial: torch.Tensor | None = None
|
||
|
_epsilon_xy_backup: torch.Tensor | None = None
|
||
|
_weights_backup: torch.Tensor | None = None
|
||
|
_alpha_number_of_iterations: torch.Tensor | None = None
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
number_of_input_neurons: int,
|
||
|
number_of_neurons: int,
|
||
|
input_size: list[int],
|
||
|
forward_kernel_size: list[int],
|
||
|
number_of_spikes: int,
|
||
|
epsilon_t: torch.Tensor,
|
||
|
epsilon_xy_intitial: float = 0.1,
|
||
|
epsilon_0: float = 1.0,
|
||
|
weight_noise_amplitude: float = 0.01,
|
||
|
is_pooling_layer: bool = False,
|
||
|
strides: list[int] = [1, 1],
|
||
|
dilation: list[int] = [0, 0],
|
||
|
padding: list[int] = [0, 0],
|
||
|
alpha_number_of_iterations: int = 0,
|
||
|
number_of_cpu_processes: int = 1,
|
||
|
) -> None:
|
||
|
"""Constructor"""
|
||
|
super().__init__()
|
||
|
|
||
|
self.stride = torch.tensor(strides, dtype=torch.int64)
|
||
|
|
||
|
self.dilation = torch.tensor(dilation, dtype=torch.int64)
|
||
|
|
||
|
self.padding = torch.tensor(padding, dtype=torch.int64)
|
||
|
|
||
|
self.kernel_size = torch.tensor(
|
||
|
forward_kernel_size,
|
||
|
dtype=torch.int64,
|
||
|
)
|
||
|
|
||
|
self.number_of_input_neurons = torch.tensor(
|
||
|
number_of_input_neurons,
|
||
|
dtype=torch.int64,
|
||
|
)
|
||
|
|
||
|
self.number_of_neurons = torch.tensor(
|
||
|
number_of_neurons,
|
||
|
dtype=torch.int64,
|
||
|
)
|
||
|
|
||
|
self.alpha_number_of_iterations = torch.tensor(
|
||
|
alpha_number_of_iterations, dtype=torch.int64
|
||
|
)
|
||
|
|
||
|
self.calculate_output_size(torch.tensor(input_size, dtype=torch.int64))
|
||
|
|
||
|
self.set_h_init_to_uniform()
|
||
|
|
||
|
self.initialize_epsilon_xy(epsilon_xy_intitial)
|
||
|
|
||
|
self.epsilon_0 = torch.tensor(epsilon_0, dtype=torch.float64)
|
||
|
|
||
|
self.number_of_cpu_processes = torch.tensor(
|
||
|
number_of_cpu_processes, dtype=torch.int64
|
||
|
)
|
||
|
|
||
|
self.number_of_spikes = torch.tensor(number_of_spikes, dtype=torch.int64)
|
||
|
|
||
|
self.epsilon_t = epsilon_t.type(dtype=torch.float64)
|
||
|
|
||
|
self.initialize_weights(
|
||
|
is_pooling_layer=is_pooling_layer,
|
||
|
noise_amplitude=weight_noise_amplitude,
|
||
|
)
|
||
|
|
||
|
self.functional_sbs = FunctionalSbS.apply
|
||
|
|
||
|
####################################################################
|
||
|
# Variables in and out #
|
||
|
####################################################################
|
||
|
|
||
|
@property
|
||
|
def epsilon_xy(self) -> torch.Tensor | None:
|
||
|
if self._epsilon_xy_exists is False:
|
||
|
return None
|
||
|
else:
|
||
|
return self._epsilon_xy
|
||
|
|
||
|
@epsilon_xy.setter
|
||
|
def epsilon_xy(self, value: torch.Tensor):
|
||
|
assert value is not None
|
||
|
assert torch.is_tensor(value) is True
|
||
|
assert value.dim() == 2
|
||
|
assert value.dtype == torch.float64
|
||
|
if self._epsilon_xy_exists is False:
|
||
|
self._epsilon_xy = torch.nn.parameter.Parameter(
|
||
|
value.detach().clone(memory_format=torch.contiguous_format),
|
||
|
requires_grad=True,
|
||
|
)
|
||
|
self._epsilon_xy_exists = True
|
||
|
else:
|
||
|
self._epsilon_xy.data = value.detach().clone(
|
||
|
memory_format=torch.contiguous_format
|
||
|
)
|
||
|
|
||
|
@property
|
||
|
def epsilon_0(self) -> torch.Tensor | None:
|
||
|
return self._epsilon_0
|
||
|
|
||
|
@epsilon_0.setter
|
||
|
def epsilon_0(self, value: torch.Tensor):
|
||
|
assert value is not None
|
||
|
assert torch.is_tensor(value) is True
|
||
|
assert torch.numel(value) == 1
|
||
|
assert value.dtype == torch.float64
|
||
|
assert value.item() > 0
|
||
|
self._epsilon_0 = value.detach().clone(memory_format=torch.contiguous_format)
|
||
|
self._epsilon_0.requires_grad_(False)
|
||
|
|
||
|
@property
|
||
|
def epsilon_t(self) -> torch.Tensor | None:
|
||
|
return self._epsilon_t
|
||
|
|
||
|
@epsilon_t.setter
|
||
|
def epsilon_t(self, value: torch.Tensor):
|
||
|
assert value is not None
|
||
|
assert torch.is_tensor(value) is True
|
||
|
assert value.dim() == 1
|
||
|
assert value.dtype == torch.float64
|
||
|
self._epsilon_t = value.detach().clone(memory_format=torch.contiguous_format)
|
||
|
self._epsilon_t.requires_grad_(False)
|
||
|
|
||
|
@property
|
||
|
def weights(self) -> torch.Tensor | None:
|
||
|
if self._weights_exists is False:
|
||
|
return None
|
||
|
else:
|
||
|
return self._weights
|
||
|
|
||
|
@weights.setter
|
||
|
def weights(self, value: torch.Tensor):
|
||
|
assert value is not None
|
||
|
assert torch.is_tensor(value) is True
|
||
|
assert value.dim() == 2
|
||
|
assert value.dtype == torch.float64
|
||
|
temp: torch.Tensor = value.detach().clone(memory_format=torch.contiguous_format)
|
||
|
temp /= temp.sum(dim=0, keepdim=True, dtype=torch.float64)
|
||
|
if self._weights_exists is False:
|
||
|
self._weights = torch.nn.parameter.Parameter(
|
||
|
temp,
|
||
|
requires_grad=True,
|
||
|
)
|
||
|
self._weights_exists = True
|
||
|
else:
|
||
|
self._weights.data = temp
|
||
|
|
||
|
@property
|
||
|
def kernel_size(self) -> torch.Tensor | None:
|
||
|
return self._kernel_size
|
||
|
|
||
|
@kernel_size.setter
|
||
|
def kernel_size(self, value: torch.Tensor):
|
||
|
assert value is not None
|
||
|
assert torch.is_tensor(value) is True
|
||
|
assert value.dim() == 1
|
||
|
assert torch.numel(value) == 2
|
||
|
assert value.dtype == torch.int64
|
||
|
assert value[0] > 0
|
||
|
assert value[1] > 0
|
||
|
self._kernel_size = value.detach().clone(memory_format=torch.contiguous_format)
|
||
|
self._kernel_size.requires_grad_(False)
|
||
|
|
||
|
@property
|
||
|
def stride(self) -> torch.Tensor | None:
|
||
|
return self._stride
|
||
|
|
||
|
@stride.setter
|
||
|
def stride(self, value: torch.Tensor):
|
||
|
assert value is not None
|
||
|
assert torch.is_tensor(value) is True
|
||
|
assert value.dim() == 1
|
||
|
assert torch.numel(value) == 2
|
||
|
assert value.dtype == torch.int64
|
||
|
assert value[0] > 0
|
||
|
assert value[1] > 0
|
||
|
self._stride = value.detach().clone(memory_format=torch.contiguous_format)
|
||
|
self._stride.requires_grad_(False)
|
||
|
|
||
|
@property
|
||
|
def dilation(self) -> torch.Tensor | None:
|
||
|
return self._dilation
|
||
|
|
||
|
@dilation.setter
|
||
|
def dilation(self, value: torch.Tensor):
|
||
|
assert value is not None
|
||
|
assert torch.is_tensor(value) is True
|
||
|
assert value.dim() == 1
|
||
|
assert torch.numel(value) == 2
|
||
|
assert value.dtype == torch.int64
|
||
|
assert value[0] > 0
|
||
|
assert value[1] > 0
|
||
|
self._dilation = value.detach().clone(memory_format=torch.contiguous_format)
|
||
|
self._dilation.requires_grad_(False)
|
||
|
|
||
|
@property
|
||
|
def padding(self) -> torch.Tensor | None:
|
||
|
return self._padding
|
||
|
|
||
|
@padding.setter
|
||
|
def padding(self, value: torch.Tensor):
|
||
|
assert value is not None
|
||
|
assert torch.is_tensor(value) is True
|
||
|
assert value.dim() == 1
|
||
|
assert torch.numel(value) == 2
|
||
|
assert value.dtype == torch.int64
|
||
|
assert value[0] >= 0
|
||
|
assert value[1] >= 0
|
||
|
self._padding = value.detach().clone(memory_format=torch.contiguous_format)
|
||
|
self._padding.requires_grad_(False)
|
||
|
|
||
|
@property
|
||
|
def output_size(self) -> torch.Tensor | None:
|
||
|
return self._output_size
|
||
|
|
||
|
@output_size.setter
|
||
|
def output_size(self, value: torch.Tensor):
|
||
|
assert value is not None
|
||
|
assert torch.is_tensor(value) is True
|
||
|
assert value.dim() == 1
|
||
|
assert torch.numel(value) == 2
|
||
|
assert value.dtype == torch.int64
|
||
|
assert value[0] > 0
|
||
|
assert value[1] > 0
|
||
|
self._output_size = value.detach().clone(memory_format=torch.contiguous_format)
|
||
|
self._output_size.requires_grad_(False)
|
||
|
|
||
|
@property
|
||
|
def number_of_spikes(self) -> torch.Tensor | None:
|
||
|
return self._number_of_spikes
|
||
|
|
||
|
@number_of_spikes.setter
|
||
|
def number_of_spikes(self, value: torch.Tensor):
|
||
|
assert value is not None
|
||
|
assert torch.is_tensor(value) is True
|
||
|
assert torch.numel(value) == 1
|
||
|
assert value.dtype == torch.int64
|
||
|
assert value.item() > 0
|
||
|
self._number_of_spikes = value.detach().clone(
|
||
|
memory_format=torch.contiguous_format
|
||
|
)
|
||
|
self._number_of_spikes.requires_grad_(False)
|
||
|
|
||
|
@property
|
||
|
def number_of_cpu_processes(self) -> torch.Tensor | None:
|
||
|
return self._number_of_cpu_processes
|
||
|
|
||
|
@number_of_cpu_processes.setter
|
||
|
def number_of_cpu_processes(self, value: torch.Tensor):
|
||
|
assert value is not None
|
||
|
assert torch.is_tensor(value) is True
|
||
|
assert torch.numel(value) == 1
|
||
|
assert value.dtype == torch.int64
|
||
|
assert value.item() > 0
|
||
|
self._number_of_cpu_processes = value.detach().clone(
|
||
|
memory_format=torch.contiguous_format
|
||
|
)
|
||
|
self._number_of_cpu_processes.requires_grad_(False)
|
||
|
|
||
|
@property
|
||
|
def number_of_neurons(self) -> torch.Tensor | None:
|
||
|
return self._number_of_neurons
|
||
|
|
||
|
@number_of_neurons.setter
|
||
|
def number_of_neurons(self, value: torch.Tensor):
|
||
|
assert value is not None
|
||
|
assert torch.is_tensor(value) is True
|
||
|
assert torch.numel(value) == 1
|
||
|
assert value.dtype == torch.int64
|
||
|
assert value.item() > 0
|
||
|
self._number_of_neurons = value.detach().clone(
|
||
|
memory_format=torch.contiguous_format
|
||
|
)
|
||
|
self._number_of_neurons.requires_grad_(False)
|
||
|
|
||
|
@property
|
||
|
def number_of_input_neurons(self) -> torch.Tensor | None:
|
||
|
return self._number_of_input_neurons
|
||
|
|
||
|
@number_of_input_neurons.setter
|
||
|
def number_of_input_neurons(self, value: torch.Tensor):
|
||
|
assert value is not None
|
||
|
assert torch.is_tensor(value) is True
|
||
|
assert torch.numel(value) == 1
|
||
|
assert value.dtype == torch.int64
|
||
|
assert value.item() > 0
|
||
|
self._number_of_input_neurons = value.detach().clone(
|
||
|
memory_format=torch.contiguous_format
|
||
|
)
|
||
|
self._number_of_input_neurons.requires_grad_(False)
|
||
|
|
||
|
@property
|
||
|
def h_initial(self) -> torch.Tensor | None:
|
||
|
return self._h_initial
|
||
|
|
||
|
@h_initial.setter
|
||
|
def h_initial(self, value: torch.Tensor):
|
||
|
assert value is not None
|
||
|
assert torch.is_tensor(value) is True
|
||
|
assert value.dim() == 1
|
||
|
assert value.dtype == torch.float32
|
||
|
self._h_initial = value.detach().clone(memory_format=torch.contiguous_format)
|
||
|
self._h_initial.requires_grad_(False)
|
||
|
|
||
|
@property
|
||
|
def alpha_number_of_iterations(self) -> torch.Tensor | None:
|
||
|
return self._alpha_number_of_iterations
|
||
|
|
||
|
@alpha_number_of_iterations.setter
|
||
|
def alpha_number_of_iterations(self, value: torch.Tensor):
|
||
|
assert value is not None
|
||
|
assert torch.is_tensor(value) is True
|
||
|
assert torch.numel(value) == 1
|
||
|
assert value.dtype == torch.int64
|
||
|
assert value.item() >= 0
|
||
|
self._alpha_number_of_iterations = value.detach().clone(
|
||
|
memory_format=torch.contiguous_format
|
||
|
)
|
||
|
self._alpha_number_of_iterations.requires_grad_(False)
|
||
|
|
||
|
####################################################################
|
||
|
# Forward #
|
||
|
####################################################################
|
||
|
|
||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||
|
"""PyTorch Forward method. Does the work."""
|
||
|
|
||
|
# Are we happy with the input?
|
||
|
assert input is not None
|
||
|
assert torch.is_tensor(input) is True
|
||
|
assert input.dim() == 4
|
||
|
assert input.dtype == torch.float64
|
||
|
|
||
|
# Are we happy with the rest of the network?
|
||
|
assert self._epsilon_xy_exists is True
|
||
|
assert self._epsilon_xy is not None
|
||
|
assert self._epsilon_0 is not None
|
||
|
assert self._epsilon_t is not None
|
||
|
assert self._weights_exists is True
|
||
|
assert self._weights is not None
|
||
|
assert self._kernel_size is not None
|
||
|
assert self._stride is not None
|
||
|
assert self._dilation is not None
|
||
|
assert self._padding is not None
|
||
|
assert self._output_size is not None
|
||
|
assert self._number_of_spikes is not None
|
||
|
assert self._number_of_cpu_processes is not None
|
||
|
assert self._h_initial is not None
|
||
|
assert self._alpha_number_of_iterations is not None
|
||
|
|
||
|
# SbS forward functional
|
||
|
return self.functional_sbs(
|
||
|
input,
|
||
|
self._epsilon_xy,
|
||
|
self._epsilon_0,
|
||
|
self._epsilon_t,
|
||
|
self._weights,
|
||
|
self._kernel_size,
|
||
|
self._stride,
|
||
|
self._dilation,
|
||
|
self._padding,
|
||
|
self._output_size,
|
||
|
self._number_of_spikes,
|
||
|
self._number_of_cpu_processes,
|
||
|
self._h_initial,
|
||
|
self._alpha_number_of_iterations,
|
||
|
)
|
||
|
|
||
|
####################################################################
|
||
|
# Helper functions #
|
||
|
####################################################################
|
||
|
|
||
|
def calculate_output_size(self, value: torch.Tensor) -> None:
|
||
|
|
||
|
coordinates_0, coordinates_1 = self._get_coordinates(value)
|
||
|
|
||
|
self._output_size: torch.Tensor = torch.tensor(
|
||
|
[
|
||
|
coordinates_0.shape[1],
|
||
|
coordinates_1.shape[1],
|
||
|
],
|
||
|
dtype=torch.int64,
|
||
|
)
|
||
|
self._output_size.requires_grad_(False)
|
||
|
|
||
|
def _get_coordinates(
|
||
|
self, value: torch.Tensor
|
||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||
|
"""Function converts parameter in coordinates
|
||
|
for the convolution window"""
|
||
|
|
||
|
assert value is not None
|
||
|
assert torch.is_tensor(value) is True
|
||
|
assert value.dim() == 1
|
||
|
assert torch.numel(value) == 2
|
||
|
assert value.dtype == torch.int64
|
||
|
assert value[0] > 0
|
||
|
assert value[1] > 0
|
||
|
|
||
|
assert self._kernel_size is not None
|
||
|
assert self._stride is not None
|
||
|
assert self._dilation is not None
|
||
|
assert self._padding is not None
|
||
|
|
||
|
assert torch.numel(self._kernel_size) == 2
|
||
|
assert torch.numel(self._stride) == 2
|
||
|
assert torch.numel(self._dilation) == 2
|
||
|
assert torch.numel(self._padding) == 2
|
||
|
|
||
|
unfold_0: torch.nn.Unfold = torch.nn.Unfold(
|
||
|
kernel_size=(int(self._kernel_size[0]), 1),
|
||
|
dilation=int(self._dilation[0]),
|
||
|
padding=int(self._padding[0]),
|
||
|
stride=int(self._stride[0]),
|
||
|
)
|
||
|
|
||
|
unfold_1: torch.nn.Unfold = torch.nn.Unfold(
|
||
|
kernel_size=(1, int(self._kernel_size[1])),
|
||
|
dilation=int(self._dilation[1]),
|
||
|
padding=int(self._padding[1]),
|
||
|
stride=int(self._stride[1]),
|
||
|
)
|
||
|
|
||
|
coordinates_0: torch.Tensor = (
|
||
|
unfold_0(
|
||
|
torch.unsqueeze(
|
||
|
torch.unsqueeze(
|
||
|
torch.unsqueeze(
|
||
|
torch.arange(0, int(value[0]), dtype=torch.float64),
|
||
|
1,
|
||
|
),
|
||
|
0,
|
||
|
),
|
||
|
0,
|
||
|
)
|
||
|
)
|
||
|
.squeeze(0)
|
||
|
.type(torch.int64)
|
||
|
)
|
||
|
|
||
|
coordinates_1: torch.Tensor = (
|
||
|
unfold_1(
|
||
|
torch.unsqueeze(
|
||
|
torch.unsqueeze(
|
||
|
torch.unsqueeze(
|
||
|
torch.arange(0, int(value[1]), dtype=torch.float64),
|
||
|
0,
|
||
|
),
|
||
|
0,
|
||
|
),
|
||
|
0,
|
||
|
)
|
||
|
)
|
||
|
.squeeze(0)
|
||
|
.type(torch.int64)
|
||
|
)
|
||
|
|
||
|
return coordinates_0, coordinates_1
|
||
|
|
||
|
def _initial_random_weights(self, noise_amplitude: torch.Tensor) -> torch.Tensor:
|
||
|
"""Creates initial weights
|
||
|
Uniform plus random noise plus normalization
|
||
|
"""
|
||
|
|
||
|
assert torch.numel(noise_amplitude) == 1
|
||
|
assert noise_amplitude.item() >= 0
|
||
|
assert noise_amplitude.dtype == torch.float64
|
||
|
|
||
|
assert self._number_of_neurons is not None
|
||
|
assert self._number_of_input_neurons is not None
|
||
|
assert self._kernel_size is not None
|
||
|
|
||
|
weights = torch.empty(
|
||
|
(
|
||
|
int(self._kernel_size[0]),
|
||
|
int(self._kernel_size[1]),
|
||
|
int(self._number_of_input_neurons),
|
||
|
int(self._number_of_neurons),
|
||
|
),
|
||
|
dtype=torch.float64,
|
||
|
)
|
||
|
torch.nn.init.uniform_(weights, a=1.0, b=(1.0 + noise_amplitude.item()))
|
||
|
|
||
|
return weights
|
||
|
|
||
|
def _make_pooling_weights(self) -> torch.Tensor:
|
||
|
"""For generating the pooling weights."""
|
||
|
|
||
|
assert self._number_of_neurons is not None
|
||
|
assert self._kernel_size is not None
|
||
|
|
||
|
norm: float = 1.0 / (self._kernel_size[0] * self._kernel_size[1])
|
||
|
|
||
|
weights: torch.Tensor = torch.zeros(
|
||
|
(
|
||
|
int(self._kernel_size[0]),
|
||
|
int(self._kernel_size[1]),
|
||
|
int(self._number_of_neurons),
|
||
|
int(self._number_of_neurons),
|
||
|
),
|
||
|
dtype=torch.float64,
|
||
|
)
|
||
|
|
||
|
for i in range(0, int(self._number_of_neurons)):
|
||
|
weights[:, :, i, i] = norm
|
||
|
|
||
|
return weights
|
||
|
|
||
|
def initialize_weights(
|
||
|
self,
|
||
|
is_pooling_layer: bool = False,
|
||
|
noise_amplitude: float = 0.01,
|
||
|
) -> None:
|
||
|
"""For the generation of the initital weights.
|
||
|
Switches between normal initial random weights and pooling weights."""
|
||
|
|
||
|
assert self._kernel_size is not None
|
||
|
|
||
|
if is_pooling_layer is True:
|
||
|
weights = self._make_pooling_weights()
|
||
|
else:
|
||
|
weights = self._initial_random_weights(
|
||
|
torch.tensor(noise_amplitude, dtype=torch.float64)
|
||
|
)
|
||
|
|
||
|
weights = weights.moveaxis(-1, 0).moveaxis(-1, 1)
|
||
|
|
||
|
weights_t = torch.nn.functional.unfold(
|
||
|
input=weights,
|
||
|
kernel_size=(int(self._kernel_size[0]), int(self._kernel_size[1])),
|
||
|
dilation=(1, 1),
|
||
|
padding=(0, 0),
|
||
|
stride=(1, 1),
|
||
|
).squeeze()
|
||
|
|
||
|
weights_t = torch.moveaxis(weights_t, 0, 1)
|
||
|
|
||
|
self.weights = weights_t
|
||
|
|
||
|
def initialize_epsilon_xy(
|
||
|
self,
|
||
|
eps_xy_intitial: float,
|
||
|
) -> None:
|
||
|
"""Creates initial epsilon xy matrices"""
|
||
|
|
||
|
assert self._output_size is not None
|
||
|
assert eps_xy_intitial > 0
|
||
|
|
||
|
eps_xy_temp: torch.Tensor = torch.full(
|
||
|
(int(self._output_size[0]), int(self._output_size[1])),
|
||
|
eps_xy_intitial,
|
||
|
dtype=torch.float64,
|
||
|
)
|
||
|
|
||
|
self.epsilon_xy = eps_xy_temp
|
||
|
|
||
|
def set_h_init_to_uniform(self) -> None:
|
||
|
|
||
|
assert self._number_of_neurons is not None
|
||
|
|
||
|
h_initial: torch.Tensor = torch.full(
|
||
|
(int(self._number_of_neurons.item()),),
|
||
|
(1.0 / float(self._number_of_neurons.item())),
|
||
|
dtype=torch.float32,
|
||
|
)
|
||
|
|
||
|
self.h_initial = h_initial
|
||
|
|
||
|
# Epsilon XY
|
||
|
def backup_epsilon_xy(self) -> None:
|
||
|
assert self._epsilon_xy_exists is True
|
||
|
self._epsilon_xy_backup = self._epsilon_xy.data.clone()
|
||
|
|
||
|
def restore_epsilon_xy(self) -> None:
|
||
|
assert self._epsilon_xy_backup is not None
|
||
|
assert self._epsilon_xy_exists is True
|
||
|
self._epsilon_xy.data = self._epsilon_xy_backup.clone()
|
||
|
|
||
|
def mean_epsilon_xy(self) -> None:
|
||
|
assert self._epsilon_xy_exists is True
|
||
|
|
||
|
fill_value: float = float(self._epsilon_xy.data.mean())
|
||
|
self._epsilon_xy.data = torch.full_like(
|
||
|
self._epsilon_xy.data, fill_value, dtype=torch.float64
|
||
|
)
|
||
|
|
||
|
def threshold_epsilon_xy(self, threshold: float) -> None:
|
||
|
assert self._epsilon_xy_exists is True
|
||
|
assert threshold >= 0
|
||
|
torch.clamp(
|
||
|
self._epsilon_xy.data,
|
||
|
min=float(threshold),
|
||
|
max=None,
|
||
|
out=self._epsilon_xy.data,
|
||
|
)
|
||
|
|
||
|
# Weights
|
||
|
def backup_weights(self) -> None:
|
||
|
assert self._weights_exists is True
|
||
|
self._weights_backup = self._weights.data.clone()
|
||
|
|
||
|
def restore_weights(self) -> None:
|
||
|
assert self._weights_backup is not None
|
||
|
assert self._weights_exists is True
|
||
|
self._weights.data = self._weights_backup.clone()
|
||
|
|
||
|
def norm_weights(self) -> None:
|
||
|
assert self._weights_exists is True
|
||
|
temp: torch.Tensor = (
|
||
|
self._weights.data.detach()
|
||
|
.clone(memory_format=torch.contiguous_format)
|
||
|
.type(dtype=torch.float64)
|
||
|
)
|
||
|
temp /= temp.sum(dim=0, keepdim=True, dtype=torch.float64)
|
||
|
self._weights.data = temp
|
||
|
|
||
|
def threshold_weights(self, threshold: float) -> None:
|
||
|
assert self._weights_exists is True
|
||
|
assert threshold >= 0
|
||
|
torch.clamp(
|
||
|
self._weights.data,
|
||
|
min=float(threshold),
|
||
|
max=None,
|
||
|
out=self._weights.data,
|
||
|
)
|
||
|
|
||
|
|
||
|
class FunctionalSbS(torch.autograd.Function):
|
||
|
@staticmethod
|
||
|
def forward( # type: ignore
|
||
|
ctx,
|
||
|
input_float64: torch.Tensor,
|
||
|
epsilon_xy_float64: torch.Tensor,
|
||
|
epsilon_0_float64: torch.Tensor,
|
||
|
epsilon_t_float64: torch.Tensor,
|
||
|
weights_float64: torch.Tensor,
|
||
|
kernel_size: torch.Tensor,
|
||
|
stride: torch.Tensor,
|
||
|
dilation: torch.Tensor,
|
||
|
padding: torch.Tensor,
|
||
|
output_size: torch.Tensor,
|
||
|
number_of_spikes: torch.Tensor,
|
||
|
number_of_cpu_processes: torch.Tensor,
|
||
|
h_initial: torch.Tensor,
|
||
|
alpha_number_of_iterations: torch.Tensor,
|
||
|
) -> torch.Tensor:
|
||
|
|
||
|
input = input_float64.type(dtype=torch.float32)
|
||
|
epsilon_xy = epsilon_xy_float64.type(dtype=torch.float32)
|
||
|
weights = weights_float64.type(dtype=torch.float32)
|
||
|
epsilon_0 = epsilon_0_float64.type(dtype=torch.float32)
|
||
|
epsilon_t = epsilon_t_float64.type(dtype=torch.float32)
|
||
|
|
||
|
assert input.dim() == 4
|
||
|
assert torch.numel(kernel_size) == 2
|
||
|
assert torch.numel(dilation) == 2
|
||
|
assert torch.numel(padding) == 2
|
||
|
assert torch.numel(stride) == 2
|
||
|
assert torch.numel(output_size) == 2
|
||
|
|
||
|
assert torch.numel(epsilon_0) == 1
|
||
|
assert torch.numel(number_of_spikes) == 1
|
||
|
assert torch.numel(number_of_cpu_processes) == 1
|
||
|
assert torch.numel(alpha_number_of_iterations) == 1
|
||
|
|
||
|
input_size = torch.tensor([input.shape[2], input.shape[3]])
|
||
|
|
||
|
############################################################
|
||
|
# Pre convolving the input #
|
||
|
############################################################
|
||
|
|
||
|
input_convolved_temp = torch.nn.functional.unfold(
|
||
|
input,
|
||
|
kernel_size=tuple(kernel_size.tolist()),
|
||
|
dilation=tuple(dilation.tolist()),
|
||
|
padding=tuple(padding.tolist()),
|
||
|
stride=tuple(stride.tolist()),
|
||
|
)
|
||
|
|
||
|
input_convolved = torch.nn.functional.fold(
|
||
|
input_convolved_temp,
|
||
|
output_size=tuple(output_size.tolist()),
|
||
|
kernel_size=(1, 1),
|
||
|
dilation=(1, 1),
|
||
|
padding=(0, 0),
|
||
|
stride=(1, 1),
|
||
|
).requires_grad_(True)
|
||
|
|
||
|
############################################################
|
||
|
# Spike generation #
|
||
|
############################################################
|
||
|
|
||
|
if cpp_spike is False:
|
||
|
# Alternative to the C++ module but 5x slower:
|
||
|
spikes = (
|
||
|
(
|
||
|
input_convolved.movedim(source=(2, 3), destination=(0, 1))
|
||
|
.reshape(
|
||
|
shape=(
|
||
|
input_convolved.shape[2]
|
||
|
* input_convolved.shape[3]
|
||
|
* input_convolved.shape[0],
|
||
|
input_convolved.shape[1],
|
||
|
)
|
||
|
)
|
||
|
.multinomial(
|
||
|
num_samples=int(number_of_spikes.item()), replacement=True
|
||
|
)
|
||
|
)
|
||
|
.reshape(
|
||
|
shape=(
|
||
|
input_convolved.shape[2],
|
||
|
input_convolved.shape[3],
|
||
|
input_convolved.shape[0],
|
||
|
int(number_of_spikes.item()),
|
||
|
)
|
||
|
)
|
||
|
.movedim(source=(0, 1), destination=(2, 3))
|
||
|
).contiguous(memory_format=torch.contiguous_format)
|
||
|
else:
|
||
|
# Normalized cumsum
|
||
|
input_cumsum: torch.Tensor = torch.cumsum(
|
||
|
input_convolved, dim=1, dtype=torch.float32
|
||
|
)
|
||
|
input_cumsum_last: torch.Tensor = input_cumsum[:, -1, :, :].unsqueeze(1)
|
||
|
input_cumsum /= input_cumsum_last
|
||
|
|
||
|
random_values = torch.rand(
|
||
|
size=[
|
||
|
input_cumsum.shape[0],
|
||
|
int(number_of_spikes.item()),
|
||
|
input_cumsum.shape[2],
|
||
|
input_cumsum.shape[3],
|
||
|
],
|
||
|
dtype=torch.float32,
|
||
|
)
|
||
|
|
||
|
spikes = torch.empty_like(random_values, dtype=torch.int64)
|
||
|
|
||
|
# Prepare for Export (Pointer and stuff)->
|
||
|
np_input: np.ndarray = input_cumsum.detach().numpy()
|
||
|
assert input_cumsum.dtype == torch.float32
|
||
|
assert np_input.flags["C_CONTIGUOUS"] is True
|
||
|
assert np_input.ndim == 4
|
||
|
|
||
|
np_random_values: np.ndarray = random_values.detach().numpy()
|
||
|
assert random_values.dtype == torch.float32
|
||
|
assert np_random_values.flags["C_CONTIGUOUS"] is True
|
||
|
assert np_random_values.ndim == 4
|
||
|
|
||
|
np_spikes: np.ndarray = spikes.detach().numpy()
|
||
|
assert spikes.dtype == torch.int64
|
||
|
assert np_spikes.flags["C_CONTIGUOUS"] is True
|
||
|
assert np_spikes.ndim == 4
|
||
|
# <- Prepare for Export
|
||
|
|
||
|
spike_generation: PySpikeGeneration2DManyIP.SpikeGeneration2DManyIP = (
|
||
|
PySpikeGeneration2DManyIP.SpikeGeneration2DManyIP()
|
||
|
)
|
||
|
|
||
|
spike_generation.spike_generation_multi_pattern(
|
||
|
np_input.__array_interface__["data"][0],
|
||
|
int(np_input.shape[0]),
|
||
|
int(np_input.shape[1]),
|
||
|
int(np_input.shape[2]),
|
||
|
int(np_input.shape[3]),
|
||
|
np_random_values.__array_interface__["data"][0],
|
||
|
int(np_random_values.shape[0]),
|
||
|
int(np_random_values.shape[1]),
|
||
|
int(np_random_values.shape[2]),
|
||
|
int(np_random_values.shape[3]),
|
||
|
np_spikes.__array_interface__["data"][0],
|
||
|
int(np_spikes.shape[0]),
|
||
|
int(np_spikes.shape[1]),
|
||
|
int(np_spikes.shape[2]),
|
||
|
int(np_spikes.shape[3]),
|
||
|
int(number_of_cpu_processes.item()),
|
||
|
)
|
||
|
|
||
|
############################################################
|
||
|
# H dynamic #
|
||
|
############################################################
|
||
|
|
||
|
assert epsilon_t.ndim == 1
|
||
|
assert epsilon_t.shape[0] >= number_of_spikes
|
||
|
|
||
|
if cpp_sbs is False:
|
||
|
h = torch.tile(
|
||
|
h_initial.unsqueeze(0).unsqueeze(0).unsqueeze(0),
|
||
|
dims=[int(input.shape[0]), int(output_size[0]), int(output_size[1]), 1],
|
||
|
)
|
||
|
|
||
|
epsilon_scale: torch.Tensor = torch.ones(
|
||
|
size=[1, int(epsilon_xy.shape[0]), int(epsilon_xy.shape[1]), 1],
|
||
|
dtype=torch.float32,
|
||
|
)
|
||
|
|
||
|
for t in range(0, spikes.shape[1]):
|
||
|
|
||
|
if epsilon_scale.max() > 1e10:
|
||
|
h /= epsilon_scale
|
||
|
epsilon_scale = torch.ones_like(epsilon_scale)
|
||
|
|
||
|
h_temp: torch.Tensor = weights[spikes[:, t, :, :], :] * h
|
||
|
epsilon_subsegment: torch.Tensor = (
|
||
|
epsilon_xy.unsqueeze(0).unsqueeze(-1) * epsilon_t[t] * epsilon_0
|
||
|
)
|
||
|
h_temp_sum: torch.Tensor = (
|
||
|
epsilon_scale * epsilon_subsegment / h_temp.sum(dim=3, keepdim=True)
|
||
|
)
|
||
|
torch.nan_to_num(
|
||
|
h_temp_sum, out=h_temp_sum, nan=0.0, posinf=0.0, neginf=0.0
|
||
|
)
|
||
|
h_temp *= h_temp_sum
|
||
|
h += h_temp
|
||
|
|
||
|
epsilon_scale *= 1.0 + epsilon_subsegment
|
||
|
|
||
|
h /= epsilon_scale
|
||
|
output = h.movedim(3, 1)
|
||
|
else:
|
||
|
epsilon_t_0: torch.Tensor = epsilon_t * epsilon_0
|
||
|
|
||
|
h_shape: tuple[int, int, int, int] = (
|
||
|
int(input.shape[0]),
|
||
|
int(weights.shape[1]),
|
||
|
int(output_size[0]),
|
||
|
int(output_size[1]),
|
||
|
)
|
||
|
|
||
|
output = torch.empty(h_shape, dtype=torch.float32)
|
||
|
|
||
|
# Prepare the export to C++ ->
|
||
|
np_h: np.ndarray = output.detach().numpy()
|
||
|
assert output.dtype == torch.float32
|
||
|
assert np_h.flags["C_CONTIGUOUS"] is True
|
||
|
assert np_h.ndim == 4
|
||
|
|
||
|
np_epsilon_xy: np.ndarray = epsilon_xy.detach().numpy()
|
||
|
assert epsilon_xy.dtype == torch.float32
|
||
|
assert np_epsilon_xy.flags["C_CONTIGUOUS"] is True
|
||
|
assert np_epsilon_xy.ndim == 2
|
||
|
|
||
|
np_epsilon_t: np.ndarray = epsilon_t_0.detach().numpy()
|
||
|
assert epsilon_t_0.dtype == torch.float32
|
||
|
assert np_epsilon_t.flags["C_CONTIGUOUS"] is True
|
||
|
assert np_epsilon_t.ndim == 1
|
||
|
|
||
|
np_weights: np.ndarray = weights.detach().numpy()
|
||
|
assert weights.dtype == torch.float32
|
||
|
assert np_weights.flags["C_CONTIGUOUS"] is True
|
||
|
assert np_weights.ndim == 2
|
||
|
|
||
|
np_spikes = spikes.contiguous().detach().numpy()
|
||
|
assert spikes.dtype == torch.int64
|
||
|
assert np_spikes.flags["C_CONTIGUOUS"] is True
|
||
|
assert np_spikes.ndim == 4
|
||
|
|
||
|
np_h_initial = h_initial.contiguous().detach().numpy()
|
||
|
assert h_initial.dtype == torch.float32
|
||
|
assert np_h_initial.flags["C_CONTIGUOUS"] is True
|
||
|
assert np_h_initial.ndim == 1
|
||
|
# <- Prepare the export to C++
|
||
|
|
||
|
h_dynamic: PyHDynamicCNNManyIP.HDynamicCNNManyIP = (
|
||
|
PyHDynamicCNNManyIP.HDynamicCNNManyIP()
|
||
|
)
|
||
|
|
||
|
h_dynamic.update_with_init_vector_multi_pattern(
|
||
|
np_h.__array_interface__["data"][0],
|
||
|
int(np_h.shape[0]),
|
||
|
int(np_h.shape[1]),
|
||
|
int(np_h.shape[2]),
|
||
|
int(np_h.shape[3]),
|
||
|
np_epsilon_xy.__array_interface__["data"][0],
|
||
|
int(np_epsilon_xy.shape[0]),
|
||
|
int(np_epsilon_xy.shape[1]),
|
||
|
np_epsilon_t.__array_interface__["data"][0],
|
||
|
int(np_epsilon_t.shape[0]),
|
||
|
np_weights.__array_interface__["data"][0],
|
||
|
int(np_weights.shape[0]),
|
||
|
int(np_weights.shape[1]),
|
||
|
np_spikes.__array_interface__["data"][0],
|
||
|
int(np_spikes.shape[0]),
|
||
|
int(np_spikes.shape[1]),
|
||
|
int(np_spikes.shape[2]),
|
||
|
int(np_spikes.shape[3]),
|
||
|
np_h_initial.__array_interface__["data"][0],
|
||
|
int(np_h_initial.shape[0]),
|
||
|
int(number_of_cpu_processes.item()),
|
||
|
)
|
||
|
|
||
|
############################################################
|
||
|
# Alpha #
|
||
|
############################################################
|
||
|
alpha_number_of_iterations_int: int = int(alpha_number_of_iterations.item())
|
||
|
|
||
|
if alpha_number_of_iterations_int > 0:
|
||
|
# Initialization
|
||
|
virtual_reconstruction_weight: torch.Tensor = torch.einsum(
|
||
|
"bixy,ji->bjxy", output, weights
|
||
|
)
|
||
|
alpha_fill_value: float = 1.0 / (
|
||
|
virtual_reconstruction_weight.shape[2]
|
||
|
* virtual_reconstruction_weight.shape[3]
|
||
|
)
|
||
|
alpha_dynamic: torch.Tensor = torch.full(
|
||
|
(
|
||
|
int(virtual_reconstruction_weight.shape[0]),
|
||
|
1,
|
||
|
int(virtual_reconstruction_weight.shape[2]),
|
||
|
int(virtual_reconstruction_weight.shape[3]),
|
||
|
),
|
||
|
alpha_fill_value,
|
||
|
dtype=torch.float32,
|
||
|
)
|
||
|
|
||
|
# Iterations
|
||
|
for _ in range(0, alpha_number_of_iterations_int):
|
||
|
alpha_temp: torch.Tensor = alpha_dynamic * virtual_reconstruction_weight
|
||
|
alpha_temp /= alpha_temp.sum(dim=3, keepdim=True).sum(
|
||
|
dim=2, keepdim=True
|
||
|
)
|
||
|
torch.nan_to_num(
|
||
|
alpha_temp, out=alpha_temp, nan=0.0, posinf=0.0, neginf=0.0
|
||
|
)
|
||
|
|
||
|
alpha_temp = torch.nn.functional.unfold(
|
||
|
alpha_temp,
|
||
|
kernel_size=(1, 1),
|
||
|
dilation=1,
|
||
|
padding=0,
|
||
|
stride=1,
|
||
|
)
|
||
|
|
||
|
alpha_temp = torch.nn.functional.fold(
|
||
|
alpha_temp,
|
||
|
output_size=tuple(input_size.tolist()),
|
||
|
kernel_size=tuple(kernel_size.tolist()),
|
||
|
dilation=tuple(dilation.tolist()),
|
||
|
padding=tuple(padding.tolist()),
|
||
|
stride=tuple(stride.tolist()),
|
||
|
)
|
||
|
|
||
|
alpha_temp = (alpha_temp * input).sum(dim=1, keepdim=True)
|
||
|
|
||
|
alpha_temp = torch.nn.functional.unfold(
|
||
|
alpha_temp,
|
||
|
kernel_size=tuple(kernel_size.tolist()),
|
||
|
dilation=tuple(dilation.tolist()),
|
||
|
padding=tuple(padding.tolist()),
|
||
|
stride=tuple(stride.tolist()),
|
||
|
)
|
||
|
|
||
|
alpha_temp = torch.nn.functional.fold(
|
||
|
alpha_temp,
|
||
|
output_size=tuple(output_size.tolist()),
|
||
|
kernel_size=(1, 1),
|
||
|
dilation=(1, 1),
|
||
|
padding=(0, 0),
|
||
|
stride=(1, 1),
|
||
|
)
|
||
|
alpha_dynamic = alpha_temp.sum(dim=1, keepdim=True)
|
||
|
|
||
|
alpha_dynamic += torch.finfo(torch.float32).eps * 1000
|
||
|
|
||
|
# Alpha normalization
|
||
|
alpha_dynamic /= alpha_dynamic.sum(dim=3, keepdim=True).sum(
|
||
|
dim=2, keepdim=True
|
||
|
)
|
||
|
torch.nan_to_num(
|
||
|
alpha_dynamic, out=alpha_dynamic, nan=0.0, posinf=0.0, neginf=0.0
|
||
|
)
|
||
|
|
||
|
# Applied to the output
|
||
|
output *= alpha_dynamic
|
||
|
|
||
|
############################################################
|
||
|
# Save the necessary data for the backward pass #
|
||
|
############################################################
|
||
|
|
||
|
output = output.type(dtype=torch.float64)
|
||
|
|
||
|
ctx.save_for_backward(
|
||
|
input_convolved,
|
||
|
epsilon_xy_float64,
|
||
|
epsilon_0_float64,
|
||
|
weights_float64,
|
||
|
output,
|
||
|
kernel_size,
|
||
|
stride,
|
||
|
dilation,
|
||
|
padding,
|
||
|
input_size,
|
||
|
)
|
||
|
|
||
|
return output
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, grad_output):
|
||
|
|
||
|
# Get the variables back
|
||
|
(
|
||
|
input_float32,
|
||
|
epsilon_xy,
|
||
|
epsilon_0,
|
||
|
weights,
|
||
|
output,
|
||
|
kernel_size,
|
||
|
stride,
|
||
|
dilation,
|
||
|
padding,
|
||
|
input_size,
|
||
|
) = ctx.saved_tensors
|
||
|
|
||
|
input = input_float32.type(dtype=torch.float64)
|
||
|
input /= input.sum(dim=1, keepdim=True, dtype=torch.float64)
|
||
|
|
||
|
# For debugging:
|
||
|
# print(
|
||
|
# f"S: O: {output.min().item():e} {output.max().item():e} I: {input.min().item():e} {input.max().item():e} G: {grad_output.min().item():e} {grad_output.max().item():e}"
|
||
|
# )
|
||
|
|
||
|
epsilon_0_float: float = epsilon_0.item()
|
||
|
|
||
|
temp_e: torch.Tensor = 1.0 / ((epsilon_xy * epsilon_0_float) + 1.0)
|
||
|
|
||
|
eps_a: torch.Tensor = temp_e.clone()
|
||
|
eps_a *= epsilon_xy * epsilon_0_float
|
||
|
|
||
|
eps_b: torch.Tensor = temp_e**2 * epsilon_0_float
|
||
|
|
||
|
backprop_r: torch.Tensor = weights.unsqueeze(0).unsqueeze(-1).unsqueeze(
|
||
|
-1
|
||
|
) * output.unsqueeze(1)
|
||
|
|
||
|
backprop_bigr: torch.Tensor = backprop_r.sum(axis=2)
|
||
|
|
||
|
temp: torch.Tensor = input / backprop_bigr**2
|
||
|
|
||
|
backprop_f: torch.Tensor = output.unsqueeze(1) * temp.unsqueeze(2)
|
||
|
torch.nan_to_num(
|
||
|
backprop_f, out=backprop_f, nan=1e300, posinf=1e300, neginf=-1e300
|
||
|
)
|
||
|
torch.clip(backprop_f, out=backprop_f, min=-1e300, max=1e300)
|
||
|
|
||
|
tempz: torch.Tensor = 1.0 / backprop_bigr
|
||
|
|
||
|
backprop_z: torch.Tensor = backprop_r * tempz.unsqueeze(2)
|
||
|
torch.nan_to_num(
|
||
|
backprop_z, out=backprop_z, nan=1e300, posinf=1e300, neginf=-1e300
|
||
|
)
|
||
|
torch.clip(backprop_z, out=backprop_z, min=-1e300, max=1e300)
|
||
|
|
||
|
backprop_y: torch.Tensor = (
|
||
|
torch.einsum("bijxy,bixy->bjxy", backprop_z, input) - output
|
||
|
)
|
||
|
|
||
|
result_omega: torch.Tensor = backprop_bigr.unsqueeze(2) * grad_output.unsqueeze(
|
||
|
1
|
||
|
)
|
||
|
result_omega -= torch.einsum(
|
||
|
"bijxy,bjxy->bixy", backprop_r, grad_output
|
||
|
).unsqueeze(2)
|
||
|
result_omega *= backprop_f
|
||
|
torch.nan_to_num(
|
||
|
result_omega, out=result_omega, nan=1e300, posinf=1e300, neginf=-1e300
|
||
|
)
|
||
|
torch.clip(result_omega, out=result_omega, min=-1e300, max=1e300)
|
||
|
|
||
|
result_eps_xy: torch.Tensor = (
|
||
|
torch.einsum("bixy,bixy->bxy", backprop_y, grad_output) * eps_b
|
||
|
)
|
||
|
torch.nan_to_num(
|
||
|
result_eps_xy, out=result_eps_xy, nan=1e300, posinf=1e300, neginf=-1e300
|
||
|
)
|
||
|
torch.clip(result_eps_xy, out=result_eps_xy, min=-1e300, max=1e300)
|
||
|
|
||
|
result_phi: torch.Tensor = torch.einsum(
|
||
|
"bijxy,bjxy->bixy", backprop_z, grad_output
|
||
|
) * eps_a.unsqueeze(0).unsqueeze(0)
|
||
|
torch.nan_to_num(
|
||
|
result_phi, out=result_phi, nan=1e300, posinf=1e300, neginf=-1e300
|
||
|
)
|
||
|
torch.clip(result_phi, out=result_phi, min=-1e300, max=1e300)
|
||
|
|
||
|
grad_weights = result_omega.sum(0).sum(-1).sum(-1)
|
||
|
grad_eps_xy = result_eps_xy.sum(0)
|
||
|
|
||
|
grad_input = torch.nn.functional.fold(
|
||
|
torch.nn.functional.unfold(
|
||
|
result_phi,
|
||
|
kernel_size=(1, 1),
|
||
|
dilation=1,
|
||
|
padding=0,
|
||
|
stride=1,
|
||
|
),
|
||
|
output_size=input_size,
|
||
|
kernel_size=kernel_size,
|
||
|
dilation=dilation,
|
||
|
padding=padding,
|
||
|
stride=stride,
|
||
|
)
|
||
|
|
||
|
torch.nan_to_num(
|
||
|
grad_input, out=grad_input, nan=1e300, posinf=1e300, neginf=-1e300
|
||
|
)
|
||
|
torch.clip(grad_input, out=grad_input, min=-1e300, max=1e300)
|
||
|
|
||
|
torch.nan_to_num(
|
||
|
grad_eps_xy, out=grad_eps_xy, nan=1e300, posinf=1e300, neginf=-1e300
|
||
|
)
|
||
|
torch.clip(grad_eps_xy, out=grad_eps_xy, min=-1e300, max=1e300)
|
||
|
|
||
|
torch.nan_to_num(
|
||
|
grad_weights, out=grad_weights, nan=1e300, posinf=1e300, neginf=-1e300
|
||
|
)
|
||
|
torch.clip(grad_weights, out=grad_weights, min=-1e300, max=1e300)
|
||
|
|
||
|
grad_epsilon_0 = None
|
||
|
grad_epsilon_t = None
|
||
|
grad_kernel_size = None
|
||
|
grad_stride = None
|
||
|
grad_dilation = None
|
||
|
grad_padding = None
|
||
|
grad_output_size = None
|
||
|
grad_number_of_spikes = None
|
||
|
grad_number_of_cpu_processes = None
|
||
|
grad_h_initial = None
|
||
|
grad_alpha_number_of_iterations = None
|
||
|
|
||
|
return (
|
||
|
grad_input,
|
||
|
grad_eps_xy,
|
||
|
grad_epsilon_0,
|
||
|
grad_epsilon_t,
|
||
|
grad_weights,
|
||
|
grad_kernel_size,
|
||
|
grad_stride,
|
||
|
grad_dilation,
|
||
|
grad_padding,
|
||
|
grad_output_size,
|
||
|
grad_number_of_spikes,
|
||
|
grad_number_of_cpu_processes,
|
||
|
grad_h_initial,
|
||
|
grad_alpha_number_of_iterations,
|
||
|
)
|