pytorch-sbs/network/SpikeLayer.py
2023-02-02 23:08:20 +01:00

237 lines
8.4 KiB
Python

import torch
from network.PySpikeGenerationCPU import SpikeGenerationCPU
from network.PySpikeGenerationGPU import SpikeGenerationGPU
# from PyCountSpikesCPU import CountSpikesCPU
global_spike_generation_gpu_setting: list[torch.Tensor] = []
global_spike_size: list[torch.Tensor] = []
global_spike_generation_cpp: list[SpikeGenerationCPU | SpikeGenerationGPU] = []
class SpikeLayer(torch.nn.Module):
_spike_generation_cpp_position: int
_spike_generation_gpu_setting_position: int
_number_of_cpu_processes: int
_number_of_spikes: int
_spikes: torch.Tensor | None = None
_store_spikes: bool
def __init__(
self,
number_of_spikes: int = 1,
number_of_cpu_processes: int = 1,
device: torch.device | None = None,
default_dtype: torch.dtype | None = None,
store_spikes: bool = False,
) -> None:
super().__init__()
assert device is not None
assert default_dtype is not None
self.device = device
self.default_dtype = default_dtype
self._number_of_cpu_processes = number_of_cpu_processes
self._number_of_spikes = number_of_spikes
self._store_spikes = store_spikes
global_spike_generation_gpu_setting.append(torch.tensor([0]))
global_spike_size.append(torch.tensor([0, 0, 0, 0]))
if device == torch.device("cpu"):
global_spike_generation_cpp.append(SpikeGenerationCPU())
else:
global_spike_generation_cpp.append(SpikeGenerationGPU())
self._spike_generation_cpp_position = len(global_spike_generation_cpp) - 1
self._spike_generation_gpu_setting_position = (
len(global_spike_generation_gpu_setting) - 1
)
self.functional_spike_generation = FunctionalSpikeGeneration.apply
####################################################################
# Forward #
####################################################################
def forward(
self,
input: torch.Tensor,
number_of_spikes: int | None = None,
store_spikes: bool | None = None,
) -> torch.Tensor:
if number_of_spikes is None:
number_of_spikes = self._number_of_spikes
assert number_of_spikes > 0
parameter_list = torch.tensor(
[
int(self._number_of_cpu_processes), # 0
int(self._spike_generation_cpp_position), # 1
int(self._spike_generation_gpu_setting_position), # 2
int(number_of_spikes), # 3
],
dtype=torch.int64,
)
spikes = self.functional_spike_generation(input, parameter_list)
if (store_spikes is not None) and (store_spikes is True):
self._spikes = spikes.detach().clone()
elif (store_spikes is not None) and (store_spikes is False):
self._spikes = None
elif self._store_spikes is True:
self._spikes = spikes.detach().clone()
else:
self._spikes = None
return spikes
class FunctionalSpikeGeneration(torch.autograd.Function):
@staticmethod
def forward( # type: ignore
ctx,
input: torch.Tensor,
parameter_list: torch.Tensor,
) -> torch.Tensor:
assert input.dim() == 4
if input.device == torch.device("cpu"):
spike_number_of_cpu_processes: int = int(parameter_list[0])
else:
spike_number_of_cpu_processes = -1
spike_generation_cpp_position = int(parameter_list[1])
spike_generation_gpu_setting_position = int(parameter_list[2])
number_of_spikes: int = int(parameter_list[3])
# ###########################################################
# Spike generation
# ###########################################################
# ############################################
# Normalized cumsum
# (beware of the pytorch bug! Thus .clone()!)
# ############################################
input_cumsum: torch.Tensor = torch.cumsum(input, dim=1, dtype=input.dtype)
input_cumsum_last: torch.Tensor = input_cumsum[:, -1, :, :].unsqueeze(1).clone()
input_cumsum /= input_cumsum_last
# ############################################
# Get the required random numbers
# ############################################
random_values = torch.rand(
size=[
input_cumsum.shape[0],
number_of_spikes,
input_cumsum.shape[2],
input_cumsum.shape[3],
],
dtype=input.dtype,
device=input.device,
)
# ############################################
# Make space for the results
# ############################################
spikes = torch.empty_like(random_values, dtype=torch.int64, device=input.device)
assert input_cumsum.is_contiguous() is True
assert random_values.is_contiguous() is True
assert spikes.is_contiguous() is True
# time_start: float = time.perf_counter()
spike_generation_profile = global_spike_generation_gpu_setting[
spike_generation_gpu_setting_position
].clone()
spike_generation_size = global_spike_size[
spike_generation_gpu_setting_position
].clone()
if (
isinstance(
global_spike_generation_cpp[spike_generation_cpp_position],
SpikeGenerationGPU,
)
is True
):
if (
(spike_generation_profile.numel() == 1)
or (spike_generation_size[0] != int(spikes.shape[0]))
or (spike_generation_size[1] != int(spikes.shape[1]))
or (spike_generation_size[2] != int(spikes.shape[2]))
or (spike_generation_size[3] != int(spikes.shape[3]))
):
spike_generation_profile = torch.zeros(
(1, 7), dtype=torch.int64, device=torch.device("cpu")
)
global_spike_generation_cpp[
spike_generation_cpp_position
].gpu_occupancy_export(
int(spikes.shape[2]),
int(spikes.shape[3]),
int(spikes.shape[0]),
int(spikes.shape[1]),
spike_generation_profile.data_ptr(),
int(spike_generation_profile.shape[0]),
int(spike_generation_profile.shape[1]),
)
global_spike_generation_gpu_setting[
spike_generation_gpu_setting_position
] = spike_generation_profile.clone()
spike_generation_size[0] = int(spikes.shape[0])
spike_generation_size[1] = int(spikes.shape[1])
spike_generation_size[2] = int(spikes.shape[2])
spike_generation_size[3] = int(spikes.shape[3])
global_spike_size[
spike_generation_gpu_setting_position
] = spike_generation_size.clone()
else:
global_spike_generation_cpp[
spike_generation_cpp_position
].gpu_occupancy_import(
spike_generation_profile.data_ptr(),
int(spike_generation_profile.shape[0]),
int(spike_generation_profile.shape[1]),
)
global_spike_generation_cpp[spike_generation_cpp_position].spike_generation(
input_cumsum.data_ptr(),
int(input_cumsum.shape[0]),
int(input_cumsum.shape[1]),
int(input_cumsum.shape[2]),
int(input_cumsum.shape[3]),
random_values.data_ptr(),
int(random_values.shape[0]),
int(random_values.shape[1]),
int(random_values.shape[2]),
int(random_values.shape[3]),
spikes.data_ptr(),
int(spikes.shape[0]),
int(spikes.shape[1]),
int(spikes.shape[2]),
int(spikes.shape[3]),
int(spike_number_of_cpu_processes),
)
del random_values
del input_cumsum
return spikes
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output
grad_parameter_list = None
return (grad_input, grad_parameter_list)