pytorch-sbs/network/SplitOnOffLayer.py
2023-02-21 14:37:51 +01:00

55 lines
1.6 KiB
Python

import torch
class SplitOnOffLayer(torch.nn.Module):
device: torch.device
default_dtype: torch.dtype
mean: torch.Tensor | None = None
epsilon: float = 0.01
def __init__(
self,
device: torch.device | None = None,
default_dtype: torch.dtype | None = None,
) -> None:
super().__init__()
assert device is not None
assert default_dtype is not None
self.device = device
self.default_dtype = default_dtype
####################################################################
# Forward #
####################################################################
def forward(self, input: torch.Tensor) -> torch.Tensor:
assert input.ndim == 4
# # self.training is switched by network.eval() and network.train()
# if self.training is True:
# mean_temp = (
# input.mean(dim=0, keepdim=True)
# .mean(dim=1, keepdim=True)
# .detach()
# .clone()
# )
#
# if self.mean is None:
# self.mean = mean_temp
# else:
# self.mean = (1.0 - self.epsilon) * self.mean + self.epsilon * mean_temp
#
# assert self.mean is not None
# temp = input - self.mean.detach().clone()
temp = input - 0.5
temp_a = torch.nn.functional.relu(temp)
temp_b = torch.nn.functional.relu(-temp)
output = torch.cat((temp_a, temp_b), dim=1)
#output /= output.sum(dim=1, keepdim=True) + 1e-20
return output