pytorch-sbs/network/SortSpikesLayer.py

174 lines
5.3 KiB
Python
Raw Normal View History

2023-02-04 14:24:47 +01:00
import torch
from network.PySortSpikesCPU import SortSpikesCPU
class SortSpikesLayer(torch.nn.Module):
_kernel_size: list[int]
_stride: list[int]
_dilation: list[int]
_padding: list[int]
_output_size: torch.Tensor
_number_of_cpu_processes: int
_input_shape: list[int]
order: torch.Tensor | None = None
order_convoled: torch.Tensor | None = None
indices: torch.Tensor | None = None
def __init__(
self,
kernel_size: list[int],
input_shape: list[int],
output_size: torch.Tensor,
strides: list[int] = [1, 1],
dilation: list[int] = [0, 0],
padding: list[int] = [0, 0],
number_of_cpu_processes: int = 1,
) -> None:
super().__init__()
self._stride = strides
self._dilation = dilation
self._padding = padding
self._kernel_size = kernel_size
self._output_size = output_size
self._number_of_cpu_processes = number_of_cpu_processes
self._input_shape = input_shape
self.sort_spikes = SortSpikesCPU()
self.order = (
torch.arange(
0,
self._input_shape[0] * self._input_shape[1] * self._input_shape[2],
device=torch.device("cpu"),
)
.reshape(
(
1,
self._input_shape[0],
self._input_shape[1],
self._input_shape[2],
)
)
.type(dtype=torch.float32)
)
self.order_convoled = torch.nn.functional.fold(
torch.nn.functional.unfold(
self.order,
kernel_size=(
int(self._kernel_size[0]),
int(self._kernel_size[1]),
),
dilation=(int(self._dilation[0]), int(self._dilation[1])),
padding=(int(self._padding[0]), int(self._padding[1])),
stride=(int(self._stride[0]), int(self._stride[1])),
),
output_size=tuple(self._output_size.tolist()),
kernel_size=(1, 1),
dilation=(1, 1),
padding=(0, 0),
stride=(1, 1),
).type(dtype=torch.int64)
assert self.order_convoled is not None
self.order_convoled = self.order_convoled.reshape(
(
self.order_convoled.shape[1]
* self.order_convoled.shape[2]
* self.order_convoled.shape[3]
)
)
max_length: int = 0
max_range: int = (
self._input_shape[0] * self._input_shape[1] * self._input_shape[2]
)
for id in range(0, max_range):
idx = torch.where(self.order_convoled == id)[0]
max_length = max(max_length, int(idx.shape[0]))
self.indices = torch.full(
(max_range, max_length),
-1,
dtype=torch.int64,
device=torch.device("cpu"),
)
for id in range(0, max_range):
idx = torch.where(self.order_convoled == id)[0]
self.indices[id, 0 : int(idx.shape[0])] = idx
####################################################################
# Forward #
####################################################################
def forward(
self,
input: torch.Tensor,
) -> torch.Tensor:
assert len(self._input_shape) == 3
assert input.shape[-2] == 1
assert input.shape[-1] == 1
assert self.indices is not None
spikes_count = torch.zeros(
(input.shape[0], int(self._output_size[0]), int(self._output_size[1])),
device=torch.device("cpu"),
dtype=torch.int64,
)
input_cpu = input.clone().cpu()
self.sort_spikes.count(
input_cpu.data_ptr(), # Input
int(input_cpu.shape[0]),
int(input_cpu.shape[1]),
int(input_cpu.shape[2]),
int(input_cpu.shape[3]),
spikes_count.data_ptr(), # Output
int(spikes_count.shape[0]),
int(spikes_count.shape[1]),
int(spikes_count.shape[2]),
self.indices.data_ptr(), # Positions
int(self.indices.shape[0]),
int(self.indices.shape[1]),
int(self._number_of_cpu_processes),
)
spikes_output = torch.full(
(
input.shape[0],
int(spikes_count.max()),
int(self._output_size[0]),
int(self._output_size[1]),
),
-1,
dtype=torch.int64,
device=torch.device("cpu"),
)
self.sort_spikes.process(
input_cpu.data_ptr(), # Input
int(input_cpu.shape[0]),
int(input_cpu.shape[1]),
int(input_cpu.shape[2]),
int(input_cpu.shape[3]),
spikes_output.data_ptr(), # Output
int(spikes_output.shape[0]),
int(spikes_output.shape[1]),
int(spikes_output.shape[2]),
int(spikes_output.shape[3]),
self.indices.data_ptr(), # Positions
int(self.indices.shape[0]),
int(self.indices.shape[1]),
int(self._number_of_cpu_processes),
)
return spikes_output