23 lines
642 B
Python
23 lines
642 B
Python
import torch
|
|
|
|
class SplitOnOffLayer(torch.nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
|
|
####################################################################
|
|
# Forward #
|
|
####################################################################
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
assert input.ndim == 4
|
|
|
|
temp = input - 0.5
|
|
temp_a = torch.nn.functional.relu(temp)
|
|
temp_b = torch.nn.functional.relu(-temp)
|
|
output = torch.cat((temp_a, temp_b), dim=1)
|
|
|
|
return output
|