14 lines
319 B
Python
14 lines
319 B
Python
|
import torch
|
||
|
|
||
|
|
||
|
class L1NormLayer(torch.nn.Module):
|
||
|
|
||
|
epsilon: float
|
||
|
|
||
|
def __init__(self, epsilon: float = 10e-20) -> None:
|
||
|
super().__init__()
|
||
|
self.epsilon = epsilon
|
||
|
|
||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||
|
return input / (input.sum(dim=1, keepdim=True) + self.epsilon)
|