35 lines
847 B
Python
35 lines
847 B
Python
import torch
|
|
|
|
|
|
class MaskLayerFeatures(torch.nn.Module):
|
|
|
|
weight: torch.Tensor
|
|
init_min: float
|
|
init_max: float
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
device=None,
|
|
dtype=None,
|
|
) -> None:
|
|
factory_kwargs = {"device": device, "dtype": dtype}
|
|
|
|
super().__init__()
|
|
|
|
self.weight = torch.nn.parameter.Parameter(
|
|
torch.ones((in_channels), **factory_kwargs)
|
|
)
|
|
|
|
self.init_min = 0.0
|
|
self.init_max = 1.0
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self) -> None:
|
|
torch.nn.init.uniform_(self.weight, a=self.init_min, b=self.init_max)
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
|
|
positive_weights = torch.abs(self.weight)
|
|
|
|
return input * positive_weights.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|