Add files via upload
This commit is contained in:
parent
8983d63330
commit
41df07230d
1 changed files with 237 additions and 0 deletions
237
network/SpikeLayer.py
Normal file
237
network/SpikeLayer.py
Normal file
|
@ -0,0 +1,237 @@
|
|||
import torch
|
||||
|
||||
from network.PySpikeGenerationCPU import SpikeGenerationCPU
|
||||
from network.PySpikeGenerationGPU import SpikeGenerationGPU
|
||||
|
||||
# from PyCountSpikesCPU import CountSpikesCPU
|
||||
|
||||
global_spike_generation_gpu_setting: list[torch.Tensor] = []
|
||||
global_spike_size: list[torch.Tensor] = []
|
||||
global_spike_generation_cpp: list[SpikeGenerationCPU | SpikeGenerationGPU] = []
|
||||
|
||||
|
||||
class SpikeLayer(torch.nn.Module):
|
||||
|
||||
_spike_generation_cpp_position: int
|
||||
_spike_generation_gpu_setting_position: int
|
||||
_number_of_cpu_processes: int
|
||||
_number_of_spikes: int
|
||||
|
||||
_spikes: torch.Tensor | None = None
|
||||
_store_spikes: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
number_of_spikes: int = 1,
|
||||
number_of_cpu_processes: int = 1,
|
||||
device: torch.device | None = None,
|
||||
default_dtype: torch.dtype | None = None,
|
||||
store_spikes: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
assert device is not None
|
||||
assert default_dtype is not None
|
||||
self.device = device
|
||||
self.default_dtype = default_dtype
|
||||
|
||||
self._number_of_cpu_processes = number_of_cpu_processes
|
||||
self._number_of_spikes = number_of_spikes
|
||||
self._store_spikes = store_spikes
|
||||
|
||||
global_spike_generation_gpu_setting.append(torch.tensor([0]))
|
||||
global_spike_size.append(torch.tensor([0, 0, 0, 0]))
|
||||
|
||||
if device == torch.device("cpu"):
|
||||
global_spike_generation_cpp.append(SpikeGenerationCPU())
|
||||
else:
|
||||
global_spike_generation_cpp.append(SpikeGenerationGPU())
|
||||
|
||||
self._spike_generation_cpp_position = len(global_spike_generation_cpp) - 1
|
||||
self._spike_generation_gpu_setting_position = (
|
||||
len(global_spike_generation_gpu_setting) - 1
|
||||
)
|
||||
|
||||
self.functional_spike_generation = FunctionalSpikeGeneration.apply
|
||||
|
||||
####################################################################
|
||||
# Forward #
|
||||
####################################################################
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
number_of_spikes: int | None = None,
|
||||
store_spikes: bool | None = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if number_of_spikes is None:
|
||||
number_of_spikes = self._number_of_spikes
|
||||
|
||||
assert number_of_spikes > 0
|
||||
|
||||
parameter_list = torch.tensor(
|
||||
[
|
||||
int(self._number_of_cpu_processes), # 0
|
||||
int(self._spike_generation_cpp_position), # 1
|
||||
int(self._spike_generation_gpu_setting_position), # 2
|
||||
int(number_of_spikes), # 3
|
||||
],
|
||||
dtype=torch.int64,
|
||||
)
|
||||
|
||||
spikes = self.functional_spike_generation(input, parameter_list)
|
||||
|
||||
if (store_spikes is not None) and (store_spikes is True):
|
||||
self._spikes = spikes.detach().clone()
|
||||
elif (store_spikes is not None) and (store_spikes is False):
|
||||
self._spikes = None
|
||||
elif self._store_spikes is True:
|
||||
self._spikes = spikes.detach().clone()
|
||||
else:
|
||||
self._spikes = None
|
||||
|
||||
return spikes
|
||||
|
||||
|
||||
class FunctionalSpikeGeneration(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward( # type: ignore
|
||||
ctx,
|
||||
input: torch.Tensor,
|
||||
parameter_list: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert input.dim() == 4
|
||||
|
||||
if input.device == torch.device("cpu"):
|
||||
spike_number_of_cpu_processes: int = int(parameter_list[0])
|
||||
else:
|
||||
spike_number_of_cpu_processes = -1
|
||||
|
||||
spike_generation_cpp_position = int(parameter_list[1])
|
||||
spike_generation_gpu_setting_position = int(parameter_list[2])
|
||||
number_of_spikes: int = int(parameter_list[3])
|
||||
|
||||
# ###########################################################
|
||||
# Spike generation
|
||||
# ###########################################################
|
||||
|
||||
# ############################################
|
||||
# Normalized cumsum
|
||||
# (beware of the pytorch bug! Thus .clone()!)
|
||||
# ############################################
|
||||
input_cumsum: torch.Tensor = torch.cumsum(input, dim=1, dtype=input.dtype)
|
||||
input_cumsum_last: torch.Tensor = input_cumsum[:, -1, :, :].unsqueeze(1).clone()
|
||||
input_cumsum /= input_cumsum_last
|
||||
|
||||
# ############################################
|
||||
# Get the required random numbers
|
||||
# ############################################
|
||||
random_values = torch.rand(
|
||||
size=[
|
||||
input_cumsum.shape[0],
|
||||
number_of_spikes,
|
||||
input_cumsum.shape[2],
|
||||
input_cumsum.shape[3],
|
||||
],
|
||||
dtype=input.dtype,
|
||||
device=input.device,
|
||||
)
|
||||
|
||||
# ############################################
|
||||
# Make space for the results
|
||||
# ############################################
|
||||
spikes = torch.empty_like(random_values, dtype=torch.int64, device=input.device)
|
||||
|
||||
assert input_cumsum.is_contiguous() is True
|
||||
assert random_values.is_contiguous() is True
|
||||
assert spikes.is_contiguous() is True
|
||||
|
||||
# time_start: float = time.perf_counter()
|
||||
spike_generation_profile = global_spike_generation_gpu_setting[
|
||||
spike_generation_gpu_setting_position
|
||||
].clone()
|
||||
|
||||
spike_generation_size = global_spike_size[
|
||||
spike_generation_gpu_setting_position
|
||||
].clone()
|
||||
|
||||
if (
|
||||
isinstance(
|
||||
global_spike_generation_cpp[spike_generation_cpp_position],
|
||||
SpikeGenerationGPU,
|
||||
)
|
||||
is True
|
||||
):
|
||||
if (
|
||||
(spike_generation_profile.numel() == 1)
|
||||
or (spike_generation_size[0] != int(spikes.shape[0]))
|
||||
or (spike_generation_size[1] != int(spikes.shape[1]))
|
||||
or (spike_generation_size[2] != int(spikes.shape[2]))
|
||||
or (spike_generation_size[3] != int(spikes.shape[3]))
|
||||
):
|
||||
|
||||
spike_generation_profile = torch.zeros(
|
||||
(1, 7), dtype=torch.int64, device=torch.device("cpu")
|
||||
)
|
||||
global_spike_generation_cpp[
|
||||
spike_generation_cpp_position
|
||||
].gpu_occupancy_export(
|
||||
int(spikes.shape[2]),
|
||||
int(spikes.shape[3]),
|
||||
int(spikes.shape[0]),
|
||||
int(spikes.shape[1]),
|
||||
spike_generation_profile.data_ptr(),
|
||||
int(spike_generation_profile.shape[0]),
|
||||
int(spike_generation_profile.shape[1]),
|
||||
)
|
||||
global_spike_generation_gpu_setting[
|
||||
spike_generation_gpu_setting_position
|
||||
] = spike_generation_profile.clone()
|
||||
|
||||
spike_generation_size[0] = int(spikes.shape[0])
|
||||
spike_generation_size[1] = int(spikes.shape[1])
|
||||
spike_generation_size[2] = int(spikes.shape[2])
|
||||
spike_generation_size[3] = int(spikes.shape[3])
|
||||
global_spike_size[
|
||||
spike_generation_gpu_setting_position
|
||||
] = spike_generation_size.clone()
|
||||
|
||||
else:
|
||||
global_spike_generation_cpp[
|
||||
spike_generation_cpp_position
|
||||
].gpu_occupancy_import(
|
||||
spike_generation_profile.data_ptr(),
|
||||
int(spike_generation_profile.shape[0]),
|
||||
int(spike_generation_profile.shape[1]),
|
||||
)
|
||||
|
||||
global_spike_generation_cpp[spike_generation_cpp_position].spike_generation(
|
||||
input_cumsum.data_ptr(),
|
||||
int(input_cumsum.shape[0]),
|
||||
int(input_cumsum.shape[1]),
|
||||
int(input_cumsum.shape[2]),
|
||||
int(input_cumsum.shape[3]),
|
||||
random_values.data_ptr(),
|
||||
int(random_values.shape[0]),
|
||||
int(random_values.shape[1]),
|
||||
int(random_values.shape[2]),
|
||||
int(random_values.shape[3]),
|
||||
spikes.data_ptr(),
|
||||
int(spikes.shape[0]),
|
||||
int(spikes.shape[1]),
|
||||
int(spikes.shape[2]),
|
||||
int(spikes.shape[3]),
|
||||
int(spike_number_of_cpu_processes),
|
||||
)
|
||||
del random_values
|
||||
del input_cumsum
|
||||
|
||||
return spikes
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_input = grad_output
|
||||
grad_parameter_list = None
|
||||
return (grad_input, grad_parameter_list)
|
Loading…
Reference in a new issue