pytorch-sbs/network/SplitOnOffLayer.py

56 lines
1.6 KiB
Python
Raw Normal View History

2023-01-05 13:23:58 +01:00
import torch
class SplitOnOffLayer(torch.nn.Module):
device: torch.device
default_dtype: torch.dtype
mean: torch.Tensor | None = None
epsilon: float = 0.01
def __init__(
self,
device: torch.device | None = None,
default_dtype: torch.dtype | None = None,
) -> None:
super().__init__()
assert device is not None
assert default_dtype is not None
self.device = device
self.default_dtype = default_dtype
####################################################################
# Forward #
####################################################################
def forward(self, input: torch.Tensor) -> torch.Tensor:
assert input.ndim == 4
2023-02-21 14:37:51 +01:00
# # self.training is switched by network.eval() and network.train()
# if self.training is True:
# mean_temp = (
# input.mean(dim=0, keepdim=True)
# .mean(dim=1, keepdim=True)
# .detach()
# .clone()
# )
#
# if self.mean is None:
# self.mean = mean_temp
# else:
# self.mean = (1.0 - self.epsilon) * self.mean + self.epsilon * mean_temp
#
# assert self.mean is not None
# temp = input - self.mean.detach().clone()
temp = input - 0.5
2023-01-05 13:23:58 +01:00
temp_a = torch.nn.functional.relu(temp)
temp_b = torch.nn.functional.relu(-temp)
output = torch.cat((temp_a, temp_b), dim=1)
2023-02-21 14:37:51 +01:00
#output /= output.sum(dim=1, keepdim=True) + 1e-20
2023-01-05 13:23:58 +01:00
return output