2023-01-13 21:31:39 +01:00
|
|
|
import torch
|
|
|
|
|
2023-02-04 14:24:47 +01:00
|
|
|
from network.SbSLayer import SbSLayer
|
2023-01-13 21:31:39 +01:00
|
|
|
|
|
|
|
|
|
|
|
class SbSReconstruction(torch.nn.Module):
|
|
|
|
|
2023-02-04 14:24:47 +01:00
|
|
|
_the_sbs_layer: SbSLayer
|
2023-01-13 21:31:39 +01:00
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
2023-02-04 14:24:47 +01:00
|
|
|
the_sbs_layer: SbSLayer,
|
2023-01-13 21:31:39 +01:00
|
|
|
) -> None:
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self._the_sbs_layer = the_sbs_layer
|
|
|
|
self.device = self._the_sbs_layer.device
|
|
|
|
self.default_dtype = self._the_sbs_layer.default_dtype
|
|
|
|
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
|
|
assert self._the_sbs_layer._weights_exists is True
|
|
|
|
|
|
|
|
input_norm = input / input.sum(dim=1, keepdim=True)
|
|
|
|
|
|
|
|
output = (
|
|
|
|
self._the_sbs_layer._weights.data.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
|
|
|
* input_norm.unsqueeze(1)
|
|
|
|
).sum(dim=2)
|
|
|
|
|
|
|
|
output /= output.sum(dim=1, keepdim=True)
|
|
|
|
|
|
|
|
return output
|