55 lines
1.6 KiB
Python
55 lines
1.6 KiB
Python
|
import torch
|
||
|
|
||
|
|
||
|
class SplitOnOffLayer(torch.nn.Module):
|
||
|
|
||
|
device: torch.device
|
||
|
default_dtype: torch.dtype
|
||
|
|
||
|
mean: torch.Tensor | None = None
|
||
|
epsilon: float = 0.01
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
device: torch.device | None = None,
|
||
|
default_dtype: torch.dtype | None = None,
|
||
|
) -> None:
|
||
|
super().__init__()
|
||
|
|
||
|
assert device is not None
|
||
|
assert default_dtype is not None
|
||
|
self.device = device
|
||
|
self.default_dtype = default_dtype
|
||
|
|
||
|
####################################################################
|
||
|
# Forward #
|
||
|
####################################################################
|
||
|
|
||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||
|
assert input.ndim == 4
|
||
|
|
||
|
# self.training is switched by network.eval() and network.train()
|
||
|
if self.training is True:
|
||
|
mean_temp = (
|
||
|
input.mean(dim=0, keepdim=True)
|
||
|
.mean(dim=1, keepdim=True)
|
||
|
.detach()
|
||
|
.clone()
|
||
|
)
|
||
|
|
||
|
if self.mean is None:
|
||
|
self.mean = mean_temp
|
||
|
else:
|
||
|
self.mean = (1.0 - self.epsilon) * self.mean + self.epsilon * mean_temp
|
||
|
|
||
|
assert self.mean is not None
|
||
|
|
||
|
temp = input - self.mean.detach().clone()
|
||
|
temp_a = torch.nn.functional.relu(temp)
|
||
|
temp_b = torch.nn.functional.relu(-temp)
|
||
|
output = torch.cat((temp_a, temp_b), dim=1)
|
||
|
|
||
|
output /= output.sum(dim=1, keepdim=True) + 1e-20
|
||
|
|
||
|
return output
|