pytorch-sbs/network/InputSpikeImage.py

112 lines
3.3 KiB
Python
Raw Normal View History

2023-02-04 14:24:47 +01:00
import torch
from network.SpikeLayer import SpikeLayer
from network.SpikeCountLayer import SpikeCountLayer
class InputSpikeImage(torch.nn.Module):
_reshape: bool
_normalize: bool
_device: torch.device
number_of_spikes: int
def __init__(
self,
number_of_spikes: int = -1,
number_of_cpu_processes: int = 1,
reshape: bool = False,
normalize: bool = True,
device: torch.device | None = None,
) -> None:
super().__init__()
assert device is not None
self._device = device
self._reshape = bool(reshape)
self._normalize = bool(normalize)
self.number_of_spikes = int(number_of_spikes)
if device != torch.device("cpu"):
number_of_cpu_processes_spike_generator = 0
else:
number_of_cpu_processes_spike_generator = number_of_cpu_processes
self.spike_generator = SpikeLayer(
number_of_cpu_processes=number_of_cpu_processes_spike_generator,
device=device,
)
self.spike_count = SpikeCountLayer(
number_of_cpu_processes=number_of_cpu_processes
)
####################################################################
# Forward #
####################################################################
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.number_of_spikes < 1:
2023-02-21 14:37:51 +01:00
output = input
output = output.type(dtype=input.dtype)
if self._normalize is True:
output = output * output.shape[-1] * output.shape[-2] * output.shape[-3] / output.sum(dim=-1, keepdim=True).sum(
dim=-2, keepdim=True
).sum(dim=-3, keepdim=True)
return output
2023-02-04 14:24:47 +01:00
input_shape: list[int] = [
int(input.shape[0]),
int(input.shape[1]),
int(input.shape[2]),
int(input.shape[3]),
]
if self._reshape is True:
input_work = (
input.detach()
.clone()
.to(self._device)
.reshape(
(input_shape[0], input_shape[1] * input_shape[2] * input_shape[3])
)
.unsqueeze(-1)
.unsqueeze(-1)
)
else:
input_work = input.detach().clone().to(self._device)
spikes = self.spike_generator(
input=input_work, number_of_spikes=self.number_of_spikes
)
if self._reshape is True:
dim_s: int = input_shape[1] * input_shape[2] * input_shape[3]
else:
dim_s = input_shape[1]
output: torch.Tensor = self.spike_count(spikes, dim_s)
if self._reshape is True:
output = (
output.squeeze(-1)
.squeeze(-1)
.reshape(
(input_shape[0], input_shape[1], input_shape[2], input_shape[3])
)
)
2023-02-21 14:37:51 +01:00
output = output.type(dtype=input_work.dtype)
2023-02-04 14:24:47 +01:00
if self._normalize is True:
2023-02-21 14:37:51 +01:00
output = output * output.shape[-1] * output.shape[-2] * output.shape[-3] / output.sum(dim=-1, keepdim=True).sum(
2023-02-04 14:24:47 +01:00
dim=-2, keepdim=True
).sum(dim=-3, keepdim=True)
return output