104 lines
2.9 KiB
Python
104 lines
2.9 KiB
Python
import torch
|
|
|
|
from network.SpikeLayer import SpikeLayer
|
|
from network.SpikeCountLayer import SpikeCountLayer
|
|
|
|
|
|
class InputSpikeImage(torch.nn.Module):
|
|
|
|
_reshape: bool
|
|
_normalize: bool
|
|
_device: torch.device
|
|
|
|
number_of_spikes: int
|
|
|
|
def __init__(
|
|
self,
|
|
number_of_spikes: int = -1,
|
|
number_of_cpu_processes: int = 1,
|
|
reshape: bool = False,
|
|
normalize: bool = True,
|
|
device: torch.device | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
assert device is not None
|
|
self._device = device
|
|
|
|
self._reshape = bool(reshape)
|
|
self._normalize = bool(normalize)
|
|
|
|
self.number_of_spikes = int(number_of_spikes)
|
|
|
|
if device != torch.device("cpu"):
|
|
number_of_cpu_processes_spike_generator = 0
|
|
else:
|
|
number_of_cpu_processes_spike_generator = number_of_cpu_processes
|
|
|
|
self.spike_generator = SpikeLayer(
|
|
number_of_cpu_processes=number_of_cpu_processes_spike_generator,
|
|
device=device,
|
|
)
|
|
|
|
self.spike_count = SpikeCountLayer(
|
|
number_of_cpu_processes=number_of_cpu_processes
|
|
)
|
|
|
|
####################################################################
|
|
# Forward #
|
|
####################################################################
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
|
|
if self.number_of_spikes < 1:
|
|
return input
|
|
|
|
input_shape: list[int] = [
|
|
int(input.shape[0]),
|
|
int(input.shape[1]),
|
|
int(input.shape[2]),
|
|
int(input.shape[3]),
|
|
]
|
|
|
|
if self._reshape is True:
|
|
input_work = (
|
|
input.detach()
|
|
.clone()
|
|
.to(self._device)
|
|
.reshape(
|
|
(input_shape[0], input_shape[1] * input_shape[2] * input_shape[3])
|
|
)
|
|
.unsqueeze(-1)
|
|
.unsqueeze(-1)
|
|
)
|
|
else:
|
|
input_work = input.detach().clone().to(self._device)
|
|
|
|
spikes = self.spike_generator(
|
|
input=input_work, number_of_spikes=self.number_of_spikes
|
|
)
|
|
|
|
if self._reshape is True:
|
|
dim_s: int = input_shape[1] * input_shape[2] * input_shape[3]
|
|
else:
|
|
dim_s = input_shape[1]
|
|
|
|
output: torch.Tensor = self.spike_count(spikes, dim_s)
|
|
|
|
if self._reshape is True:
|
|
|
|
output = (
|
|
output.squeeze(-1)
|
|
.squeeze(-1)
|
|
.reshape(
|
|
(input_shape[0], input_shape[1], input_shape[2], input_shape[3])
|
|
)
|
|
)
|
|
|
|
if self._normalize is True:
|
|
output = output.type(dtype=input_work.dtype)
|
|
output = output / output.sum(dim=-1, keepdim=True).sum(
|
|
dim=-2, keepdim=True
|
|
).sum(dim=-3, keepdim=True)
|
|
|
|
return output
|