import torch from non_linear_weigth_function import non_linear_weigth_function class NNMF2dAutograd(torch.nn.Module): in_channels: int out_channels: int weight: torch.Tensor iterations: int epsilon: float | None init_min: float init_max: float beta: torch.Tensor | None positive_function_type: int local_learning: bool local_learning_kl: bool def __init__( self, in_channels: int, out_channels: int, device=None, dtype=None, iterations: int = 20, epsilon: float | None = None, init_min: float = 0.0, init_max: float = 1.0, beta: float | None = None, positive_function_type: int = 0, local_learning: bool = False, local_learning_kl: bool = False, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.positive_function_type = positive_function_type self.init_min = init_min self.init_max = init_max self.in_channels = in_channels self.out_channels = out_channels self.iterations = iterations self.local_learning = local_learning self.local_learning_kl = local_learning_kl self.weight = torch.nn.parameter.Parameter( torch.empty((out_channels, in_channels), **factory_kwargs) ) if beta is not None: self.beta = torch.nn.parameter.Parameter(torch.empty((1), **factory_kwargs)) self.beta.data[0] = beta else: self.beta = None self.reset_parameters() self.epsilon = epsilon def extra_repr(self) -> str: s: str = f"{self.in_channels}, {self.out_channels}" if self.epsilon is not None: s += f", epsilon={self.epsilon}" s += f", pfunctype={self.positive_function_type}" s += f", local_learning={self.local_learning}" if self.local_learning: s += f", local_learning_kl={self.local_learning_kl}" return s def reset_parameters(self) -> None: torch.nn.init.uniform_(self.weight, a=self.init_min, b=self.init_max) def forward(self, input: torch.Tensor) -> torch.Tensor: positive_weights = non_linear_weigth_function( self.weight, self.beta, self.positive_function_type ) positive_weights = positive_weights / ( positive_weights.sum(dim=1, keepdim=True) + 10e-20 ) # --------------------- # Prepare h h = torch.full( (input.shape[0], self.out_channels, input.shape[-2], input.shape[-1]), 1.0 / float(self.out_channels), device=input.device, dtype=input.dtype, ) h = h.movedim(1, -1) input = input.movedim(1, -1) for _ in range(0, self.iterations): reconstruction = torch.nn.functional.linear(h, positive_weights.T) reconstruction = reconstruction + 1e-20 if self.epsilon is None: h = h * torch.nn.functional.linear( (input / reconstruction), positive_weights ) else: h = h * ( 1 + self.epsilon * torch.nn.functional.linear( (input / reconstruction), positive_weights ) ) h = h / (h.sum(-1, keepdim=True) + 10e-20) h = h.movedim(-1, 1) input = input.movedim(-1, 1) assert torch.isfinite(h).all() return h