209 lines
6.8 KiB
Python
209 lines
6.8 KiB
Python
|
import torch
|
||
|
from non_linear_weigth_function import non_linear_weigth_function
|
||
|
|
||
|
|
||
|
class NNMF2dConvGrouped(torch.nn.Module):
|
||
|
|
||
|
in_channels: int
|
||
|
out_channels: int
|
||
|
weight: torch.Tensor
|
||
|
iterations: int
|
||
|
epsilon: float | None
|
||
|
init_min: float
|
||
|
init_max: float
|
||
|
beta: torch.Tensor | None
|
||
|
positive_function_type: int
|
||
|
convolution_contribution_map: None | torch.Tensor = None
|
||
|
convolution_contribution_map_enable: bool
|
||
|
convolution_ip_norm: bool
|
||
|
kernel_size: tuple[int, ...]
|
||
|
stride: tuple[int, ...]
|
||
|
padding: str | tuple[int, ...]
|
||
|
dilation: tuple[int, ...]
|
||
|
output_size: None | torch.Tensor = None
|
||
|
groups: int
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
in_channels: int,
|
||
|
out_channels: int,
|
||
|
kernel_size: tuple[int, int],
|
||
|
groups: int = 1,
|
||
|
device=None,
|
||
|
dtype=None,
|
||
|
iterations: int = 20,
|
||
|
epsilon: float | None = None,
|
||
|
init_min: float = 0.0,
|
||
|
init_max: float = 1.0,
|
||
|
beta: float | None = None,
|
||
|
positive_function_type: int = 0,
|
||
|
convolution_contribution_map_enable: bool = False,
|
||
|
stride: tuple[int, int] = (1, 1),
|
||
|
padding: str | tuple[int, int] = (0, 0),
|
||
|
dilation: tuple[int, int] = (1, 1),
|
||
|
convolution_ip_norm: bool = True,
|
||
|
) -> None:
|
||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||
|
|
||
|
super().__init__()
|
||
|
|
||
|
valid_padding_strings = {"same", "valid"}
|
||
|
if isinstance(padding, str):
|
||
|
if padding not in valid_padding_strings:
|
||
|
raise ValueError(
|
||
|
f"Invalid padding string {padding!r}, should be one of {valid_padding_strings}"
|
||
|
)
|
||
|
if padding == "same" and any(s != 1 for s in stride):
|
||
|
raise ValueError(
|
||
|
"padding='same' is not supported for strided convolutions"
|
||
|
)
|
||
|
|
||
|
self.positive_function_type = positive_function_type
|
||
|
self.init_min = init_min
|
||
|
self.init_max = init_max
|
||
|
|
||
|
self.groups = groups
|
||
|
assert (
|
||
|
in_channels % self.groups == 0
|
||
|
), f"Can't divide without rest {in_channels} / {self.groups}"
|
||
|
self.in_channels = in_channels // self.groups
|
||
|
self.out_channels = out_channels
|
||
|
|
||
|
self.iterations = iterations
|
||
|
self.kernel_size = kernel_size
|
||
|
self.stride = stride
|
||
|
self.padding = padding
|
||
|
self.dilation = dilation
|
||
|
self.convolution_contribution_map_enable = convolution_contribution_map_enable
|
||
|
self.convolution_ip_norm = convolution_ip_norm
|
||
|
|
||
|
self.weight = torch.nn.parameter.Parameter(
|
||
|
torch.empty(
|
||
|
(out_channels, self.in_channels, *kernel_size), **factory_kwargs
|
||
|
)
|
||
|
)
|
||
|
|
||
|
if beta is not None:
|
||
|
self.beta = torch.nn.parameter.Parameter(torch.empty((1), **factory_kwargs))
|
||
|
self.beta.data[0] = beta
|
||
|
else:
|
||
|
self.beta = None
|
||
|
|
||
|
self.reset_parameters()
|
||
|
|
||
|
self.epsilon = epsilon
|
||
|
|
||
|
def extra_repr(self) -> str:
|
||
|
s: str = f"{self.in_channels}, {self.out_channels},"
|
||
|
s += f"kernel_size={self.kernel_size},"
|
||
|
s += f"stride={self.stride}, iterations={self.iterations}"
|
||
|
if self.epsilon is not None:
|
||
|
s += f", epsilon={self.epsilon}"
|
||
|
s += f", pfunctype={self.positive_function_type}"
|
||
|
s += f", groups={self.groups}"
|
||
|
|
||
|
if self.padding != (0,) * len(self.padding):
|
||
|
s += f", padding={self.padding}"
|
||
|
if self.dilation != (1,) * len(self.dilation):
|
||
|
s += f", dilation={self.dilation}"
|
||
|
return s
|
||
|
|
||
|
def reset_parameters(self) -> None:
|
||
|
torch.nn.init.uniform_(self.weight, a=self.init_min, b=self.init_max)
|
||
|
|
||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||
|
|
||
|
if input.ndim == 2:
|
||
|
input = input.unsqueeze(-1)
|
||
|
if input.ndim == 3:
|
||
|
input = input.unsqueeze(-1)
|
||
|
|
||
|
if self.output_size is None:
|
||
|
self.output_size = torch.tensor(
|
||
|
torch.nn.functional.conv2d(
|
||
|
torch.zeros(
|
||
|
1,
|
||
|
input.shape[1],
|
||
|
input.shape[2],
|
||
|
input.shape[3],
|
||
|
device=self.weight.device,
|
||
|
dtype=self.weight.dtype,
|
||
|
),
|
||
|
torch.zeros_like(self.weight),
|
||
|
stride=self.stride,
|
||
|
padding=self.padding,
|
||
|
dilation=self.dilation,
|
||
|
groups=self.groups,
|
||
|
).shape,
|
||
|
requires_grad=False,
|
||
|
)
|
||
|
assert self.output_size is not None
|
||
|
|
||
|
positive_weights = non_linear_weigth_function(
|
||
|
self.weight, self.beta, self.positive_function_type
|
||
|
)
|
||
|
|
||
|
positive_weights = positive_weights / (
|
||
|
positive_weights.sum(dim=-1, keepdim=True) + 10e-20
|
||
|
)
|
||
|
|
||
|
input = input / (input.sum((1, 2, 3), keepdim=True) + 10e-20)
|
||
|
|
||
|
# Prepare h
|
||
|
self.output_size[0] = input.shape[0]
|
||
|
h = torch.full(
|
||
|
self.output_size.tolist(),
|
||
|
1.0 / float(self.output_size[1]),
|
||
|
device=input.device,
|
||
|
dtype=input.dtype,
|
||
|
)
|
||
|
|
||
|
if self.convolution_ip_norm:
|
||
|
pass
|
||
|
else:
|
||
|
h = h / (h.sum((1, 2, 3), keepdim=True) + 10e-20)
|
||
|
|
||
|
for _ in range(0, self.iterations):
|
||
|
|
||
|
factor_x_div_r: torch.Tensor = input / (
|
||
|
torch.nn.functional.conv_transpose2d(
|
||
|
h,
|
||
|
positive_weights,
|
||
|
stride=self.stride,
|
||
|
padding=self.padding, # type: ignore
|
||
|
dilation=self.dilation,
|
||
|
groups=self.groups,
|
||
|
)
|
||
|
+ 10e-20
|
||
|
)
|
||
|
|
||
|
if self.epsilon is None:
|
||
|
h = h * torch.nn.functional.conv2d(
|
||
|
factor_x_div_r,
|
||
|
positive_weights,
|
||
|
stride=self.stride,
|
||
|
padding=self.padding,
|
||
|
dilation=self.dilation,
|
||
|
groups=self.groups,
|
||
|
)
|
||
|
else:
|
||
|
h = h * (
|
||
|
1
|
||
|
+ self.epsilon
|
||
|
* torch.nn.functional.conv2d(
|
||
|
factor_x_div_r,
|
||
|
positive_weights,
|
||
|
stride=self.stride,
|
||
|
padding=self.padding,
|
||
|
dilation=self.dilation,
|
||
|
groups=self.groups,
|
||
|
)
|
||
|
)
|
||
|
if self.convolution_ip_norm:
|
||
|
h = h / (h.sum(1, keepdim=True) + 10e-20)
|
||
|
else:
|
||
|
h = h / (h.sum((1, 2, 3), keepdim=True) + 10e-20)
|
||
|
|
||
|
assert torch.isfinite(h).all()
|
||
|
return h
|