diff --git a/network/SbS.py b/network/SbS.py index 9ff9ac9..72b141d 100644 --- a/network/SbS.py +++ b/network/SbS.py @@ -1,17 +1,21 @@ import torch -from network.CPP.PySpikeGeneration2DManyIP import SpikeGeneration2DManyIP -from network.CPP.PyHDynamicCNNManyIP import HDynamicCNNManyIP +from network.PySpikeGenerationCPU import SpikeGenerationCPU +from network.PySpikeGenerationGPU import SpikeGenerationGPU + +from network.PyHDynamicCNNCPU import HDynamicCNNCPU +from network.PyHDynamicCNNGPU import HDynamicCNNGPU + from network.calculate_output_size import calculate_output_size import os import numpy as np global_sbs_gpu_setting: list[torch.Tensor] = [] global_sbs_size: list[torch.Tensor] = [] -global_sbs_hdynamic_cpp: list[HDynamicCNNManyIP] = [] +global_sbs_hdynamic_cpp: list[HDynamicCNNCPU | HDynamicCNNGPU] = [] global_spike_generation_gpu_setting: list[torch.Tensor] = [] global_spike_size: list[torch.Tensor] = [] -global_spike_generation_cpp: list[SpikeGeneration2DManyIP] = [] +global_spike_generation_cpp: list[SpikeGenerationCPU | SpikeGenerationGPU] = [] class SbS(torch.nn.Module): @@ -61,6 +65,7 @@ class SbS(torch.nn.Module): _cooldown_after_number_of_spikes: int = -1 _reduction_cooldown: float = 1.0 + _layer_id: int = (-1,) def __init__( self, @@ -114,6 +119,7 @@ class SbS(torch.nn.Module): self._is_pooling_layer = bool(is_pooling_layer) self._cooldown_after_number_of_spikes = int(cooldown_after_number_of_spikes) self.reduction_cooldown = float(reduction_cooldown) + self._layer_id = layer_id assert len(input_size) == 2 self._input_size = input_size @@ -123,8 +129,15 @@ class SbS(torch.nn.Module): global_sbs_size.append(torch.tensor([0, 0, 0, 0])) global_spike_size.append(torch.tensor([0, 0, 0, 0])) - global_sbs_hdynamic_cpp.append(HDynamicCNNManyIP()) - global_spike_generation_cpp.append(SpikeGeneration2DManyIP()) + if device == torch.device("cpu"): + global_sbs_hdynamic_cpp.append(HDynamicCNNGPU()) + else: + global_sbs_hdynamic_cpp.append(HDynamicCNNCPU()) + + if device == torch.device("cpu"): + global_spike_generation_cpp.append(SpikeGenerationCPU()) + else: + global_spike_generation_cpp.append(SpikeGenerationGPU()) self.sbs_gpu_setting_position = len(global_sbs_gpu_setting) - 1 self.sbs_hdynamic_cpp_position = len(global_sbs_hdynamic_cpp) - 1