pytorch-sbs/network/SpikeCountLayer.py
2023-02-04 14:24:47 +01:00

52 lines
1.4 KiB
Python

import torch
from network.PyCountSpikesCPU import CountSpikesCPU
class SpikeCountLayer(torch.nn.Module):
_number_of_cpu_processes: int
def __init__(
self,
number_of_cpu_processes: int = 1,
) -> None:
super().__init__()
self._number_of_cpu_processes = number_of_cpu_processes
####################################################################
# Forward #
####################################################################
def forward(self, input: torch.Tensor, dim_s: int) -> torch.Tensor:
assert input.ndim == 4
assert dim_s > 0
input_cpu = input.cpu()
histogram = torch.zeros(
(
int(input.shape[0]),
int(dim_s),
int(input.shape[-2]),
int(input.shape[-1]),
),
dtype=torch.int64,
device=input_cpu.device,
)
count_spikes = CountSpikesCPU()
count_spikes.process(
input_cpu.data_ptr(),
int(input_cpu.shape[0]),
int(input_cpu.shape[1]),
int(input_cpu.shape[2]),
int(input_cpu.shape[3]),
histogram.data_ptr(),
int(histogram.shape[1]),
int(self._number_of_cpu_processes),
)
return histogram.to(device=input.device)