nnmf_24a/SplitOnOffLayer.py

24 lines
642 B
Python
Raw Normal View History

2024-07-10 15:47:16 +02:00
import torch
class SplitOnOffLayer(torch.nn.Module):
def __init__(
self,
) -> None:
super().__init__()
####################################################################
# Forward #
####################################################################
def forward(self, input: torch.Tensor) -> torch.Tensor:
assert input.ndim == 4
temp = input - 0.5
temp_a = torch.nn.functional.relu(temp)
temp_b = torch.nn.functional.relu(-temp)
output = torch.cat((temp_a, temp_b), dim=1)
return output