173 lines
5.3 KiB
Python
173 lines
5.3 KiB
Python
import torch
|
|
|
|
from network.PySortSpikesCPU import SortSpikesCPU
|
|
|
|
|
|
class SortSpikesLayer(torch.nn.Module):
|
|
|
|
_kernel_size: list[int]
|
|
_stride: list[int]
|
|
_dilation: list[int]
|
|
_padding: list[int]
|
|
_output_size: torch.Tensor
|
|
_number_of_cpu_processes: int
|
|
_input_shape: list[int]
|
|
|
|
order: torch.Tensor | None = None
|
|
order_convoled: torch.Tensor | None = None
|
|
indices: torch.Tensor | None = None
|
|
|
|
def __init__(
|
|
self,
|
|
kernel_size: list[int],
|
|
input_shape: list[int],
|
|
output_size: torch.Tensor,
|
|
strides: list[int] = [1, 1],
|
|
dilation: list[int] = [0, 0],
|
|
padding: list[int] = [0, 0],
|
|
number_of_cpu_processes: int = 1,
|
|
) -> None:
|
|
|
|
super().__init__()
|
|
|
|
self._stride = strides
|
|
self._dilation = dilation
|
|
self._padding = padding
|
|
self._kernel_size = kernel_size
|
|
self._output_size = output_size
|
|
self._number_of_cpu_processes = number_of_cpu_processes
|
|
self._input_shape = input_shape
|
|
|
|
self.sort_spikes = SortSpikesCPU()
|
|
|
|
self.order = (
|
|
torch.arange(
|
|
0,
|
|
self._input_shape[0] * self._input_shape[1] * self._input_shape[2],
|
|
device=torch.device("cpu"),
|
|
)
|
|
.reshape(
|
|
(
|
|
1,
|
|
self._input_shape[0],
|
|
self._input_shape[1],
|
|
self._input_shape[2],
|
|
)
|
|
)
|
|
.type(dtype=torch.float32)
|
|
)
|
|
|
|
self.order_convoled = torch.nn.functional.fold(
|
|
torch.nn.functional.unfold(
|
|
self.order,
|
|
kernel_size=(
|
|
int(self._kernel_size[0]),
|
|
int(self._kernel_size[1]),
|
|
),
|
|
dilation=(int(self._dilation[0]), int(self._dilation[1])),
|
|
padding=(int(self._padding[0]), int(self._padding[1])),
|
|
stride=(int(self._stride[0]), int(self._stride[1])),
|
|
),
|
|
output_size=tuple(self._output_size.tolist()),
|
|
kernel_size=(1, 1),
|
|
dilation=(1, 1),
|
|
padding=(0, 0),
|
|
stride=(1, 1),
|
|
).type(dtype=torch.int64)
|
|
|
|
assert self.order_convoled is not None
|
|
|
|
self.order_convoled = self.order_convoled.reshape(
|
|
(
|
|
self.order_convoled.shape[1]
|
|
* self.order_convoled.shape[2]
|
|
* self.order_convoled.shape[3]
|
|
)
|
|
)
|
|
|
|
max_length: int = 0
|
|
max_range: int = (
|
|
self._input_shape[0] * self._input_shape[1] * self._input_shape[2]
|
|
)
|
|
for id in range(0, max_range):
|
|
idx = torch.where(self.order_convoled == id)[0]
|
|
max_length = max(max_length, int(idx.shape[0]))
|
|
|
|
self.indices = torch.full(
|
|
(max_range, max_length),
|
|
-1,
|
|
dtype=torch.int64,
|
|
device=torch.device("cpu"),
|
|
)
|
|
|
|
for id in range(0, max_range):
|
|
idx = torch.where(self.order_convoled == id)[0]
|
|
self.indices[id, 0 : int(idx.shape[0])] = idx
|
|
|
|
####################################################################
|
|
# Forward #
|
|
####################################################################
|
|
def forward(
|
|
self,
|
|
input: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
|
|
assert len(self._input_shape) == 3
|
|
assert input.shape[-2] == 1
|
|
assert input.shape[-1] == 1
|
|
assert self.indices is not None
|
|
|
|
spikes_count = torch.zeros(
|
|
(input.shape[0], int(self._output_size[0]), int(self._output_size[1])),
|
|
device=torch.device("cpu"),
|
|
dtype=torch.int64,
|
|
)
|
|
|
|
input_cpu = input.clone().cpu()
|
|
|
|
self.sort_spikes.count(
|
|
input_cpu.data_ptr(), # Input
|
|
int(input_cpu.shape[0]),
|
|
int(input_cpu.shape[1]),
|
|
int(input_cpu.shape[2]),
|
|
int(input_cpu.shape[3]),
|
|
spikes_count.data_ptr(), # Output
|
|
int(spikes_count.shape[0]),
|
|
int(spikes_count.shape[1]),
|
|
int(spikes_count.shape[2]),
|
|
self.indices.data_ptr(), # Positions
|
|
int(self.indices.shape[0]),
|
|
int(self.indices.shape[1]),
|
|
int(self._number_of_cpu_processes),
|
|
)
|
|
|
|
spikes_output = torch.full(
|
|
(
|
|
input.shape[0],
|
|
int(spikes_count.max()),
|
|
int(self._output_size[0]),
|
|
int(self._output_size[1]),
|
|
),
|
|
-1,
|
|
dtype=torch.int64,
|
|
device=torch.device("cpu"),
|
|
)
|
|
|
|
self.sort_spikes.process(
|
|
input_cpu.data_ptr(), # Input
|
|
int(input_cpu.shape[0]),
|
|
int(input_cpu.shape[1]),
|
|
int(input_cpu.shape[2]),
|
|
int(input_cpu.shape[3]),
|
|
spikes_output.data_ptr(), # Output
|
|
int(spikes_output.shape[0]),
|
|
int(spikes_output.shape[1]),
|
|
int(spikes_output.shape[2]),
|
|
int(spikes_output.shape[3]),
|
|
self.indices.data_ptr(), # Positions
|
|
int(self.indices.shape[0]),
|
|
int(self.indices.shape[1]),
|
|
int(self._number_of_cpu_processes),
|
|
)
|
|
|
|
return spikes_output
|