Bernstein_Poster_2024/avg_pooling_nnmf_fmask/MaskLayerFeatures.py

36 lines
847 B
Python
Raw Normal View History

2024-11-05 18:20:02 +01:00
import torch
class MaskLayerFeatures(torch.nn.Module):
weight: torch.Tensor
init_min: float
init_max: float
def __init__(
self,
in_channels: int,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.weight = torch.nn.parameter.Parameter(
torch.ones((in_channels), **factory_kwargs)
)
self.init_min = 0.0
self.init_max = 1.0
self.reset_parameters()
def reset_parameters(self) -> None:
torch.nn.init.uniform_(self.weight, a=self.init_min, b=self.init_max)
def forward(self, input: torch.Tensor) -> torch.Tensor:
positive_weights = torch.abs(self.weight)
return input * positive_weights.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)