2023-02-04 14:24:47 +01:00
|
|
|
import torch
|
|
|
|
|
|
|
|
from network.PyHDynamicCNNCPU import HDynamicCNNCPU
|
|
|
|
from network.PyHDynamicCNNGPU import HDynamicCNNGPU
|
|
|
|
|
|
|
|
global_sbs_gpu_setting: list[torch.Tensor] = []
|
|
|
|
global_sbs_size: list[torch.Tensor] = []
|
|
|
|
global_sbs_hdynamic_cpp: list[HDynamicCNNCPU | HDynamicCNNGPU] = []
|
|
|
|
|
|
|
|
|
|
|
|
class HDynamicLayer(torch.nn.Module):
|
|
|
|
|
|
|
|
_sbs_gpu_setting_position: int
|
|
|
|
_sbs_hdynamic_cpp_position: int
|
|
|
|
_gpu_tuning_factor: int
|
|
|
|
_number_of_cpu_processes: int
|
|
|
|
_output_size: list[int]
|
|
|
|
_w_trainable: bool
|
|
|
|
_output_layer: bool
|
|
|
|
_local_learning: bool
|
|
|
|
device: torch.device
|
|
|
|
default_dtype: torch.dtype
|
|
|
|
|
2023-02-06 09:55:30 +01:00
|
|
|
_force_forward_h_dynamic_on_cpu: bool
|
|
|
|
|
2023-02-04 14:24:47 +01:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
output_size: list[int],
|
|
|
|
output_layer: bool = False,
|
|
|
|
local_learning: bool = False,
|
|
|
|
number_of_cpu_processes: int = 1,
|
|
|
|
w_trainable: bool = False,
|
|
|
|
skip_gradient_calculation: bool = False,
|
|
|
|
device: torch.device | None = None,
|
|
|
|
default_dtype: torch.dtype | None = None,
|
|
|
|
gpu_tuning_factor: int = 5,
|
2023-02-06 09:55:30 +01:00
|
|
|
force_forward_h_dynamic_on_cpu: bool = False,
|
2023-02-04 14:24:47 +01:00
|
|
|
) -> None:
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
assert device is not None
|
|
|
|
self.device = device
|
|
|
|
self.default_dtype = default_dtype
|
|
|
|
|
|
|
|
self._gpu_tuning_factor = int(gpu_tuning_factor)
|
|
|
|
self._number_of_cpu_processes = int(number_of_cpu_processes)
|
|
|
|
self._w_trainable = bool(w_trainable)
|
|
|
|
self._skip_gradient_calculation = bool(skip_gradient_calculation)
|
|
|
|
self._output_size = output_size
|
|
|
|
self._output_layer = bool(output_layer)
|
|
|
|
self._local_learning = bool(local_learning)
|
2023-02-06 09:55:30 +01:00
|
|
|
self._force_forward_h_dynamic_on_cpu = force_forward_h_dynamic_on_cpu
|
2023-02-04 14:24:47 +01:00
|
|
|
|
|
|
|
global_sbs_gpu_setting.append(torch.tensor([0]))
|
|
|
|
global_sbs_size.append(torch.tensor([0, 0, 0, 0]))
|
|
|
|
|
2023-02-06 09:55:30 +01:00
|
|
|
if (device == torch.device("cpu")) or (
|
|
|
|
self._force_forward_h_dynamic_on_cpu is True
|
|
|
|
):
|
2023-02-04 14:24:47 +01:00
|
|
|
global_sbs_hdynamic_cpp.append(HDynamicCNNCPU())
|
|
|
|
else:
|
|
|
|
global_sbs_hdynamic_cpp.append(HDynamicCNNGPU())
|
|
|
|
|
|
|
|
self._sbs_gpu_setting_position = len(global_sbs_gpu_setting) - 1
|
|
|
|
self._sbs_hdynamic_cpp_position = len(global_sbs_hdynamic_cpp) - 1
|
|
|
|
|
|
|
|
self.functional_sbs = FunctionalSbS.apply
|
|
|
|
|
|
|
|
####################################################################
|
|
|
|
# Forward #
|
|
|
|
####################################################################
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
input: torch.Tensor,
|
|
|
|
spike: torch.Tensor,
|
|
|
|
epsilon_xy: torch.Tensor,
|
|
|
|
epsilon_t_0: torch.Tensor,
|
|
|
|
weights: torch.Tensor,
|
|
|
|
h_initial: torch.Tensor,
|
|
|
|
last_grad_scale: torch.Tensor,
|
|
|
|
labels: torch.Tensor | None = None,
|
|
|
|
keep_last_grad_scale: bool = False,
|
|
|
|
disable_scale_grade: bool = True,
|
|
|
|
forgetting_offset: float = -1.0,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
|
|
if labels is None:
|
|
|
|
labels_copy: torch.Tensor = torch.tensor(
|
|
|
|
[], dtype=torch.int64, device=self.device
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
labels_copy = (
|
|
|
|
labels.detach().clone().type(dtype=torch.int64).to(device=self.device)
|
|
|
|
)
|
|
|
|
|
|
|
|
if (spike.shape[-2] * spike.shape[-1]) > self._gpu_tuning_factor:
|
|
|
|
gpu_tuning_factor = self._gpu_tuning_factor
|
|
|
|
else:
|
|
|
|
gpu_tuning_factor = 0
|
|
|
|
|
|
|
|
parameter_list = torch.tensor(
|
|
|
|
[
|
|
|
|
int(self._number_of_cpu_processes), # 0
|
|
|
|
int(self._output_size[0]), # 1
|
|
|
|
int(self._output_size[1]), # 2
|
|
|
|
int(gpu_tuning_factor), # 3
|
|
|
|
int(self._sbs_gpu_setting_position), # 4
|
|
|
|
int(self._sbs_hdynamic_cpp_position), # 5
|
|
|
|
int(self._w_trainable), # 6
|
|
|
|
int(disable_scale_grade), # 7
|
|
|
|
int(keep_last_grad_scale), # 8
|
|
|
|
int(self._skip_gradient_calculation), # 9
|
|
|
|
int(self._output_layer), # 10
|
|
|
|
int(self._local_learning), # 11
|
|
|
|
],
|
|
|
|
dtype=torch.int64,
|
|
|
|
)
|
|
|
|
|
|
|
|
# SbS forward functional
|
|
|
|
return self.functional_sbs(
|
|
|
|
input,
|
|
|
|
spike,
|
|
|
|
epsilon_xy,
|
|
|
|
epsilon_t_0,
|
|
|
|
weights,
|
|
|
|
h_initial,
|
|
|
|
parameter_list,
|
|
|
|
last_grad_scale,
|
|
|
|
torch.tensor(
|
|
|
|
forgetting_offset, device=self.device, dtype=self.default_dtype
|
|
|
|
),
|
|
|
|
labels_copy,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class FunctionalSbS(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
|
|
def forward( # type: ignore
|
|
|
|
ctx,
|
|
|
|
input: torch.Tensor,
|
|
|
|
spikes: torch.Tensor,
|
|
|
|
epsilon_xy: torch.Tensor | None,
|
|
|
|
epsilon_t_0: torch.Tensor,
|
|
|
|
weights: torch.Tensor,
|
|
|
|
h_initial: torch.Tensor,
|
|
|
|
parameter_list: torch.Tensor,
|
|
|
|
grad_output_scale: torch.Tensor,
|
|
|
|
forgetting_offset: torch.Tensor,
|
|
|
|
labels: torch.Tensor,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
|
|
number_of_spikes: int = int(spikes.shape[1])
|
|
|
|
|
|
|
|
output_size_0: int = int(parameter_list[1])
|
|
|
|
output_size_1: int = int(parameter_list[2])
|
|
|
|
gpu_tuning_factor: int = int(parameter_list[3])
|
|
|
|
|
|
|
|
sbs_gpu_setting_position = int(parameter_list[4])
|
|
|
|
sbs_hdynamic_cpp_position = int(parameter_list[5])
|
|
|
|
|
2023-02-06 09:55:30 +01:00
|
|
|
if (
|
|
|
|
isinstance(
|
|
|
|
global_sbs_hdynamic_cpp[sbs_hdynamic_cpp_position], HDynamicCNNCPU
|
|
|
|
)
|
|
|
|
is True
|
|
|
|
):
|
|
|
|
are_we_on_a_cpu: bool = True
|
|
|
|
work_device: torch.device = torch.device("cpu")
|
|
|
|
else:
|
|
|
|
are_we_on_a_cpu = False
|
|
|
|
work_device = input.device
|
|
|
|
|
|
|
|
target_device: torch.device = input.device
|
|
|
|
|
|
|
|
if target_device == work_device:
|
|
|
|
data_is_on_the_same_device: bool = True
|
|
|
|
else:
|
|
|
|
data_is_on_the_same_device = False
|
|
|
|
|
|
|
|
if are_we_on_a_cpu is True:
|
|
|
|
hdyn_number_of_cpu_processes: int = int(parameter_list[0])
|
|
|
|
else:
|
|
|
|
hdyn_number_of_cpu_processes = -1
|
|
|
|
|
2023-02-04 14:24:47 +01:00
|
|
|
# ###########################################################
|
|
|
|
# H dynamic
|
|
|
|
# ###########################################################
|
|
|
|
|
|
|
|
assert epsilon_t_0.ndim == 1
|
|
|
|
assert epsilon_t_0.shape[0] >= number_of_spikes
|
|
|
|
|
|
|
|
# ############################################
|
|
|
|
# Make space for the results
|
|
|
|
# ############################################
|
|
|
|
|
2023-02-06 09:55:30 +01:00
|
|
|
output_work: torch.Tensor = torch.empty(
|
2023-02-04 14:24:47 +01:00
|
|
|
(
|
|
|
|
int(input.shape[0]),
|
|
|
|
int(weights.shape[1]),
|
|
|
|
output_size_0,
|
|
|
|
output_size_1,
|
|
|
|
),
|
|
|
|
dtype=input.dtype,
|
2023-02-06 09:55:30 +01:00
|
|
|
device=work_device,
|
2023-02-04 14:24:47 +01:00
|
|
|
)
|
|
|
|
|
2023-02-06 09:55:30 +01:00
|
|
|
assert output_work.is_contiguous() is True
|
2023-02-04 14:24:47 +01:00
|
|
|
if epsilon_xy is not None:
|
|
|
|
assert epsilon_xy.is_contiguous() is True
|
|
|
|
assert epsilon_xy.ndim == 3
|
2023-02-06 09:55:30 +01:00
|
|
|
if data_is_on_the_same_device is False:
|
|
|
|
epsilon_xy_work = epsilon_xy.to(work_device)
|
|
|
|
else:
|
|
|
|
epsilon_xy_work = epsilon_xy
|
|
|
|
else:
|
|
|
|
epsilon_xy_work = None
|
|
|
|
|
2023-02-04 14:24:47 +01:00
|
|
|
assert epsilon_t_0.is_contiguous() is True
|
2023-02-06 09:55:30 +01:00
|
|
|
if data_is_on_the_same_device is False:
|
|
|
|
epsilon_t_0_work = epsilon_t_0.to(work_device)
|
|
|
|
else:
|
|
|
|
epsilon_t_0_work = epsilon_t_0
|
|
|
|
|
2023-02-04 14:24:47 +01:00
|
|
|
assert weights.is_contiguous() is True
|
2023-02-06 09:55:30 +01:00
|
|
|
if data_is_on_the_same_device is False:
|
|
|
|
weights_work = weights.to(work_device)
|
|
|
|
else:
|
|
|
|
weights_work = weights
|
|
|
|
|
2023-02-04 14:24:47 +01:00
|
|
|
assert spikes.is_contiguous() is True
|
2023-02-06 09:55:30 +01:00
|
|
|
if data_is_on_the_same_device is False:
|
|
|
|
spikes_work = spikes.to(work_device)
|
|
|
|
else:
|
|
|
|
spikes_work = spikes
|
|
|
|
|
2023-02-04 14:24:47 +01:00
|
|
|
assert h_initial.is_contiguous() is True
|
2023-02-06 09:55:30 +01:00
|
|
|
if data_is_on_the_same_device is False:
|
|
|
|
h_initial_work = h_initial.to(work_device)
|
|
|
|
else:
|
|
|
|
h_initial_work = h_initial
|
2023-02-04 14:24:47 +01:00
|
|
|
|
|
|
|
assert weights.ndim == 2
|
|
|
|
assert h_initial.ndim == 1
|
|
|
|
|
|
|
|
sbs_profile = global_sbs_gpu_setting[sbs_gpu_setting_position].clone()
|
|
|
|
|
|
|
|
sbs_size = global_sbs_size[sbs_gpu_setting_position].clone()
|
|
|
|
|
2023-02-06 09:55:30 +01:00
|
|
|
if are_we_on_a_cpu is False:
|
2023-02-04 14:24:47 +01:00
|
|
|
if (
|
|
|
|
(sbs_profile.numel() == 1)
|
2023-02-06 09:55:30 +01:00
|
|
|
or (sbs_size[0] != int(output_work.shape[0]))
|
|
|
|
or (sbs_size[1] != int(output_work.shape[1]))
|
|
|
|
or (sbs_size[2] != int(output_work.shape[2]))
|
|
|
|
or (sbs_size[3] != int(output_work.shape[3]))
|
2023-02-04 14:24:47 +01:00
|
|
|
):
|
|
|
|
sbs_profile = torch.zeros(
|
|
|
|
(14, 7), dtype=torch.int64, device=torch.device("cpu")
|
|
|
|
)
|
|
|
|
|
|
|
|
global_sbs_hdynamic_cpp[sbs_hdynamic_cpp_position].gpu_occupancy_export(
|
2023-02-06 09:55:30 +01:00
|
|
|
int(output_work.shape[2]),
|
|
|
|
int(output_work.shape[3]),
|
|
|
|
int(output_work.shape[0]),
|
|
|
|
int(output_work.shape[1]),
|
2023-02-04 14:24:47 +01:00
|
|
|
sbs_profile.data_ptr(),
|
|
|
|
int(sbs_profile.shape[0]),
|
|
|
|
int(sbs_profile.shape[1]),
|
|
|
|
)
|
|
|
|
global_sbs_gpu_setting[sbs_gpu_setting_position] = sbs_profile.clone()
|
2023-02-06 09:55:30 +01:00
|
|
|
sbs_size[0] = int(output_work.shape[0])
|
|
|
|
sbs_size[1] = int(output_work.shape[1])
|
|
|
|
sbs_size[2] = int(output_work.shape[2])
|
|
|
|
sbs_size[3] = int(output_work.shape[3])
|
2023-02-04 14:24:47 +01:00
|
|
|
global_sbs_size[sbs_gpu_setting_position] = sbs_size.clone()
|
|
|
|
|
|
|
|
else:
|
|
|
|
global_sbs_hdynamic_cpp[sbs_hdynamic_cpp_position].gpu_occupancy_import(
|
|
|
|
sbs_profile.data_ptr(),
|
|
|
|
int(sbs_profile.shape[0]),
|
|
|
|
int(sbs_profile.shape[1]),
|
|
|
|
)
|
|
|
|
|
|
|
|
global_sbs_hdynamic_cpp[sbs_hdynamic_cpp_position].update(
|
2023-02-06 09:55:30 +01:00
|
|
|
output_work.data_ptr(),
|
|
|
|
int(output_work.shape[0]),
|
|
|
|
int(output_work.shape[1]),
|
|
|
|
int(output_work.shape[2]),
|
|
|
|
int(output_work.shape[3]),
|
|
|
|
epsilon_xy_work.data_ptr() if epsilon_xy_work is not None else int(0),
|
|
|
|
int(epsilon_xy_work.shape[0]) if epsilon_xy_work is not None else int(0),
|
|
|
|
int(epsilon_xy_work.shape[1]) if epsilon_xy_work is not None else int(0),
|
|
|
|
int(epsilon_xy_work.shape[2]) if epsilon_xy_work is not None else int(0),
|
|
|
|
epsilon_t_0_work.data_ptr(),
|
|
|
|
int(epsilon_t_0_work.shape[0]),
|
|
|
|
weights_work.data_ptr(),
|
|
|
|
int(weights_work.shape[0]),
|
|
|
|
int(weights_work.shape[1]),
|
|
|
|
spikes_work.data_ptr(),
|
|
|
|
int(spikes_work.shape[0]),
|
|
|
|
int(spikes_work.shape[1]),
|
|
|
|
int(spikes_work.shape[2]),
|
|
|
|
int(spikes_work.shape[3]),
|
|
|
|
h_initial_work.data_ptr(),
|
|
|
|
int(h_initial_work.shape[0]),
|
2023-02-04 14:24:47 +01:00
|
|
|
hdyn_number_of_cpu_processes,
|
|
|
|
float(forgetting_offset.cpu().item()),
|
|
|
|
int(gpu_tuning_factor),
|
|
|
|
)
|
|
|
|
|
2023-02-06 09:55:30 +01:00
|
|
|
if data_is_on_the_same_device is False:
|
|
|
|
output = output_work.to(target_device)
|
|
|
|
else:
|
|
|
|
output = output_work
|
|
|
|
|
|
|
|
# print(output)
|
|
|
|
# print(output.sum(dim=1))
|
|
|
|
# print(output.sum(dim=1).shape)
|
|
|
|
# exit()
|
2023-02-04 14:24:47 +01:00
|
|
|
# ###########################################################
|
|
|
|
# Save the necessary data for the backward pass
|
|
|
|
# ###########################################################
|
|
|
|
|
|
|
|
ctx.save_for_backward(
|
|
|
|
input,
|
|
|
|
weights,
|
|
|
|
output,
|
|
|
|
parameter_list,
|
|
|
|
grad_output_scale,
|
|
|
|
labels,
|
|
|
|
)
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, grad_output):
|
|
|
|
# ##############################################
|
|
|
|
# Get the variables back
|
|
|
|
# ##############################################
|
|
|
|
(
|
|
|
|
input,
|
|
|
|
weights,
|
|
|
|
output,
|
|
|
|
parameter_list,
|
|
|
|
last_grad_scale,
|
|
|
|
labels,
|
|
|
|
) = ctx.saved_tensors
|
|
|
|
|
|
|
|
assert labels.numel() > 0
|
|
|
|
|
|
|
|
# ##############################################
|
|
|
|
# Default output
|
|
|
|
# ##############################################
|
|
|
|
grad_input = None
|
|
|
|
grad_spikes = None
|
|
|
|
grad_eps_xy = None
|
|
|
|
grad_epsilon_t_0 = None
|
|
|
|
grad_weights = None
|
|
|
|
grad_h_initial = None
|
|
|
|
grad_parameter_list = None
|
|
|
|
grad_forgetting_offset = None
|
|
|
|
grad_labels = None
|
|
|
|
|
|
|
|
# ##############################################
|
|
|
|
# Parameters
|
|
|
|
# ##############################################
|
|
|
|
parameter_w_trainable: bool = bool(parameter_list[6])
|
|
|
|
parameter_disable_scale_grade: bool = bool(parameter_list[7])
|
|
|
|
parameter_keep_last_grad_scale: bool = bool(parameter_list[8])
|
|
|
|
parameter_skip_gradient_calculation: bool = bool(parameter_list[9])
|
|
|
|
parameter_output_layer: bool = bool(parameter_list[10])
|
|
|
|
parameter_local_learning: bool = bool(parameter_list[11])
|
|
|
|
|
|
|
|
# ##############################################
|
|
|
|
# Dealing with overall scale of the gradient
|
|
|
|
# ##############################################
|
|
|
|
if parameter_disable_scale_grade is False:
|
|
|
|
if parameter_keep_last_grad_scale is True:
|
|
|
|
last_grad_scale = torch.tensor(
|
|
|
|
[torch.abs(grad_output).max(), last_grad_scale]
|
|
|
|
).max()
|
|
|
|
grad_output /= last_grad_scale
|
|
|
|
grad_output_scale = last_grad_scale.clone()
|
|
|
|
|
|
|
|
input /= input.sum(dim=1, keepdim=True, dtype=weights.dtype)
|
|
|
|
|
|
|
|
# #################################################
|
|
|
|
# User doesn't want us to calculate the gradients
|
|
|
|
# #################################################
|
|
|
|
|
|
|
|
if parameter_skip_gradient_calculation is True:
|
|
|
|
|
|
|
|
return (
|
|
|
|
grad_input,
|
|
|
|
grad_spikes,
|
|
|
|
grad_eps_xy,
|
|
|
|
grad_epsilon_t_0,
|
|
|
|
grad_weights,
|
|
|
|
grad_h_initial,
|
|
|
|
grad_parameter_list,
|
|
|
|
grad_output_scale,
|
|
|
|
grad_forgetting_offset,
|
|
|
|
grad_labels,
|
|
|
|
)
|
|
|
|
|
|
|
|
# #################################################
|
|
|
|
# Calculate backprop error (grad_input)
|
|
|
|
# #################################################
|
|
|
|
|
|
|
|
backprop_r: torch.Tensor = weights.unsqueeze(0).unsqueeze(-1).unsqueeze(
|
|
|
|
-1
|
|
|
|
) * output.unsqueeze(1)
|
|
|
|
|
|
|
|
backprop_bigr: torch.Tensor = backprop_r.sum(dim=2)
|
|
|
|
|
|
|
|
backprop_z: torch.Tensor = backprop_r * (
|
|
|
|
1.0 / (backprop_bigr + 1e-20)
|
|
|
|
).unsqueeze(2)
|
|
|
|
grad_input: torch.Tensor = (backprop_z * grad_output.unsqueeze(1)).sum(2)
|
|
|
|
del backprop_z
|
|
|
|
|
|
|
|
# #################################################
|
|
|
|
# Calculate weight gradient (grad_weights)
|
|
|
|
# #################################################
|
|
|
|
|
|
|
|
if parameter_w_trainable is False:
|
|
|
|
|
|
|
|
# #################################################
|
|
|
|
# We don't train this weight
|
|
|
|
# #################################################
|
|
|
|
grad_weights = None
|
|
|
|
|
|
|
|
elif (parameter_output_layer is False) and (parameter_local_learning is True):
|
|
|
|
# #################################################
|
|
|
|
# Local learning
|
|
|
|
# #################################################
|
|
|
|
grad_weights = (
|
|
|
|
(-2 * (input - backprop_bigr).unsqueeze(2) * output.unsqueeze(1))
|
|
|
|
.sum(0)
|
|
|
|
.sum(-1)
|
|
|
|
.sum(-1)
|
|
|
|
)
|
|
|
|
|
|
|
|
elif (parameter_output_layer is True) and (parameter_local_learning is True):
|
|
|
|
target_one_hot: torch.Tensor = torch.zeros(
|
|
|
|
(
|
|
|
|
labels.shape[0],
|
|
|
|
output.shape[1],
|
|
|
|
),
|
|
|
|
device=input.device,
|
|
|
|
dtype=input.dtype,
|
|
|
|
)
|
|
|
|
|
|
|
|
target_one_hot.scatter_(
|
|
|
|
1,
|
|
|
|
labels.to(input.device).unsqueeze(1),
|
|
|
|
torch.ones(
|
|
|
|
(labels.shape[0], 1),
|
|
|
|
device=input.device,
|
|
|
|
dtype=input.dtype,
|
|
|
|
),
|
|
|
|
)
|
|
|
|
target_one_hot = target_one_hot.unsqueeze(-1).unsqueeze(-1)
|
|
|
|
|
|
|
|
# (-2 * (input - backprop_bigr).unsqueeze(2) * (target_one_hot-output).unsqueeze(1))
|
|
|
|
# (-2 * input.unsqueeze(2) * (target_one_hot-output).unsqueeze(1))
|
|
|
|
grad_weights = (
|
|
|
|
(
|
|
|
|
-2
|
|
|
|
* (input - backprop_bigr).unsqueeze(2)
|
|
|
|
* target_one_hot.unsqueeze(1)
|
|
|
|
)
|
|
|
|
.sum(0)
|
|
|
|
.sum(-1)
|
|
|
|
.sum(-1)
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
# #################################################
|
|
|
|
# Backprop
|
|
|
|
# #################################################
|
|
|
|
backprop_f: torch.Tensor = output.unsqueeze(1) * (
|
|
|
|
input / (backprop_bigr**2 + 1e-20)
|
|
|
|
).unsqueeze(2)
|
|
|
|
|
|
|
|
result_omega: torch.Tensor = backprop_bigr.unsqueeze(
|
|
|
|
2
|
|
|
|
) * grad_output.unsqueeze(1)
|
|
|
|
result_omega -= (backprop_r * grad_output.unsqueeze(1)).sum(2).unsqueeze(2)
|
|
|
|
result_omega *= backprop_f
|
|
|
|
del backprop_f
|
|
|
|
grad_weights = result_omega.sum(0).sum(-1).sum(-1)
|
|
|
|
del result_omega
|
|
|
|
|
|
|
|
del backprop_bigr
|
|
|
|
del backprop_r
|
|
|
|
|
|
|
|
return (
|
|
|
|
grad_input,
|
|
|
|
grad_spikes,
|
|
|
|
grad_eps_xy,
|
|
|
|
grad_epsilon_t_0,
|
|
|
|
grad_weights,
|
|
|
|
grad_h_initial,
|
|
|
|
grad_parameter_list,
|
|
|
|
grad_output_scale,
|
|
|
|
grad_forgetting_offset,
|
|
|
|
grad_labels,
|
|
|
|
)
|