diff --git a/network/SpikeLayer.py b/network/SpikeLayer.py new file mode 100644 index 0000000..f7f7060 --- /dev/null +++ b/network/SpikeLayer.py @@ -0,0 +1,237 @@ +import torch + +from network.PySpikeGenerationCPU import SpikeGenerationCPU +from network.PySpikeGenerationGPU import SpikeGenerationGPU + +# from PyCountSpikesCPU import CountSpikesCPU + +global_spike_generation_gpu_setting: list[torch.Tensor] = [] +global_spike_size: list[torch.Tensor] = [] +global_spike_generation_cpp: list[SpikeGenerationCPU | SpikeGenerationGPU] = [] + + +class SpikeLayer(torch.nn.Module): + + _spike_generation_cpp_position: int + _spike_generation_gpu_setting_position: int + _number_of_cpu_processes: int + _number_of_spikes: int + + _spikes: torch.Tensor | None = None + _store_spikes: bool + + def __init__( + self, + number_of_spikes: int = 1, + number_of_cpu_processes: int = 1, + device: torch.device | None = None, + default_dtype: torch.dtype | None = None, + store_spikes: bool = False, + ) -> None: + super().__init__() + + assert device is not None + assert default_dtype is not None + self.device = device + self.default_dtype = default_dtype + + self._number_of_cpu_processes = number_of_cpu_processes + self._number_of_spikes = number_of_spikes + self._store_spikes = store_spikes + + global_spike_generation_gpu_setting.append(torch.tensor([0])) + global_spike_size.append(torch.tensor([0, 0, 0, 0])) + + if device == torch.device("cpu"): + global_spike_generation_cpp.append(SpikeGenerationCPU()) + else: + global_spike_generation_cpp.append(SpikeGenerationGPU()) + + self._spike_generation_cpp_position = len(global_spike_generation_cpp) - 1 + self._spike_generation_gpu_setting_position = ( + len(global_spike_generation_gpu_setting) - 1 + ) + + self.functional_spike_generation = FunctionalSpikeGeneration.apply + + #################################################################### + # Forward # + #################################################################### + + def forward( + self, + input: torch.Tensor, + number_of_spikes: int | None = None, + store_spikes: bool | None = None, + ) -> torch.Tensor: + + if number_of_spikes is None: + number_of_spikes = self._number_of_spikes + + assert number_of_spikes > 0 + + parameter_list = torch.tensor( + [ + int(self._number_of_cpu_processes), # 0 + int(self._spike_generation_cpp_position), # 1 + int(self._spike_generation_gpu_setting_position), # 2 + int(number_of_spikes), # 3 + ], + dtype=torch.int64, + ) + + spikes = self.functional_spike_generation(input, parameter_list) + + if (store_spikes is not None) and (store_spikes is True): + self._spikes = spikes.detach().clone() + elif (store_spikes is not None) and (store_spikes is False): + self._spikes = None + elif self._store_spikes is True: + self._spikes = spikes.detach().clone() + else: + self._spikes = None + + return spikes + + +class FunctionalSpikeGeneration(torch.autograd.Function): + @staticmethod + def forward( # type: ignore + ctx, + input: torch.Tensor, + parameter_list: torch.Tensor, + ) -> torch.Tensor: + + assert input.dim() == 4 + + if input.device == torch.device("cpu"): + spike_number_of_cpu_processes: int = int(parameter_list[0]) + else: + spike_number_of_cpu_processes = -1 + + spike_generation_cpp_position = int(parameter_list[1]) + spike_generation_gpu_setting_position = int(parameter_list[2]) + number_of_spikes: int = int(parameter_list[3]) + + # ########################################################### + # Spike generation + # ########################################################### + + # ############################################ + # Normalized cumsum + # (beware of the pytorch bug! Thus .clone()!) + # ############################################ + input_cumsum: torch.Tensor = torch.cumsum(input, dim=1, dtype=input.dtype) + input_cumsum_last: torch.Tensor = input_cumsum[:, -1, :, :].unsqueeze(1).clone() + input_cumsum /= input_cumsum_last + + # ############################################ + # Get the required random numbers + # ############################################ + random_values = torch.rand( + size=[ + input_cumsum.shape[0], + number_of_spikes, + input_cumsum.shape[2], + input_cumsum.shape[3], + ], + dtype=input.dtype, + device=input.device, + ) + + # ############################################ + # Make space for the results + # ############################################ + spikes = torch.empty_like(random_values, dtype=torch.int64, device=input.device) + + assert input_cumsum.is_contiguous() is True + assert random_values.is_contiguous() is True + assert spikes.is_contiguous() is True + + # time_start: float = time.perf_counter() + spike_generation_profile = global_spike_generation_gpu_setting[ + spike_generation_gpu_setting_position + ].clone() + + spike_generation_size = global_spike_size[ + spike_generation_gpu_setting_position + ].clone() + + if ( + isinstance( + global_spike_generation_cpp[spike_generation_cpp_position], + SpikeGenerationGPU, + ) + is True + ): + if ( + (spike_generation_profile.numel() == 1) + or (spike_generation_size[0] != int(spikes.shape[0])) + or (spike_generation_size[1] != int(spikes.shape[1])) + or (spike_generation_size[2] != int(spikes.shape[2])) + or (spike_generation_size[3] != int(spikes.shape[3])) + ): + + spike_generation_profile = torch.zeros( + (1, 7), dtype=torch.int64, device=torch.device("cpu") + ) + global_spike_generation_cpp[ + spike_generation_cpp_position + ].gpu_occupancy_export( + int(spikes.shape[2]), + int(spikes.shape[3]), + int(spikes.shape[0]), + int(spikes.shape[1]), + spike_generation_profile.data_ptr(), + int(spike_generation_profile.shape[0]), + int(spike_generation_profile.shape[1]), + ) + global_spike_generation_gpu_setting[ + spike_generation_gpu_setting_position + ] = spike_generation_profile.clone() + + spike_generation_size[0] = int(spikes.shape[0]) + spike_generation_size[1] = int(spikes.shape[1]) + spike_generation_size[2] = int(spikes.shape[2]) + spike_generation_size[3] = int(spikes.shape[3]) + global_spike_size[ + spike_generation_gpu_setting_position + ] = spike_generation_size.clone() + + else: + global_spike_generation_cpp[ + spike_generation_cpp_position + ].gpu_occupancy_import( + spike_generation_profile.data_ptr(), + int(spike_generation_profile.shape[0]), + int(spike_generation_profile.shape[1]), + ) + + global_spike_generation_cpp[spike_generation_cpp_position].spike_generation( + input_cumsum.data_ptr(), + int(input_cumsum.shape[0]), + int(input_cumsum.shape[1]), + int(input_cumsum.shape[2]), + int(input_cumsum.shape[3]), + random_values.data_ptr(), + int(random_values.shape[0]), + int(random_values.shape[1]), + int(random_values.shape[2]), + int(random_values.shape[3]), + spikes.data_ptr(), + int(spikes.shape[0]), + int(spikes.shape[1]), + int(spikes.shape[2]), + int(spikes.shape[3]), + int(spike_number_of_cpu_processes), + ) + del random_values + del input_cumsum + + return spikes + + @staticmethod + def backward(ctx, grad_output): + grad_input = grad_output + grad_parameter_list = None + return (grad_input, grad_parameter_list)