pytorch-sbs/network/PoissonLayer.py

75 lines
1.6 KiB
Python
Raw Permalink Normal View History

2023-03-15 16:35:13 +01:00
import torch
class PoissonLayer(torch.nn.Module):
_number_of_spikes: int
def __init__(
self,
number_of_spikes: int = 1,
) -> None:
super().__init__()
self._number_of_spikes = number_of_spikes
self.functional_poisson = FunctionalPoisson.apply
def forward(self, input: torch.Tensor) -> torch.Tensor:
assert input.ndim == 4
assert self._number_of_spikes > 0
parameter_list = torch.tensor(
[
int(self._number_of_spikes), # 0
],
dtype=torch.int64,
)
output = self.functional_poisson(
input,
parameter_list,
)
return output
class FunctionalPoisson(torch.autograd.Function):
@staticmethod
def forward( # type: ignore
ctx,
input: torch.Tensor,
parameter_list: torch.Tensor,
) -> torch.Tensor:
number_of_spikes: float = float(parameter_list[0])
input = (
number_of_spikes
* input
/ (
input.max(dim=-1, keepdim=True)[0]
.max(dim=-2, keepdim=True)[0]
.max(dim=-3, keepdim=True)[0]
+ 1e-20
)
)
output = torch.poisson(input)
output = output / (
output.sum(dim=-1, keepdim=True)
.sum(dim=-2, keepdim=True)
.sum(dim=-3, keepdim=True)
+ 1e-20
)
return output
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output
grad_parameter_list = None
return (grad_input, grad_parameter_list)