import torch class MaskLayerFeatures(torch.nn.Module): weight: torch.Tensor init_min: float init_max: float def __init__( self, in_channels: int, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.weight = torch.nn.parameter.Parameter( torch.ones((in_channels), **factory_kwargs) ) self.init_min = 0.0 self.init_max = 1.0 self.reset_parameters() def reset_parameters(self) -> None: torch.nn.init.uniform_(self.weight, a=self.init_min, b=self.init_max) def forward(self, input: torch.Tensor) -> torch.Tensor: positive_weights = torch.abs(self.weight) return input * positive_weights.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)