Bernstein_Poster_2024/basis_nnmf_groups5/L1NormLayer.py
David Rotermund a540a3f271 Initial
2024-10-21 16:43:42 +02:00

13 lines
319 B
Python

import torch
class L1NormLayer(torch.nn.Module):
epsilon: float
def __init__(self, epsilon: float = 10e-20) -> None:
super().__init__()
self.epsilon = epsilon
def forward(self, input: torch.Tensor) -> torch.Tensor:
return input / (input.sum(dim=1, keepdim=True) + self.epsilon)