Add files via upload

This commit is contained in:
David Rotermund 2023-02-02 19:16:06 +01:00 committed by GitHub
parent e7eb98edc5
commit d6469f3ee8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,17 +1,21 @@
import torch
from network.CPP.PySpikeGeneration2DManyIP import SpikeGeneration2DManyIP
from network.CPP.PyHDynamicCNNManyIP import HDynamicCNNManyIP
from network.PySpikeGenerationCPU import SpikeGenerationCPU
from network.PySpikeGenerationGPU import SpikeGenerationGPU
from network.PyHDynamicCNNCPU import HDynamicCNNCPU
from network.PyHDynamicCNNGPU import HDynamicCNNGPU
from network.calculate_output_size import calculate_output_size
import os
import numpy as np
global_sbs_gpu_setting: list[torch.Tensor] = []
global_sbs_size: list[torch.Tensor] = []
global_sbs_hdynamic_cpp: list[HDynamicCNNManyIP] = []
global_sbs_hdynamic_cpp: list[HDynamicCNNCPU | HDynamicCNNGPU] = []
global_spike_generation_gpu_setting: list[torch.Tensor] = []
global_spike_size: list[torch.Tensor] = []
global_spike_generation_cpp: list[SpikeGeneration2DManyIP] = []
global_spike_generation_cpp: list[SpikeGenerationCPU | SpikeGenerationGPU] = []
class SbS(torch.nn.Module):
@ -61,6 +65,7 @@ class SbS(torch.nn.Module):
_cooldown_after_number_of_spikes: int = -1
_reduction_cooldown: float = 1.0
_layer_id: int = (-1,)
def __init__(
self,
@ -114,6 +119,7 @@ class SbS(torch.nn.Module):
self._is_pooling_layer = bool(is_pooling_layer)
self._cooldown_after_number_of_spikes = int(cooldown_after_number_of_spikes)
self.reduction_cooldown = float(reduction_cooldown)
self._layer_id = layer_id
assert len(input_size) == 2
self._input_size = input_size
@ -123,8 +129,15 @@ class SbS(torch.nn.Module):
global_sbs_size.append(torch.tensor([0, 0, 0, 0]))
global_spike_size.append(torch.tensor([0, 0, 0, 0]))
global_sbs_hdynamic_cpp.append(HDynamicCNNManyIP())
global_spike_generation_cpp.append(SpikeGeneration2DManyIP())
if device == torch.device("cpu"):
global_sbs_hdynamic_cpp.append(HDynamicCNNGPU())
else:
global_sbs_hdynamic_cpp.append(HDynamicCNNCPU())
if device == torch.device("cpu"):
global_spike_generation_cpp.append(SpikeGenerationCPU())
else:
global_spike_generation_cpp.append(SpikeGenerationGPU())
self.sbs_gpu_setting_position = len(global_sbs_gpu_setting) - 1
self.sbs_hdynamic_cpp_position = len(global_sbs_hdynamic_cpp) - 1