import torch from non_linear_weigth_function import non_linear_weigth_function class NNMF2dConvGrouped(torch.nn.Module): in_channels: int out_channels: int weight: torch.Tensor iterations: int epsilon: float | None init_min: float init_max: float beta: torch.Tensor | None positive_function_type: int convolution_contribution_map: None | torch.Tensor = None convolution_contribution_map_enable: bool convolution_ip_norm: bool kernel_size: tuple[int, ...] stride: tuple[int, ...] padding: str | tuple[int, ...] dilation: tuple[int, ...] output_size: None | torch.Tensor = None groups: int def __init__( self, in_channels: int, out_channels: int, kernel_size: tuple[int, int], groups: int = 1, device=None, dtype=None, iterations: int = 20, epsilon: float | None = None, init_min: float = 0.0, init_max: float = 1.0, beta: float | None = None, positive_function_type: int = 0, convolution_contribution_map_enable: bool = False, stride: tuple[int, int] = (1, 1), padding: str | tuple[int, int] = (0, 0), dilation: tuple[int, int] = (1, 1), convolution_ip_norm: bool = True, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() valid_padding_strings = {"same", "valid"} if isinstance(padding, str): if padding not in valid_padding_strings: raise ValueError( f"Invalid padding string {padding!r}, should be one of {valid_padding_strings}" ) if padding == "same" and any(s != 1 for s in stride): raise ValueError( "padding='same' is not supported for strided convolutions" ) self.positive_function_type = positive_function_type self.init_min = init_min self.init_max = init_max self.groups = groups assert ( in_channels % self.groups == 0 ), f"Can't divide without rest {in_channels} / {self.groups}" self.in_channels = in_channels // self.groups self.out_channels = out_channels self.iterations = iterations self.kernel_size = kernel_size self.stride = stride self.padding = padding self.dilation = dilation self.convolution_contribution_map_enable = convolution_contribution_map_enable self.convolution_ip_norm = convolution_ip_norm self.weight = torch.nn.parameter.Parameter( torch.empty( (out_channels, self.in_channels, *kernel_size), **factory_kwargs ) ) if beta is not None: self.beta = torch.nn.parameter.Parameter(torch.empty((1), **factory_kwargs)) self.beta.data[0] = beta else: self.beta = None self.reset_parameters() self.epsilon = epsilon def extra_repr(self) -> str: s: str = f"{self.in_channels}, {self.out_channels}," s += f"kernel_size={self.kernel_size}," s += f"stride={self.stride}, iterations={self.iterations}" if self.epsilon is not None: s += f", epsilon={self.epsilon}" s += f", pfunctype={self.positive_function_type}" s += f", groups={self.groups}" if self.padding != (0,) * len(self.padding): s += f", padding={self.padding}" if self.dilation != (1,) * len(self.dilation): s += f", dilation={self.dilation}" return s def reset_parameters(self) -> None: torch.nn.init.uniform_(self.weight, a=self.init_min, b=self.init_max) def forward(self, input: torch.Tensor) -> torch.Tensor: if input.ndim == 2: input = input.unsqueeze(-1) if input.ndim == 3: input = input.unsqueeze(-1) if self.output_size is None: self.output_size = torch.tensor( torch.nn.functional.conv2d( torch.zeros( 1, input.shape[1], input.shape[2], input.shape[3], device=self.weight.device, dtype=self.weight.dtype, ), torch.zeros_like(self.weight), stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups, ).shape, requires_grad=False, ) assert self.output_size is not None positive_weights = non_linear_weigth_function( self.weight, self.beta, self.positive_function_type ) positive_weights = positive_weights / ( positive_weights.sum(dim=-1, keepdim=True) + 10e-20 ) input = input / (input.sum((1, 2, 3), keepdim=True) + 10e-20) # Prepare h self.output_size[0] = input.shape[0] h = torch.full( self.output_size.tolist(), 1.0 / float(self.output_size[1]), device=input.device, dtype=input.dtype, ) if self.convolution_ip_norm: pass else: h = h / (h.sum((1, 2, 3), keepdim=True) + 10e-20) for _ in range(0, self.iterations): factor_x_div_r: torch.Tensor = input / ( torch.nn.functional.conv_transpose2d( h, positive_weights, stride=self.stride, padding=self.padding, # type: ignore dilation=self.dilation, groups=self.groups, ) + 10e-20 ) if self.epsilon is None: h = h * torch.nn.functional.conv2d( factor_x_div_r, positive_weights, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups, ) else: h = h * ( 1 + self.epsilon * torch.nn.functional.conv2d( factor_x_div_r, positive_weights, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups, ) ) if self.convolution_ip_norm: h = h / (h.sum(1, keepdim=True) + 10e-20) else: h = h / (h.sum((1, 2, 3), keepdim=True) + 10e-20) assert torch.isfinite(h).all() return h