pynnmf/SplitOnOffLayer.py

24 lines
642 B
Python
Raw Permalink Normal View History

2024-05-30 14:08:44 +02:00
import torch
class SplitOnOffLayer(torch.nn.Module):
def __init__(
self,
) -> None:
super().__init__()
####################################################################
# Forward #
####################################################################
def forward(self, input: torch.Tensor) -> torch.Tensor:
assert input.ndim == 4
temp = input - 0.5
temp_a = torch.nn.functional.relu(temp)
temp_b = torch.nn.functional.relu(-temp)
output = torch.cat((temp_a, temp_b), dim=1)
return output