Delete NNMFConv2dP.py
This commit is contained in:
parent
06d82786ab
commit
e848d49a7c
1 changed files with 0 additions and 483 deletions
483
NNMFConv2dP.py
483
NNMFConv2dP.py
|
@ -1,483 +0,0 @@
|
||||||
import torch
|
|
||||||
from non_linear_weigth_function import non_linear_weigth_function
|
|
||||||
|
|
||||||
|
|
||||||
class NNMFConv2dP(torch.nn.Module):
|
|
||||||
|
|
||||||
in_channels: int
|
|
||||||
out_channels: int
|
|
||||||
kernel_size: tuple[int, ...]
|
|
||||||
stride: tuple[int, ...]
|
|
||||||
padding: str | tuple[int, ...]
|
|
||||||
dilation: tuple[int, ...]
|
|
||||||
weight: torch.Tensor
|
|
||||||
bias: None | torch.Tensor
|
|
||||||
output_size: None | torch.Tensor = None
|
|
||||||
convolution_contribution_map: None | torch.Tensor = None
|
|
||||||
iterations: int
|
|
||||||
convolution_contribution_map_enable: bool
|
|
||||||
epsilon: float | None
|
|
||||||
init_min: float
|
|
||||||
init_max: float
|
|
||||||
beta: torch.Tensor | None
|
|
||||||
positive_function_type: int
|
|
||||||
use_convolution: bool
|
|
||||||
local_learning: bool
|
|
||||||
local_learning_kl: bool
|
|
||||||
use_reconstruction: bool
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int,
|
|
||||||
out_channels: int,
|
|
||||||
kernel_size: tuple[int, int],
|
|
||||||
stride: tuple[int, int] = (1, 1),
|
|
||||||
padding: str | tuple[int, int] = (0, 0),
|
|
||||||
dilation: tuple[int, int] = (1, 1),
|
|
||||||
device=None,
|
|
||||||
dtype=None,
|
|
||||||
iterations: int = 20,
|
|
||||||
convolution_contribution_map_enable: bool = False,
|
|
||||||
epsilon: float | None = None,
|
|
||||||
init_min: float = 0.0,
|
|
||||||
init_max: float = 1.0,
|
|
||||||
beta: float | None = None,
|
|
||||||
positive_function_type: int = 0,
|
|
||||||
use_convolution: bool = False,
|
|
||||||
local_learning: bool = False,
|
|
||||||
local_learning_kl: bool = False,
|
|
||||||
use_reconstruction: bool = False,
|
|
||||||
) -> None:
|
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
valid_padding_strings = {"same", "valid"}
|
|
||||||
if isinstance(padding, str):
|
|
||||||
if padding not in valid_padding_strings:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid padding string {padding!r}, should be one of {valid_padding_strings}"
|
|
||||||
)
|
|
||||||
if padding == "same" and any(s != 1 for s in stride):
|
|
||||||
raise ValueError(
|
|
||||||
"padding='same' is not supported for strided convolutions"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.positive_function_type = positive_function_type
|
|
||||||
self.init_min = init_min
|
|
||||||
self.init_max = init_max
|
|
||||||
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
|
|
||||||
self.stride = stride
|
|
||||||
self.padding = padding
|
|
||||||
self.dilation = dilation
|
|
||||||
|
|
||||||
self.iterations = iterations
|
|
||||||
self.convolution_contribution_map_enable = convolution_contribution_map_enable
|
|
||||||
|
|
||||||
self.local_learning = local_learning
|
|
||||||
self.local_learning_kl = local_learning_kl
|
|
||||||
|
|
||||||
self.use_reconstruction = use_reconstruction
|
|
||||||
|
|
||||||
self.weight = torch.nn.parameter.Parameter(
|
|
||||||
torch.empty((out_channels, in_channels, *kernel_size), **factory_kwargs)
|
|
||||||
)
|
|
||||||
|
|
||||||
if beta is not None:
|
|
||||||
self.beta = torch.nn.parameter.Parameter(torch.empty((1), **factory_kwargs))
|
|
||||||
self.beta.data[0] = beta
|
|
||||||
else:
|
|
||||||
self.beta = None
|
|
||||||
|
|
||||||
self.reset_parameters()
|
|
||||||
self.functional_nnmf_conv2d = FunctionalNNMFConv2dP.apply
|
|
||||||
|
|
||||||
self.epsilon = epsilon
|
|
||||||
self.use_convolution = use_convolution
|
|
||||||
|
|
||||||
assert self.use_convolution is False
|
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
|
||||||
s: str = f"{self.in_channels}, {self.out_channels}"
|
|
||||||
|
|
||||||
s += f", kernel_size={self.kernel_size}"
|
|
||||||
s += f", stride={self.stride}, iterations={self.iterations}"
|
|
||||||
s += f", epsilon={self.epsilon}"
|
|
||||||
s += f", use_convolution={self.use_convolution}"
|
|
||||||
|
|
||||||
if self.use_convolution:
|
|
||||||
s += f", ccmap={self.convolution_contribution_map_enable}"
|
|
||||||
|
|
||||||
s += f", pfunctype={self.positive_function_type}"
|
|
||||||
s += f", local_learning={self.local_learning}"
|
|
||||||
|
|
||||||
if self.local_learning:
|
|
||||||
s += f", local_learning_kl={self.local_learning_kl}"
|
|
||||||
|
|
||||||
if self.padding != (0,) * len(self.padding):
|
|
||||||
s += f", padding={self.padding}"
|
|
||||||
|
|
||||||
if self.dilation != (1,) * len(self.dilation):
|
|
||||||
s += f", dilation={self.dilation}"
|
|
||||||
|
|
||||||
return s
|
|
||||||
|
|
||||||
def reset_parameters(self) -> None:
|
|
||||||
torch.nn.init.uniform_(self.weight, a=self.init_min, b=self.init_max)
|
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
||||||
|
|
||||||
if input.ndim == 2:
|
|
||||||
input = input.unsqueeze(-1)
|
|
||||||
if input.ndim == 3:
|
|
||||||
input = input.unsqueeze(-1)
|
|
||||||
|
|
||||||
if self.output_size is None:
|
|
||||||
self.output_size = torch.tensor(
|
|
||||||
torch.nn.functional.conv2d(
|
|
||||||
torch.zeros(
|
|
||||||
1,
|
|
||||||
input.shape[1],
|
|
||||||
input.shape[2],
|
|
||||||
input.shape[3],
|
|
||||||
device=self.weight.device,
|
|
||||||
dtype=self.weight.dtype,
|
|
||||||
),
|
|
||||||
torch.zeros_like(self.weight),
|
|
||||||
stride=self.stride,
|
|
||||||
padding=self.padding,
|
|
||||||
dilation=self.dilation,
|
|
||||||
).shape,
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
assert self.output_size is not None
|
|
||||||
|
|
||||||
input = torch.nn.functional.fold(
|
|
||||||
torch.nn.functional.unfold(
|
|
||||||
input.requires_grad_(True),
|
|
||||||
kernel_size=self.kernel_size,
|
|
||||||
dilation=self.dilation,
|
|
||||||
padding=self.padding,
|
|
||||||
stride=self.stride,
|
|
||||||
),
|
|
||||||
output_size=self.output_size[-2:],
|
|
||||||
kernel_size=(1, 1),
|
|
||||||
dilation=(1, 1),
|
|
||||||
padding=(0, 0),
|
|
||||||
stride=(1, 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
positive_weights = non_linear_weigth_function(
|
|
||||||
self.weight, self.beta, self.positive_function_type
|
|
||||||
)
|
|
||||||
positive_weights = positive_weights / (
|
|
||||||
positive_weights.sum((1, 2, 3), keepdim=True) + 10e-20
|
|
||||||
)
|
|
||||||
|
|
||||||
positive_weights = positive_weights.reshape(
|
|
||||||
positive_weights.shape[0],
|
|
||||||
positive_weights.shape[1]
|
|
||||||
* positive_weights.shape[2]
|
|
||||||
* positive_weights.shape[3],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare input
|
|
||||||
input = input / (input.sum(dim=1, keepdim=True) + 10e-20)
|
|
||||||
|
|
||||||
h_dyn = self.functional_nnmf_conv2d(
|
|
||||||
input,
|
|
||||||
positive_weights,
|
|
||||||
self.output_size,
|
|
||||||
self.iterations,
|
|
||||||
self.stride,
|
|
||||||
self.padding,
|
|
||||||
self.dilation,
|
|
||||||
self.epsilon,
|
|
||||||
self.use_convolution,
|
|
||||||
self.local_learning,
|
|
||||||
self.local_learning_kl,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.use_reconstruction:
|
|
||||||
reconstruction = torch.nn.functional.linear(
|
|
||||||
h_dyn.movedim(1, -1), positive_weights.T
|
|
||||||
).movedim(-1, 1)
|
|
||||||
output = torch.cat((h_dyn, input - reconstruction), dim=1)
|
|
||||||
else:
|
|
||||||
output = torch.cat((h_dyn, input), dim=1)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class FunctionalNNMFConv2dP(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward( # type: ignore
|
|
||||||
ctx,
|
|
||||||
input: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
output_size: torch.Tensor,
|
|
||||||
iterations: int,
|
|
||||||
stride: tuple[int, int],
|
|
||||||
padding: str | tuple[int, int],
|
|
||||||
dilation: tuple[int, int],
|
|
||||||
epsilon: float | None,
|
|
||||||
use_convolution: bool,
|
|
||||||
local_learning: bool,
|
|
||||||
local_learning_kl: bool,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
# Prepare h
|
|
||||||
output_size[0] = input.shape[0]
|
|
||||||
h = torch.full(
|
|
||||||
output_size.tolist(),
|
|
||||||
1.0 / float(output_size[1]),
|
|
||||||
device=input.device,
|
|
||||||
dtype=input.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_convolution:
|
|
||||||
for _ in range(0, iterations):
|
|
||||||
factor_x_div_r: torch.Tensor = input / (
|
|
||||||
torch.nn.functional.conv_transpose2d(
|
|
||||||
h,
|
|
||||||
weight,
|
|
||||||
stride=stride,
|
|
||||||
padding=padding,
|
|
||||||
dilation=dilation,
|
|
||||||
)
|
|
||||||
+ 10e-20
|
|
||||||
)
|
|
||||||
|
|
||||||
if epsilon is None:
|
|
||||||
h *= torch.nn.functional.conv2d(
|
|
||||||
factor_x_div_r,
|
|
||||||
weight,
|
|
||||||
stride=stride,
|
|
||||||
padding=padding,
|
|
||||||
dilation=dilation,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
h *= 1 + epsilon * torch.nn.functional.conv2d(
|
|
||||||
factor_x_div_r,
|
|
||||||
weight,
|
|
||||||
stride=stride,
|
|
||||||
padding=padding,
|
|
||||||
dilation=dilation,
|
|
||||||
)
|
|
||||||
|
|
||||||
h /= h.sum(1, keepdim=True) + 10e-20
|
|
||||||
else:
|
|
||||||
h = h.movedim(1, -1)
|
|
||||||
input = input.movedim(1, -1)
|
|
||||||
for _ in range(0, iterations):
|
|
||||||
reconstruction = torch.nn.functional.linear(h, weight.T)
|
|
||||||
reconstruction += 1e-20
|
|
||||||
if epsilon is None:
|
|
||||||
h *= torch.nn.functional.linear((input / reconstruction), weight)
|
|
||||||
else:
|
|
||||||
h *= 1 + epsilon * torch.nn.functional.linear(
|
|
||||||
(input / reconstruction), weight
|
|
||||||
)
|
|
||||||
h /= h.sum(-1, keepdim=True) + 10e-20
|
|
||||||
h = h.movedim(-1, 1)
|
|
||||||
input = input.movedim(-1, 1)
|
|
||||||
|
|
||||||
# ###########################################################
|
|
||||||
# Save the necessary data for the backward pass
|
|
||||||
# ###########################################################
|
|
||||||
ctx.save_for_backward(input, weight, h)
|
|
||||||
|
|
||||||
ctx.stride = stride
|
|
||||||
ctx.padding = padding
|
|
||||||
ctx.dilation = dilation
|
|
||||||
ctx.use_convolution = use_convolution
|
|
||||||
ctx.local_learning = local_learning
|
|
||||||
ctx.local_learning_kl = local_learning_kl
|
|
||||||
|
|
||||||
assert torch.isfinite(h).all()
|
|
||||||
return h
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@torch.autograd.function.once_differentiable
|
|
||||||
def backward(ctx, grad_output: torch.Tensor) -> tuple[ # type: ignore
|
|
||||||
torch.Tensor | None,
|
|
||||||
torch.Tensor | None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
]:
|
|
||||||
|
|
||||||
# ##############################################
|
|
||||||
# Default values
|
|
||||||
# ##############################################
|
|
||||||
grad_input: torch.Tensor | None = None
|
|
||||||
grad_weight: torch.Tensor | None = None
|
|
||||||
|
|
||||||
# ##############################################
|
|
||||||
# Get the variables back
|
|
||||||
# ##############################################
|
|
||||||
(input, weight, h) = ctx.saved_tensors
|
|
||||||
|
|
||||||
if ctx.use_convolution:
|
|
||||||
big_r: torch.Tensor = torch.nn.functional.conv_transpose2d(
|
|
||||||
h,
|
|
||||||
weight,
|
|
||||||
stride=ctx.stride,
|
|
||||||
padding=ctx.padding,
|
|
||||||
dilation=ctx.dilation,
|
|
||||||
)
|
|
||||||
big_r_div = 1.0 / (big_r + 1e-20)
|
|
||||||
|
|
||||||
factor_x_div_r: torch.Tensor = input * big_r_div
|
|
||||||
|
|
||||||
grad_input = (
|
|
||||||
torch.nn.functional.conv_transpose2d(
|
|
||||||
(h * grad_output),
|
|
||||||
weight,
|
|
||||||
stride=ctx.stride,
|
|
||||||
padding=ctx.padding,
|
|
||||||
dilation=ctx.dilation,
|
|
||||||
)
|
|
||||||
* big_r_div
|
|
||||||
)
|
|
||||||
|
|
||||||
del big_r_div
|
|
||||||
if ctx.local_learning is False:
|
|
||||||
del big_r
|
|
||||||
grad_weight = -torch.nn.functional.conv2d(
|
|
||||||
(factor_x_div_r * grad_input).movedim(0, 1),
|
|
||||||
h.movedim(0, 1),
|
|
||||||
stride=ctx.dilation,
|
|
||||||
padding=ctx.padding,
|
|
||||||
dilation=ctx.stride,
|
|
||||||
)
|
|
||||||
|
|
||||||
grad_weight += torch.nn.functional.conv2d(
|
|
||||||
factor_x_div_r.movedim(0, 1),
|
|
||||||
(h * grad_output).movedim(0, 1),
|
|
||||||
stride=ctx.dilation,
|
|
||||||
padding=ctx.padding,
|
|
||||||
dilation=ctx.stride,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
if ctx.local_learning_kl:
|
|
||||||
grad_weight = -torch.nn.functional.conv2d(
|
|
||||||
factor_x_div_r.movedim(0, 1),
|
|
||||||
h.movedim(0, 1),
|
|
||||||
stride=ctx.dilation,
|
|
||||||
padding=ctx.padding,
|
|
||||||
dilation=ctx.stride,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
grad_weight = -torch.nn.functional.conv2d(
|
|
||||||
(2 * (input - big_r)).movedim(0, 1),
|
|
||||||
h.movedim(0, 1),
|
|
||||||
stride=ctx.dilation,
|
|
||||||
padding=ctx.padding,
|
|
||||||
dilation=ctx.stride,
|
|
||||||
)
|
|
||||||
grad_weight = grad_weight.movedim(0, 1)
|
|
||||||
else:
|
|
||||||
h = h.movedim(1, -1)
|
|
||||||
grad_output = grad_output.movedim(1, -1)
|
|
||||||
input = input.movedim(1, -1)
|
|
||||||
big_r = torch.nn.functional.linear(h, weight.T)
|
|
||||||
big_r_div = 1.0 / (big_r + 1e-20)
|
|
||||||
|
|
||||||
factor_x_div_r = input * big_r_div
|
|
||||||
|
|
||||||
grad_input = (
|
|
||||||
torch.nn.functional.linear(h * grad_output, weight.T) * big_r_div
|
|
||||||
)
|
|
||||||
|
|
||||||
del big_r_div
|
|
||||||
|
|
||||||
if ctx.local_learning is False:
|
|
||||||
del big_r
|
|
||||||
|
|
||||||
grad_weight = -torch.nn.functional.linear(
|
|
||||||
h.reshape(
|
|
||||||
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
|
||||||
h.shape[3],
|
|
||||||
).T,
|
|
||||||
(factor_x_div_r * grad_input)
|
|
||||||
.reshape(
|
|
||||||
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
|
||||||
grad_input.shape[3],
|
|
||||||
)
|
|
||||||
.T,
|
|
||||||
)
|
|
||||||
|
|
||||||
grad_weight += torch.nn.functional.linear(
|
|
||||||
(h * grad_output)
|
|
||||||
.reshape(
|
|
||||||
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
|
||||||
h.shape[3],
|
|
||||||
)
|
|
||||||
.T,
|
|
||||||
factor_x_div_r.reshape(
|
|
||||||
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
|
||||||
grad_input.shape[3],
|
|
||||||
).T,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
if ctx.local_learning_kl:
|
|
||||||
grad_weight = -torch.nn.functional.linear(
|
|
||||||
h.reshape(
|
|
||||||
grad_input.shape[0]
|
|
||||||
* grad_input.shape[1]
|
|
||||||
* grad_input.shape[2],
|
|
||||||
h.shape[3],
|
|
||||||
).T,
|
|
||||||
factor_x_div_r.reshape(
|
|
||||||
grad_input.shape[0]
|
|
||||||
* grad_input.shape[1]
|
|
||||||
* grad_input.shape[2],
|
|
||||||
grad_input.shape[3],
|
|
||||||
).T,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
grad_weight = -torch.nn.functional.linear(
|
|
||||||
h.reshape(
|
|
||||||
grad_input.shape[0]
|
|
||||||
* grad_input.shape[1]
|
|
||||||
* grad_input.shape[2],
|
|
||||||
h.shape[3],
|
|
||||||
).T,
|
|
||||||
(2 * (input - big_r))
|
|
||||||
.reshape(
|
|
||||||
grad_input.shape[0]
|
|
||||||
* grad_input.shape[1]
|
|
||||||
* grad_input.shape[2],
|
|
||||||
grad_input.shape[3],
|
|
||||||
)
|
|
||||||
.T,
|
|
||||||
)
|
|
||||||
grad_input = grad_input.movedim(-1, 1)
|
|
||||||
assert torch.isfinite(grad_input).all()
|
|
||||||
assert torch.isfinite(grad_weight).all()
|
|
||||||
|
|
||||||
return (
|
|
||||||
grad_input,
|
|
||||||
grad_weight,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
)
|
|
Loading…
Reference in a new issue