122 lines
3.5 KiB
Python
122 lines
3.5 KiB
Python
|
import torch
|
||
|
from non_linear_weigth_function import non_linear_weigth_function
|
||
|
|
||
|
|
||
|
class NNMF2dAutograd(torch.nn.Module):
|
||
|
|
||
|
in_channels: int
|
||
|
out_channels: int
|
||
|
weight: torch.Tensor
|
||
|
iterations: int
|
||
|
epsilon: float | None
|
||
|
init_min: float
|
||
|
init_max: float
|
||
|
beta: torch.Tensor | None
|
||
|
positive_function_type: int
|
||
|
local_learning: bool
|
||
|
local_learning_kl: bool
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
in_channels: int,
|
||
|
out_channels: int,
|
||
|
device=None,
|
||
|
dtype=None,
|
||
|
iterations: int = 20,
|
||
|
epsilon: float | None = None,
|
||
|
init_min: float = 0.0,
|
||
|
init_max: float = 1.0,
|
||
|
beta: float | None = None,
|
||
|
positive_function_type: int = 0,
|
||
|
local_learning: bool = False,
|
||
|
local_learning_kl: bool = False,
|
||
|
) -> None:
|
||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||
|
|
||
|
super().__init__()
|
||
|
|
||
|
self.positive_function_type = positive_function_type
|
||
|
self.init_min = init_min
|
||
|
self.init_max = init_max
|
||
|
|
||
|
self.in_channels = in_channels
|
||
|
self.out_channels = out_channels
|
||
|
|
||
|
self.iterations = iterations
|
||
|
self.local_learning = local_learning
|
||
|
self.local_learning_kl = local_learning_kl
|
||
|
|
||
|
self.weight = torch.nn.parameter.Parameter(
|
||
|
torch.empty((out_channels, in_channels), **factory_kwargs)
|
||
|
)
|
||
|
|
||
|
if beta is not None:
|
||
|
self.beta = torch.nn.parameter.Parameter(torch.empty((1), **factory_kwargs))
|
||
|
self.beta.data[0] = beta
|
||
|
else:
|
||
|
self.beta = None
|
||
|
|
||
|
self.reset_parameters()
|
||
|
|
||
|
self.epsilon = epsilon
|
||
|
|
||
|
def extra_repr(self) -> str:
|
||
|
s: str = f"{self.in_channels}, {self.out_channels}"
|
||
|
|
||
|
if self.epsilon is not None:
|
||
|
s += f", epsilon={self.epsilon}"
|
||
|
s += f", pfunctype={self.positive_function_type}"
|
||
|
s += f", local_learning={self.local_learning}"
|
||
|
|
||
|
if self.local_learning:
|
||
|
s += f", local_learning_kl={self.local_learning_kl}"
|
||
|
|
||
|
return s
|
||
|
|
||
|
def reset_parameters(self) -> None:
|
||
|
torch.nn.init.uniform_(self.weight, a=self.init_min, b=self.init_max)
|
||
|
|
||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||
|
|
||
|
positive_weights = non_linear_weigth_function(
|
||
|
self.weight, self.beta, self.positive_function_type
|
||
|
)
|
||
|
positive_weights = positive_weights / (
|
||
|
positive_weights.sum(dim=1, keepdim=True) + 10e-20
|
||
|
)
|
||
|
|
||
|
# ---------------------
|
||
|
|
||
|
# Prepare h
|
||
|
h = torch.full(
|
||
|
(input.shape[0], self.out_channels, input.shape[-2], input.shape[-1]),
|
||
|
1.0 / float(self.out_channels),
|
||
|
device=input.device,
|
||
|
dtype=input.dtype,
|
||
|
)
|
||
|
|
||
|
h = h.movedim(1, -1)
|
||
|
input = input.movedim(1, -1)
|
||
|
for _ in range(0, self.iterations):
|
||
|
reconstruction = torch.nn.functional.linear(h, positive_weights.T)
|
||
|
reconstruction = reconstruction + 1e-20
|
||
|
if self.epsilon is None:
|
||
|
h = h * torch.nn.functional.linear(
|
||
|
(input / reconstruction), positive_weights
|
||
|
)
|
||
|
else:
|
||
|
h = h * (
|
||
|
1
|
||
|
+ self.epsilon
|
||
|
* torch.nn.functional.linear(
|
||
|
(input / reconstruction), positive_weights
|
||
|
)
|
||
|
)
|
||
|
h = h / (h.sum(-1, keepdim=True) + 10e-20)
|
||
|
h = h.movedim(-1, 1)
|
||
|
input = input.movedim(-1, 1)
|
||
|
|
||
|
assert torch.isfinite(h).all()
|
||
|
|
||
|
return h
|