Add files via upload
This commit is contained in:
parent
0ad3b2c89c
commit
4e0f4f84f1
12 changed files with 2041 additions and 0 deletions
502
NNMFConv2d.py
Normal file
502
NNMFConv2d.py
Normal file
|
@ -0,0 +1,502 @@
|
||||||
|
import torch
|
||||||
|
from non_linear_weigth_function import non_linear_weigth_function
|
||||||
|
|
||||||
|
|
||||||
|
class NNMFConv2d(torch.nn.Module):
|
||||||
|
|
||||||
|
in_channels: int
|
||||||
|
out_channels: int
|
||||||
|
kernel_size: tuple[int, ...]
|
||||||
|
stride: tuple[int, ...]
|
||||||
|
padding: str | tuple[int, ...]
|
||||||
|
dilation: tuple[int, ...]
|
||||||
|
weight: torch.Tensor
|
||||||
|
bias: None | torch.Tensor
|
||||||
|
output_size: None | torch.Tensor = None
|
||||||
|
convolution_contribution_map: None | torch.Tensor = None
|
||||||
|
iterations: int
|
||||||
|
convolution_contribution_map_enable: bool
|
||||||
|
epsilon: float | None
|
||||||
|
init_min: float
|
||||||
|
init_max: float
|
||||||
|
beta: torch.Tensor | None
|
||||||
|
positive_function_type: int
|
||||||
|
use_convolution: bool
|
||||||
|
local_learning: bool
|
||||||
|
local_learning_kl: bool
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: tuple[int, int],
|
||||||
|
stride: tuple[int, int] = (1, 1),
|
||||||
|
padding: str | tuple[int, int] = (0, 0),
|
||||||
|
dilation: tuple[int, int] = (1, 1),
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
iterations: int = 20,
|
||||||
|
convolution_contribution_map_enable: bool = False,
|
||||||
|
epsilon: float | None = None,
|
||||||
|
init_min: float = 0.0,
|
||||||
|
init_max: float = 1.0,
|
||||||
|
beta: float | None = None,
|
||||||
|
positive_function_type: int = 0,
|
||||||
|
use_convolution: bool = False,
|
||||||
|
local_learning: bool = False,
|
||||||
|
local_learning_kl: bool = False,
|
||||||
|
) -> None:
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
valid_padding_strings = {"same", "valid"}
|
||||||
|
if isinstance(padding, str):
|
||||||
|
if padding not in valid_padding_strings:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid padding string {padding!r}, should be one of {valid_padding_strings}"
|
||||||
|
)
|
||||||
|
if padding == "same" and any(s != 1 for s in stride):
|
||||||
|
raise ValueError(
|
||||||
|
"padding='same' is not supported for strided convolutions"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.positive_function_type = positive_function_type
|
||||||
|
self.init_min = init_min
|
||||||
|
self.init_max = init_max
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
|
||||||
|
self.stride = stride
|
||||||
|
self.padding = padding
|
||||||
|
self.dilation = dilation
|
||||||
|
|
||||||
|
self.iterations = iterations
|
||||||
|
self.convolution_contribution_map_enable = convolution_contribution_map_enable
|
||||||
|
|
||||||
|
self.local_learning = local_learning
|
||||||
|
self.local_learning_kl = local_learning_kl
|
||||||
|
|
||||||
|
self.weight = torch.nn.parameter.Parameter(
|
||||||
|
torch.empty((out_channels, in_channels, *kernel_size), **factory_kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
|
if beta is not None:
|
||||||
|
self.beta = torch.nn.parameter.Parameter(torch.empty((1), **factory_kwargs))
|
||||||
|
self.beta.data[0] = beta
|
||||||
|
else:
|
||||||
|
self.beta = None
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
self.functional_nnmf_conv2d = FunctionalNNMFConv2d.apply
|
||||||
|
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.use_convolution = use_convolution
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
s: str = f"{self.in_channels}, {self.out_channels}"
|
||||||
|
|
||||||
|
s += f", kernel_size={self.kernel_size}"
|
||||||
|
s += f", stride={self.stride}, iterations={self.iterations}"
|
||||||
|
s += f", epsilon={self.epsilon}"
|
||||||
|
s += f", use_convolution={self.use_convolution}"
|
||||||
|
|
||||||
|
if self.use_convolution:
|
||||||
|
s += f", ccmap={self.convolution_contribution_map_enable}"
|
||||||
|
|
||||||
|
s += f", pfunctype={self.positive_function_type}"
|
||||||
|
s += f", local_learning={self.local_learning}"
|
||||||
|
|
||||||
|
if self.local_learning:
|
||||||
|
s += f", local_learning_kl={self.local_learning_kl}"
|
||||||
|
|
||||||
|
if self.padding != (0,) * len(self.padding):
|
||||||
|
s += f", padding={self.padding}"
|
||||||
|
|
||||||
|
if self.dilation != (1,) * len(self.dilation):
|
||||||
|
s += f", dilation={self.dilation}"
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
def reset_parameters(self) -> None:
|
||||||
|
torch.nn.init.uniform_(self.weight, a=self.init_min, b=self.init_max)
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
if input.ndim == 2:
|
||||||
|
input = input.unsqueeze(-1)
|
||||||
|
if input.ndim == 3:
|
||||||
|
input = input.unsqueeze(-1)
|
||||||
|
|
||||||
|
if self.output_size is None:
|
||||||
|
self.output_size = torch.tensor(
|
||||||
|
torch.nn.functional.conv2d(
|
||||||
|
torch.zeros(
|
||||||
|
1,
|
||||||
|
input.shape[1],
|
||||||
|
input.shape[2],
|
||||||
|
input.shape[3],
|
||||||
|
device=self.weight.device,
|
||||||
|
dtype=self.weight.dtype,
|
||||||
|
),
|
||||||
|
torch.zeros_like(self.weight),
|
||||||
|
stride=self.stride,
|
||||||
|
padding=self.padding,
|
||||||
|
dilation=self.dilation,
|
||||||
|
).shape,
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
assert self.output_size is not None
|
||||||
|
|
||||||
|
if self.use_convolution is False:
|
||||||
|
|
||||||
|
input = torch.nn.functional.fold(
|
||||||
|
torch.nn.functional.unfold(
|
||||||
|
input.requires_grad_(True),
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
dilation=self.dilation,
|
||||||
|
padding=self.padding,
|
||||||
|
stride=self.stride,
|
||||||
|
),
|
||||||
|
output_size=self.output_size[-2:],
|
||||||
|
kernel_size=(1, 1),
|
||||||
|
dilation=(1, 1),
|
||||||
|
padding=(0, 0),
|
||||||
|
stride=(1, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
(self.convolution_contribution_map is None)
|
||||||
|
and (self.convolution_contribution_map_enable)
|
||||||
|
and (self.use_convolution)
|
||||||
|
):
|
||||||
|
|
||||||
|
self.convolution_contribution_map = torch.nn.functional.conv_transpose2d(
|
||||||
|
torch.full(
|
||||||
|
self.output_size.tolist(),
|
||||||
|
1.0 / float(self.output_size[1]),
|
||||||
|
dtype=self.weight.dtype,
|
||||||
|
device=self.weight.device,
|
||||||
|
requires_grad=False,
|
||||||
|
),
|
||||||
|
torch.ones_like(self.weight, requires_grad=False),
|
||||||
|
stride=self.stride,
|
||||||
|
padding=self.padding,
|
||||||
|
dilation=self.dilation,
|
||||||
|
) * (
|
||||||
|
(input.shape[1] * input.shape[2] * input.shape[3])
|
||||||
|
/ (self.weight.shape[1] * self.weight.shape[2] * self.weight.shape[3])
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.convolution_contribution_map_enable and self.use_convolution:
|
||||||
|
assert self.convolution_contribution_map is not None
|
||||||
|
|
||||||
|
positive_weights = non_linear_weigth_function(
|
||||||
|
self.weight, self.beta, self.positive_function_type
|
||||||
|
)
|
||||||
|
positive_weights = positive_weights / (
|
||||||
|
positive_weights.sum((1, 2, 3), keepdim=True) + 10e-20
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_convolution is False:
|
||||||
|
positive_weights = positive_weights.reshape(
|
||||||
|
positive_weights.shape[0],
|
||||||
|
positive_weights.shape[1]
|
||||||
|
* positive_weights.shape[2]
|
||||||
|
* positive_weights.shape[3],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare input
|
||||||
|
if self.use_convolution:
|
||||||
|
input = input / (input.sum((1, 2, 3), keepdim=True) + 10e-20)
|
||||||
|
if self.convolution_contribution_map is not None:
|
||||||
|
input = input * self.convolution_contribution_map
|
||||||
|
else:
|
||||||
|
input = input / (input.sum(dim=1, keepdim=True) + 10e-20)
|
||||||
|
|
||||||
|
return self.functional_nnmf_conv2d(
|
||||||
|
input,
|
||||||
|
positive_weights,
|
||||||
|
self.output_size,
|
||||||
|
self.iterations,
|
||||||
|
self.stride,
|
||||||
|
self.padding,
|
||||||
|
self.dilation,
|
||||||
|
self.epsilon,
|
||||||
|
self.use_convolution,
|
||||||
|
self.local_learning,
|
||||||
|
self.local_learning_kl,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionalNNMFConv2d(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward( # type: ignore
|
||||||
|
ctx,
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
output_size: torch.Tensor,
|
||||||
|
iterations: int,
|
||||||
|
stride: tuple[int, int],
|
||||||
|
padding: str | tuple[int, int],
|
||||||
|
dilation: tuple[int, int],
|
||||||
|
epsilon: float | None,
|
||||||
|
use_convolution: bool,
|
||||||
|
local_learning: bool,
|
||||||
|
local_learning_kl: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# Prepare h
|
||||||
|
output_size[0] = input.shape[0]
|
||||||
|
h = torch.full(
|
||||||
|
output_size.tolist(),
|
||||||
|
1.0 / float(output_size[1]),
|
||||||
|
device=input.device,
|
||||||
|
dtype=input.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_convolution:
|
||||||
|
for _ in range(0, iterations):
|
||||||
|
factor_x_div_r: torch.Tensor = input / (
|
||||||
|
torch.nn.functional.conv_transpose2d(
|
||||||
|
h,
|
||||||
|
weight,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
)
|
||||||
|
+ 10e-20
|
||||||
|
)
|
||||||
|
|
||||||
|
if epsilon is None:
|
||||||
|
h *= torch.nn.functional.conv2d(
|
||||||
|
factor_x_div_r,
|
||||||
|
weight,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
h *= 1 + epsilon * torch.nn.functional.conv2d(
|
||||||
|
factor_x_div_r,
|
||||||
|
weight,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
)
|
||||||
|
|
||||||
|
h /= h.sum(1, keepdim=True) + 10e-20
|
||||||
|
else:
|
||||||
|
h = h.movedim(1, -1)
|
||||||
|
input = input.movedim(1, -1)
|
||||||
|
for _ in range(0, iterations):
|
||||||
|
reconstruction = torch.nn.functional.linear(h, weight.T)
|
||||||
|
reconstruction += 1e-20
|
||||||
|
if epsilon is None:
|
||||||
|
h *= torch.nn.functional.linear((input / reconstruction), weight)
|
||||||
|
else:
|
||||||
|
h *= 1 + epsilon * torch.nn.functional.linear(
|
||||||
|
(input / reconstruction), weight
|
||||||
|
)
|
||||||
|
h /= h.sum(-1, keepdim=True) + 10e-20
|
||||||
|
h = h.movedim(-1, 1)
|
||||||
|
input = input.movedim(-1, 1)
|
||||||
|
|
||||||
|
# ###########################################################
|
||||||
|
# Save the necessary data for the backward pass
|
||||||
|
# ###########################################################
|
||||||
|
ctx.save_for_backward(input, weight, h)
|
||||||
|
|
||||||
|
ctx.stride = stride
|
||||||
|
ctx.padding = padding
|
||||||
|
ctx.dilation = dilation
|
||||||
|
ctx.use_convolution = use_convolution
|
||||||
|
ctx.local_learning = local_learning
|
||||||
|
ctx.local_learning_kl = local_learning_kl
|
||||||
|
|
||||||
|
assert torch.isfinite(h).all()
|
||||||
|
return h
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.autograd.function.once_differentiable
|
||||||
|
def backward(ctx, grad_output: torch.Tensor) -> tuple[ # type: ignore
|
||||||
|
torch.Tensor | None,
|
||||||
|
torch.Tensor | None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
]:
|
||||||
|
|
||||||
|
# ##############################################
|
||||||
|
# Default values
|
||||||
|
# ##############################################
|
||||||
|
grad_input: torch.Tensor | None = None
|
||||||
|
grad_weight: torch.Tensor | None = None
|
||||||
|
|
||||||
|
# ##############################################
|
||||||
|
# Get the variables back
|
||||||
|
# ##############################################
|
||||||
|
(input, weight, h) = ctx.saved_tensors
|
||||||
|
|
||||||
|
if ctx.use_convolution:
|
||||||
|
big_r: torch.Tensor = torch.nn.functional.conv_transpose2d(
|
||||||
|
h,
|
||||||
|
weight,
|
||||||
|
stride=ctx.stride,
|
||||||
|
padding=ctx.padding,
|
||||||
|
dilation=ctx.dilation,
|
||||||
|
)
|
||||||
|
big_r_div = 1.0 / (big_r + 1e-20)
|
||||||
|
|
||||||
|
factor_x_div_r: torch.Tensor = input * big_r_div
|
||||||
|
|
||||||
|
grad_input = (
|
||||||
|
torch.nn.functional.conv_transpose2d(
|
||||||
|
(h * grad_output),
|
||||||
|
weight,
|
||||||
|
stride=ctx.stride,
|
||||||
|
padding=ctx.padding,
|
||||||
|
dilation=ctx.dilation,
|
||||||
|
)
|
||||||
|
* big_r_div
|
||||||
|
)
|
||||||
|
|
||||||
|
del big_r_div
|
||||||
|
if ctx.local_learning is False:
|
||||||
|
del big_r
|
||||||
|
grad_weight = -torch.nn.functional.conv2d(
|
||||||
|
(factor_x_div_r * grad_input).movedim(0, 1),
|
||||||
|
h.movedim(0, 1),
|
||||||
|
stride=ctx.dilation,
|
||||||
|
padding=ctx.padding,
|
||||||
|
dilation=ctx.stride,
|
||||||
|
)
|
||||||
|
|
||||||
|
grad_weight += torch.nn.functional.conv2d(
|
||||||
|
factor_x_div_r.movedim(0, 1),
|
||||||
|
(h * grad_output).movedim(0, 1),
|
||||||
|
stride=ctx.dilation,
|
||||||
|
padding=ctx.padding,
|
||||||
|
dilation=ctx.stride,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if ctx.local_learning_kl:
|
||||||
|
grad_weight = -torch.nn.functional.conv2d(
|
||||||
|
factor_x_div_r.movedim(0, 1),
|
||||||
|
h.movedim(0, 1),
|
||||||
|
stride=ctx.dilation,
|
||||||
|
padding=ctx.padding,
|
||||||
|
dilation=ctx.stride,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
grad_weight = -torch.nn.functional.conv2d(
|
||||||
|
(2 * (input - big_r)).movedim(0, 1),
|
||||||
|
h.movedim(0, 1),
|
||||||
|
stride=ctx.dilation,
|
||||||
|
padding=ctx.padding,
|
||||||
|
dilation=ctx.stride,
|
||||||
|
)
|
||||||
|
grad_weight = grad_weight.movedim(0, 1)
|
||||||
|
else:
|
||||||
|
h = h.movedim(1, -1)
|
||||||
|
grad_output = grad_output.movedim(1, -1)
|
||||||
|
input = input.movedim(1, -1)
|
||||||
|
big_r = torch.nn.functional.linear(h, weight.T)
|
||||||
|
big_r_div = 1.0 / (big_r + 1e-20)
|
||||||
|
|
||||||
|
factor_x_div_r = input * big_r_div
|
||||||
|
|
||||||
|
grad_input = (
|
||||||
|
torch.nn.functional.linear(h * grad_output, weight.T) * big_r_div
|
||||||
|
)
|
||||||
|
|
||||||
|
del big_r_div
|
||||||
|
|
||||||
|
if ctx.local_learning is False:
|
||||||
|
del big_r
|
||||||
|
|
||||||
|
grad_weight = -torch.nn.functional.linear(
|
||||||
|
h.reshape(
|
||||||
|
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
||||||
|
h.shape[3],
|
||||||
|
).T,
|
||||||
|
(factor_x_div_r * grad_input)
|
||||||
|
.reshape(
|
||||||
|
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
||||||
|
grad_input.shape[3],
|
||||||
|
)
|
||||||
|
.T,
|
||||||
|
)
|
||||||
|
|
||||||
|
grad_weight += torch.nn.functional.linear(
|
||||||
|
(h * grad_output)
|
||||||
|
.reshape(
|
||||||
|
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
||||||
|
h.shape[3],
|
||||||
|
)
|
||||||
|
.T,
|
||||||
|
factor_x_div_r.reshape(
|
||||||
|
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
||||||
|
grad_input.shape[3],
|
||||||
|
).T,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if ctx.local_learning_kl:
|
||||||
|
grad_weight = -torch.nn.functional.linear(
|
||||||
|
h.reshape(
|
||||||
|
grad_input.shape[0]
|
||||||
|
* grad_input.shape[1]
|
||||||
|
* grad_input.shape[2],
|
||||||
|
h.shape[3],
|
||||||
|
).T,
|
||||||
|
factor_x_div_r.reshape(
|
||||||
|
grad_input.shape[0]
|
||||||
|
* grad_input.shape[1]
|
||||||
|
* grad_input.shape[2],
|
||||||
|
grad_input.shape[3],
|
||||||
|
).T,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
grad_weight = -torch.nn.functional.linear(
|
||||||
|
h.reshape(
|
||||||
|
grad_input.shape[0]
|
||||||
|
* grad_input.shape[1]
|
||||||
|
* grad_input.shape[2],
|
||||||
|
h.shape[3],
|
||||||
|
).T,
|
||||||
|
(2 * (input - big_r))
|
||||||
|
.reshape(
|
||||||
|
grad_input.shape[0]
|
||||||
|
* grad_input.shape[1]
|
||||||
|
* grad_input.shape[2],
|
||||||
|
grad_input.shape[3],
|
||||||
|
)
|
||||||
|
.T,
|
||||||
|
)
|
||||||
|
grad_input = grad_input.movedim(-1, 1)
|
||||||
|
assert torch.isfinite(grad_input).all()
|
||||||
|
assert torch.isfinite(grad_weight).all()
|
||||||
|
|
||||||
|
return (
|
||||||
|
grad_input,
|
||||||
|
grad_weight,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
480
NNMFConv2dP.py
Normal file
480
NNMFConv2dP.py
Normal file
|
@ -0,0 +1,480 @@
|
||||||
|
import torch
|
||||||
|
from non_linear_weigth_function import non_linear_weigth_function
|
||||||
|
|
||||||
|
|
||||||
|
class NNMFConv2dP(torch.nn.Module):
|
||||||
|
|
||||||
|
in_channels: int
|
||||||
|
out_channels: int
|
||||||
|
kernel_size: tuple[int, ...]
|
||||||
|
stride: tuple[int, ...]
|
||||||
|
padding: str | tuple[int, ...]
|
||||||
|
dilation: tuple[int, ...]
|
||||||
|
weight: torch.Tensor
|
||||||
|
bias: None | torch.Tensor
|
||||||
|
output_size: None | torch.Tensor = None
|
||||||
|
convolution_contribution_map: None | torch.Tensor = None
|
||||||
|
iterations: int
|
||||||
|
convolution_contribution_map_enable: bool
|
||||||
|
epsilon: float | None
|
||||||
|
init_min: float
|
||||||
|
init_max: float
|
||||||
|
beta: torch.Tensor | None
|
||||||
|
positive_function_type: int
|
||||||
|
use_convolution: bool
|
||||||
|
local_learning: bool
|
||||||
|
local_learning_kl: bool
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: tuple[int, int],
|
||||||
|
stride: tuple[int, int] = (1, 1),
|
||||||
|
padding: str | tuple[int, int] = (0, 0),
|
||||||
|
dilation: tuple[int, int] = (1, 1),
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
iterations: int = 20,
|
||||||
|
convolution_contribution_map_enable: bool = False,
|
||||||
|
epsilon: float | None = None,
|
||||||
|
init_min: float = 0.0,
|
||||||
|
init_max: float = 1.0,
|
||||||
|
beta: float | None = None,
|
||||||
|
positive_function_type: int = 0,
|
||||||
|
use_convolution: bool = False,
|
||||||
|
local_learning: bool = False,
|
||||||
|
local_learning_kl: bool = False,
|
||||||
|
) -> None:
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
valid_padding_strings = {"same", "valid"}
|
||||||
|
if isinstance(padding, str):
|
||||||
|
if padding not in valid_padding_strings:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid padding string {padding!r}, should be one of {valid_padding_strings}"
|
||||||
|
)
|
||||||
|
if padding == "same" and any(s != 1 for s in stride):
|
||||||
|
raise ValueError(
|
||||||
|
"padding='same' is not supported for strided convolutions"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.positive_function_type = positive_function_type
|
||||||
|
self.init_min = init_min
|
||||||
|
self.init_max = init_max
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
|
||||||
|
self.stride = stride
|
||||||
|
self.padding = padding
|
||||||
|
self.dilation = dilation
|
||||||
|
|
||||||
|
self.iterations = iterations
|
||||||
|
self.convolution_contribution_map_enable = convolution_contribution_map_enable
|
||||||
|
|
||||||
|
self.local_learning = local_learning
|
||||||
|
self.local_learning_kl = local_learning_kl
|
||||||
|
|
||||||
|
self.weight = torch.nn.parameter.Parameter(
|
||||||
|
torch.empty((out_channels, in_channels, *kernel_size), **factory_kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
|
if beta is not None:
|
||||||
|
self.beta = torch.nn.parameter.Parameter(torch.empty((1), **factory_kwargs))
|
||||||
|
self.beta.data[0] = beta
|
||||||
|
else:
|
||||||
|
self.beta = None
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
self.functional_nnmf_conv2d = FunctionalNNMFConv2dP.apply
|
||||||
|
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.use_convolution = use_convolution
|
||||||
|
|
||||||
|
assert self.use_convolution is False
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
s: str = f"{self.in_channels}, {self.out_channels}"
|
||||||
|
|
||||||
|
s += f", kernel_size={self.kernel_size}"
|
||||||
|
s += f", stride={self.stride}, iterations={self.iterations}"
|
||||||
|
s += f", epsilon={self.epsilon}"
|
||||||
|
s += f", use_convolution={self.use_convolution}"
|
||||||
|
|
||||||
|
if self.use_convolution:
|
||||||
|
s += f", ccmap={self.convolution_contribution_map_enable}"
|
||||||
|
|
||||||
|
s += f", pfunctype={self.positive_function_type}"
|
||||||
|
s += f", local_learning={self.local_learning}"
|
||||||
|
|
||||||
|
if self.local_learning:
|
||||||
|
s += f", local_learning_kl={self.local_learning_kl}"
|
||||||
|
|
||||||
|
if self.padding != (0,) * len(self.padding):
|
||||||
|
s += f", padding={self.padding}"
|
||||||
|
|
||||||
|
if self.dilation != (1,) * len(self.dilation):
|
||||||
|
s += f", dilation={self.dilation}"
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
def reset_parameters(self) -> None:
|
||||||
|
torch.nn.init.uniform_(self.weight, a=self.init_min, b=self.init_max)
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
if input.ndim == 2:
|
||||||
|
input = input.unsqueeze(-1)
|
||||||
|
if input.ndim == 3:
|
||||||
|
input = input.unsqueeze(-1)
|
||||||
|
|
||||||
|
if self.output_size is None:
|
||||||
|
self.output_size = torch.tensor(
|
||||||
|
torch.nn.functional.conv2d(
|
||||||
|
torch.zeros(
|
||||||
|
1,
|
||||||
|
input.shape[1],
|
||||||
|
input.shape[2],
|
||||||
|
input.shape[3],
|
||||||
|
device=self.weight.device,
|
||||||
|
dtype=self.weight.dtype,
|
||||||
|
),
|
||||||
|
torch.zeros_like(self.weight),
|
||||||
|
stride=self.stride,
|
||||||
|
padding=self.padding,
|
||||||
|
dilation=self.dilation,
|
||||||
|
).shape,
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
assert self.output_size is not None
|
||||||
|
|
||||||
|
input = torch.nn.functional.fold(
|
||||||
|
torch.nn.functional.unfold(
|
||||||
|
input.requires_grad_(True),
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
dilation=self.dilation,
|
||||||
|
padding=self.padding,
|
||||||
|
stride=self.stride,
|
||||||
|
),
|
||||||
|
output_size=self.output_size[-2:],
|
||||||
|
kernel_size=(1, 1),
|
||||||
|
dilation=(1, 1),
|
||||||
|
padding=(0, 0),
|
||||||
|
stride=(1, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
positive_weights = non_linear_weigth_function(
|
||||||
|
self.weight, self.beta, self.positive_function_type
|
||||||
|
)
|
||||||
|
positive_weights = positive_weights / (
|
||||||
|
positive_weights.sum((1, 2, 3), keepdim=True) + 10e-20
|
||||||
|
)
|
||||||
|
|
||||||
|
positive_weights = positive_weights.reshape(
|
||||||
|
positive_weights.shape[0],
|
||||||
|
positive_weights.shape[1]
|
||||||
|
* positive_weights.shape[2]
|
||||||
|
* positive_weights.shape[3],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare input
|
||||||
|
input = input / (input.sum(dim=1, keepdim=True) + 10e-20)
|
||||||
|
|
||||||
|
h_dyn = self.functional_nnmf_conv2d(
|
||||||
|
input,
|
||||||
|
positive_weights,
|
||||||
|
self.output_size,
|
||||||
|
self.iterations,
|
||||||
|
self.stride,
|
||||||
|
self.padding,
|
||||||
|
self.dilation,
|
||||||
|
self.epsilon,
|
||||||
|
self.use_convolution,
|
||||||
|
self.local_learning,
|
||||||
|
self.local_learning_kl,
|
||||||
|
)
|
||||||
|
self.reco = False
|
||||||
|
if self.reco:
|
||||||
|
print(h_dyn.shape)
|
||||||
|
print(positive_weights.shape)
|
||||||
|
print(input.shape)
|
||||||
|
exit()
|
||||||
|
output = torch.cat((h_dyn, input), dim=1)
|
||||||
|
else:
|
||||||
|
output = torch.cat((h_dyn, input), dim=1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionalNNMFConv2dP(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward( # type: ignore
|
||||||
|
ctx,
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
output_size: torch.Tensor,
|
||||||
|
iterations: int,
|
||||||
|
stride: tuple[int, int],
|
||||||
|
padding: str | tuple[int, int],
|
||||||
|
dilation: tuple[int, int],
|
||||||
|
epsilon: float | None,
|
||||||
|
use_convolution: bool,
|
||||||
|
local_learning: bool,
|
||||||
|
local_learning_kl: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# Prepare h
|
||||||
|
output_size[0] = input.shape[0]
|
||||||
|
h = torch.full(
|
||||||
|
output_size.tolist(),
|
||||||
|
1.0 / float(output_size[1]),
|
||||||
|
device=input.device,
|
||||||
|
dtype=input.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_convolution:
|
||||||
|
for _ in range(0, iterations):
|
||||||
|
factor_x_div_r: torch.Tensor = input / (
|
||||||
|
torch.nn.functional.conv_transpose2d(
|
||||||
|
h,
|
||||||
|
weight,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
)
|
||||||
|
+ 10e-20
|
||||||
|
)
|
||||||
|
|
||||||
|
if epsilon is None:
|
||||||
|
h *= torch.nn.functional.conv2d(
|
||||||
|
factor_x_div_r,
|
||||||
|
weight,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
h *= 1 + epsilon * torch.nn.functional.conv2d(
|
||||||
|
factor_x_div_r,
|
||||||
|
weight,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
)
|
||||||
|
|
||||||
|
h /= h.sum(1, keepdim=True) + 10e-20
|
||||||
|
else:
|
||||||
|
h = h.movedim(1, -1)
|
||||||
|
input = input.movedim(1, -1)
|
||||||
|
for _ in range(0, iterations):
|
||||||
|
reconstruction = torch.nn.functional.linear(h, weight.T)
|
||||||
|
reconstruction += 1e-20
|
||||||
|
if epsilon is None:
|
||||||
|
h *= torch.nn.functional.linear((input / reconstruction), weight)
|
||||||
|
else:
|
||||||
|
h *= 1 + epsilon * torch.nn.functional.linear(
|
||||||
|
(input / reconstruction), weight
|
||||||
|
)
|
||||||
|
h /= h.sum(-1, keepdim=True) + 10e-20
|
||||||
|
h = h.movedim(-1, 1)
|
||||||
|
input = input.movedim(-1, 1)
|
||||||
|
|
||||||
|
# ###########################################################
|
||||||
|
# Save the necessary data for the backward pass
|
||||||
|
# ###########################################################
|
||||||
|
ctx.save_for_backward(input, weight, h)
|
||||||
|
|
||||||
|
ctx.stride = stride
|
||||||
|
ctx.padding = padding
|
||||||
|
ctx.dilation = dilation
|
||||||
|
ctx.use_convolution = use_convolution
|
||||||
|
ctx.local_learning = local_learning
|
||||||
|
ctx.local_learning_kl = local_learning_kl
|
||||||
|
|
||||||
|
assert torch.isfinite(h).all()
|
||||||
|
return h
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.autograd.function.once_differentiable
|
||||||
|
def backward(ctx, grad_output: torch.Tensor) -> tuple[ # type: ignore
|
||||||
|
torch.Tensor | None,
|
||||||
|
torch.Tensor | None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
]:
|
||||||
|
|
||||||
|
# ##############################################
|
||||||
|
# Default values
|
||||||
|
# ##############################################
|
||||||
|
grad_input: torch.Tensor | None = None
|
||||||
|
grad_weight: torch.Tensor | None = None
|
||||||
|
|
||||||
|
# ##############################################
|
||||||
|
# Get the variables back
|
||||||
|
# ##############################################
|
||||||
|
(input, weight, h) = ctx.saved_tensors
|
||||||
|
|
||||||
|
if ctx.use_convolution:
|
||||||
|
big_r: torch.Tensor = torch.nn.functional.conv_transpose2d(
|
||||||
|
h,
|
||||||
|
weight,
|
||||||
|
stride=ctx.stride,
|
||||||
|
padding=ctx.padding,
|
||||||
|
dilation=ctx.dilation,
|
||||||
|
)
|
||||||
|
big_r_div = 1.0 / (big_r + 1e-20)
|
||||||
|
|
||||||
|
factor_x_div_r: torch.Tensor = input * big_r_div
|
||||||
|
|
||||||
|
grad_input = (
|
||||||
|
torch.nn.functional.conv_transpose2d(
|
||||||
|
(h * grad_output),
|
||||||
|
weight,
|
||||||
|
stride=ctx.stride,
|
||||||
|
padding=ctx.padding,
|
||||||
|
dilation=ctx.dilation,
|
||||||
|
)
|
||||||
|
* big_r_div
|
||||||
|
)
|
||||||
|
|
||||||
|
del big_r_div
|
||||||
|
if ctx.local_learning is False:
|
||||||
|
del big_r
|
||||||
|
grad_weight = -torch.nn.functional.conv2d(
|
||||||
|
(factor_x_div_r * grad_input).movedim(0, 1),
|
||||||
|
h.movedim(0, 1),
|
||||||
|
stride=ctx.dilation,
|
||||||
|
padding=ctx.padding,
|
||||||
|
dilation=ctx.stride,
|
||||||
|
)
|
||||||
|
|
||||||
|
grad_weight += torch.nn.functional.conv2d(
|
||||||
|
factor_x_div_r.movedim(0, 1),
|
||||||
|
(h * grad_output).movedim(0, 1),
|
||||||
|
stride=ctx.dilation,
|
||||||
|
padding=ctx.padding,
|
||||||
|
dilation=ctx.stride,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if ctx.local_learning_kl:
|
||||||
|
grad_weight = -torch.nn.functional.conv2d(
|
||||||
|
factor_x_div_r.movedim(0, 1),
|
||||||
|
h.movedim(0, 1),
|
||||||
|
stride=ctx.dilation,
|
||||||
|
padding=ctx.padding,
|
||||||
|
dilation=ctx.stride,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
grad_weight = -torch.nn.functional.conv2d(
|
||||||
|
(2 * (input - big_r)).movedim(0, 1),
|
||||||
|
h.movedim(0, 1),
|
||||||
|
stride=ctx.dilation,
|
||||||
|
padding=ctx.padding,
|
||||||
|
dilation=ctx.stride,
|
||||||
|
)
|
||||||
|
grad_weight = grad_weight.movedim(0, 1)
|
||||||
|
else:
|
||||||
|
h = h.movedim(1, -1)
|
||||||
|
grad_output = grad_output.movedim(1, -1)
|
||||||
|
input = input.movedim(1, -1)
|
||||||
|
big_r = torch.nn.functional.linear(h, weight.T)
|
||||||
|
big_r_div = 1.0 / (big_r + 1e-20)
|
||||||
|
|
||||||
|
factor_x_div_r = input * big_r_div
|
||||||
|
|
||||||
|
grad_input = (
|
||||||
|
torch.nn.functional.linear(h * grad_output, weight.T) * big_r_div
|
||||||
|
)
|
||||||
|
|
||||||
|
del big_r_div
|
||||||
|
|
||||||
|
if ctx.local_learning is False:
|
||||||
|
del big_r
|
||||||
|
|
||||||
|
grad_weight = -torch.nn.functional.linear(
|
||||||
|
h.reshape(
|
||||||
|
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
||||||
|
h.shape[3],
|
||||||
|
).T,
|
||||||
|
(factor_x_div_r * grad_input)
|
||||||
|
.reshape(
|
||||||
|
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
||||||
|
grad_input.shape[3],
|
||||||
|
)
|
||||||
|
.T,
|
||||||
|
)
|
||||||
|
|
||||||
|
grad_weight += torch.nn.functional.linear(
|
||||||
|
(h * grad_output)
|
||||||
|
.reshape(
|
||||||
|
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
||||||
|
h.shape[3],
|
||||||
|
)
|
||||||
|
.T,
|
||||||
|
factor_x_div_r.reshape(
|
||||||
|
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
||||||
|
grad_input.shape[3],
|
||||||
|
).T,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if ctx.local_learning_kl:
|
||||||
|
grad_weight = -torch.nn.functional.linear(
|
||||||
|
h.reshape(
|
||||||
|
grad_input.shape[0]
|
||||||
|
* grad_input.shape[1]
|
||||||
|
* grad_input.shape[2],
|
||||||
|
h.shape[3],
|
||||||
|
).T,
|
||||||
|
factor_x_div_r.reshape(
|
||||||
|
grad_input.shape[0]
|
||||||
|
* grad_input.shape[1]
|
||||||
|
* grad_input.shape[2],
|
||||||
|
grad_input.shape[3],
|
||||||
|
).T,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
grad_weight = -torch.nn.functional.linear(
|
||||||
|
h.reshape(
|
||||||
|
grad_input.shape[0]
|
||||||
|
* grad_input.shape[1]
|
||||||
|
* grad_input.shape[2],
|
||||||
|
h.shape[3],
|
||||||
|
).T,
|
||||||
|
(2 * (input - big_r))
|
||||||
|
.reshape(
|
||||||
|
grad_input.shape[0]
|
||||||
|
* grad_input.shape[1]
|
||||||
|
* grad_input.shape[2],
|
||||||
|
grad_input.shape[3],
|
||||||
|
)
|
||||||
|
.T,
|
||||||
|
)
|
||||||
|
grad_input = grad_input.movedim(-1, 1)
|
||||||
|
assert torch.isfinite(grad_input).all()
|
||||||
|
assert torch.isfinite(grad_weight).all()
|
||||||
|
|
||||||
|
return (
|
||||||
|
grad_input,
|
||||||
|
grad_weight,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
23
SplitOnOffLayer.py
Normal file
23
SplitOnOffLayer.py
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class SplitOnOffLayer(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
####################################################################
|
||||||
|
# Forward #
|
||||||
|
####################################################################
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
assert input.ndim == 4
|
||||||
|
|
||||||
|
temp = input - 0.5
|
||||||
|
temp_a = torch.nn.functional.relu(temp)
|
||||||
|
temp_b = torch.nn.functional.relu(-temp)
|
||||||
|
output = torch.cat((temp_a, temp_b), dim=1)
|
||||||
|
|
||||||
|
return output
|
29
convert_log_to_numpy.py
Normal file
29
convert_log_to_numpy.py
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
|
||||||
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||||
|
|
||||||
|
from tensorboard.backend.event_processing import event_accumulator
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def get_data(path: str = "log_cnn"):
|
||||||
|
acc = event_accumulator.EventAccumulator(path)
|
||||||
|
acc.Reload()
|
||||||
|
|
||||||
|
which_scalar = "Test Number Correct"
|
||||||
|
te = acc.Scalars(which_scalar)
|
||||||
|
|
||||||
|
np_temp = np.zeros((len(te), 2))
|
||||||
|
|
||||||
|
for id in range(0, len(te)):
|
||||||
|
np_temp[id, 0] = te[id].step
|
||||||
|
np_temp[id, 1] = te[id].value
|
||||||
|
|
||||||
|
return np_temp
|
||||||
|
|
||||||
|
|
||||||
|
for path in glob.glob("log_*"):
|
||||||
|
print(path)
|
||||||
|
data = get_data(path)
|
||||||
|
np.save("data_" + path + ".npy", data)
|
27
data_loader.py
Normal file
27
data_loader.py
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def data_loader(
|
||||||
|
pattern: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
batch_size: int = 128,
|
||||||
|
shuffle: bool = True,
|
||||||
|
torch_device: torch.device = torch.device("cpu"),
|
||||||
|
) -> torch.utils.data.dataloader.DataLoader:
|
||||||
|
|
||||||
|
assert pattern.ndim >= 3
|
||||||
|
|
||||||
|
pattern_storage: torch.Tensor = pattern.to(torch_device).type(torch.float32)
|
||||||
|
if pattern_storage.ndim == 3:
|
||||||
|
pattern_storage = pattern_storage.unsqueeze(1)
|
||||||
|
pattern_storage /= pattern_storage.max()
|
||||||
|
|
||||||
|
label_storage: torch.Tensor = labels.to(torch_device).type(torch.int64)
|
||||||
|
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
torch.utils.data.TensorDataset(pattern_storage, label_storage),
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=shuffle,
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataloader
|
115
get_the_data.py
Normal file
115
get_the_data.py
Normal file
|
@ -0,0 +1,115 @@
|
||||||
|
import torch
|
||||||
|
import torchvision # type: ignore
|
||||||
|
from data_loader import data_loader
|
||||||
|
|
||||||
|
|
||||||
|
def get_the_data(
|
||||||
|
dataset: str,
|
||||||
|
batch_size_train: int,
|
||||||
|
batch_size_test: int,
|
||||||
|
torch_device: torch.device,
|
||||||
|
input_dim_x: int,
|
||||||
|
input_dim_y: int,
|
||||||
|
flip_p: float = 0.5,
|
||||||
|
jitter_brightness: float = 0.5,
|
||||||
|
jitter_contrast: float = 0.1,
|
||||||
|
jitter_saturation: float = 0.1,
|
||||||
|
jitter_hue: float = 0.15,
|
||||||
|
) -> tuple[
|
||||||
|
data_loader,
|
||||||
|
data_loader,
|
||||||
|
torchvision.transforms.Compose,
|
||||||
|
torchvision.transforms.Compose,
|
||||||
|
]:
|
||||||
|
if dataset == "MNIST":
|
||||||
|
tv_dataset_train = torchvision.datasets.MNIST(
|
||||||
|
root="data", train=True, download=True
|
||||||
|
)
|
||||||
|
tv_dataset_test = torchvision.datasets.MNIST(
|
||||||
|
root="data", train=False, download=True
|
||||||
|
)
|
||||||
|
elif dataset == "FashionMNIST":
|
||||||
|
tv_dataset_train = torchvision.datasets.FashionMNIST(
|
||||||
|
root="data", train=True, download=True
|
||||||
|
)
|
||||||
|
tv_dataset_test = torchvision.datasets.FashionMNIST(
|
||||||
|
root="data", train=False, download=True
|
||||||
|
)
|
||||||
|
elif dataset == "CIFAR10":
|
||||||
|
tv_dataset_train = torchvision.datasets.CIFAR10(
|
||||||
|
root="data", train=True, download=True
|
||||||
|
)
|
||||||
|
tv_dataset_test = torchvision.datasets.CIFAR10(
|
||||||
|
root="data", train=False, download=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("This dataset is not implemented.")
|
||||||
|
|
||||||
|
if dataset == "MNIST" or dataset == "FashionMNIST":
|
||||||
|
|
||||||
|
train_dataloader = data_loader(
|
||||||
|
torch_device=torch_device,
|
||||||
|
batch_size=batch_size_train,
|
||||||
|
pattern=tv_dataset_train.data,
|
||||||
|
labels=tv_dataset_train.targets,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_dataloader = data_loader(
|
||||||
|
torch_device=torch_device,
|
||||||
|
batch_size=batch_size_test,
|
||||||
|
pattern=tv_dataset_test.data,
|
||||||
|
labels=tv_dataset_test.targets,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Data augmentation filter
|
||||||
|
test_processing_chain = torchvision.transforms.Compose(
|
||||||
|
transforms=[torchvision.transforms.CenterCrop((input_dim_x, input_dim_y))],
|
||||||
|
)
|
||||||
|
|
||||||
|
train_processing_chain = torchvision.transforms.Compose(
|
||||||
|
transforms=[torchvision.transforms.RandomCrop((input_dim_x, input_dim_y))],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
|
||||||
|
train_dataloader = data_loader(
|
||||||
|
torch_device=torch_device,
|
||||||
|
batch_size=batch_size_train,
|
||||||
|
pattern=torch.tensor(tv_dataset_train.data).movedim(-1, 1),
|
||||||
|
labels=torch.tensor(tv_dataset_train.targets),
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_dataloader = data_loader(
|
||||||
|
torch_device=torch_device,
|
||||||
|
batch_size=batch_size_test,
|
||||||
|
pattern=torch.tensor(tv_dataset_test.data).movedim(-1, 1),
|
||||||
|
labels=torch.tensor(tv_dataset_test.targets),
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Data augmentation filter
|
||||||
|
test_processing_chain = torchvision.transforms.Compose(
|
||||||
|
transforms=[torchvision.transforms.CenterCrop((input_dim_x, input_dim_y))],
|
||||||
|
)
|
||||||
|
|
||||||
|
train_processing_chain = torchvision.transforms.Compose(
|
||||||
|
transforms=[
|
||||||
|
torchvision.transforms.RandomCrop((input_dim_x, input_dim_y)),
|
||||||
|
torchvision.transforms.RandomHorizontalFlip(p=flip_p),
|
||||||
|
torchvision.transforms.ColorJitter(
|
||||||
|
brightness=jitter_brightness,
|
||||||
|
contrast=jitter_contrast,
|
||||||
|
saturation=jitter_saturation,
|
||||||
|
hue=jitter_hue,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
train_dataloader,
|
||||||
|
test_dataloader,
|
||||||
|
test_processing_chain,
|
||||||
|
train_processing_chain,
|
||||||
|
)
|
64
loss_function.py
Normal file
64
loss_function.py
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# loss_mode == 0: "normal" SbS loss function mixture
|
||||||
|
# loss_mode == 1: cross_entropy
|
||||||
|
def loss_function(
|
||||||
|
h: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
loss_mode: int = 0,
|
||||||
|
number_of_output_neurons: int = 10,
|
||||||
|
loss_coeffs_mse: float = 0.0,
|
||||||
|
loss_coeffs_kldiv: float = 0.0,
|
||||||
|
) -> torch.Tensor | None:
|
||||||
|
|
||||||
|
assert loss_mode >= 0
|
||||||
|
assert loss_mode <= 1
|
||||||
|
|
||||||
|
assert h.ndim == 2
|
||||||
|
|
||||||
|
if loss_mode == 0:
|
||||||
|
|
||||||
|
# Convert label into one hot
|
||||||
|
target_one_hot: torch.Tensor = torch.zeros(
|
||||||
|
(
|
||||||
|
labels.shape[0],
|
||||||
|
number_of_output_neurons,
|
||||||
|
),
|
||||||
|
device=h.device,
|
||||||
|
dtype=h.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
target_one_hot.scatter_(
|
||||||
|
1,
|
||||||
|
labels.to(h.device).unsqueeze(1),
|
||||||
|
torch.ones(
|
||||||
|
(labels.shape[0], 1),
|
||||||
|
device=h.device,
|
||||||
|
dtype=h.dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
my_loss: torch.Tensor = ((h - target_one_hot) ** 2).sum(dim=0).mean(
|
||||||
|
dim=0
|
||||||
|
) * loss_coeffs_mse
|
||||||
|
|
||||||
|
my_loss = (
|
||||||
|
my_loss
|
||||||
|
+ (
|
||||||
|
(target_one_hot * torch.log((target_one_hot + 1e-20) / (h + 1e-20)))
|
||||||
|
.sum(dim=0)
|
||||||
|
.mean(dim=0)
|
||||||
|
)
|
||||||
|
* loss_coeffs_kldiv
|
||||||
|
)
|
||||||
|
|
||||||
|
my_loss = my_loss / (abs(loss_coeffs_kldiv) + abs(loss_coeffs_mse))
|
||||||
|
|
||||||
|
return my_loss
|
||||||
|
|
||||||
|
elif loss_mode == 1:
|
||||||
|
my_loss = torch.nn.functional.cross_entropy(h, labels.to(h.device))
|
||||||
|
return my_loss
|
||||||
|
else:
|
||||||
|
return None
|
354
make_network.py
Normal file
354
make_network.py
Normal file
|
@ -0,0 +1,354 @@
|
||||||
|
import torch
|
||||||
|
from NNMFConv2d import NNMFConv2d
|
||||||
|
from NNMFConv2dP import NNMFConv2dP
|
||||||
|
from SplitOnOffLayer import SplitOnOffLayer
|
||||||
|
|
||||||
|
|
||||||
|
def make_network(
|
||||||
|
use_nnmf: bool,
|
||||||
|
cnn_top: bool,
|
||||||
|
input_dim_x: int,
|
||||||
|
input_dim_y: int,
|
||||||
|
input_number_of_channel: int,
|
||||||
|
iterations: int,
|
||||||
|
init_min: float = 0.0,
|
||||||
|
init_max: float = 1.0,
|
||||||
|
use_convolution: bool = False,
|
||||||
|
convolution_contribution_map_enable: bool = False,
|
||||||
|
epsilon: bool | None = None,
|
||||||
|
positive_function_type: int = 0,
|
||||||
|
beta: float | None = None,
|
||||||
|
number_of_output_channels_conv1: int = 32,
|
||||||
|
number_of_output_channels_conv2: int = 64,
|
||||||
|
number_of_output_channels_flatten2: int = 96,
|
||||||
|
number_of_output_channels_full1: int = 10,
|
||||||
|
kernel_size_conv1: tuple[int, int] = (5, 5),
|
||||||
|
kernel_size_pool1: tuple[int, int] = (2, 2),
|
||||||
|
kernel_size_conv2: tuple[int, int] = (5, 5),
|
||||||
|
kernel_size_pool2: tuple[int, int] = (2, 2),
|
||||||
|
stride_conv1: tuple[int, int] = (1, 1),
|
||||||
|
stride_pool1: tuple[int, int] = (2, 2),
|
||||||
|
stride_conv2: tuple[int, int] = (1, 1),
|
||||||
|
stride_pool2: tuple[int, int] = (2, 2),
|
||||||
|
padding_conv1: int = 0,
|
||||||
|
padding_pool1: int = 0,
|
||||||
|
padding_conv2: int = 0,
|
||||||
|
padding_pool2: int = 0,
|
||||||
|
enable_onoff: bool = False,
|
||||||
|
local_learning_0: bool = False,
|
||||||
|
local_learning_1: bool = False,
|
||||||
|
local_learning_2: bool = False,
|
||||||
|
local_learning_3: bool = False,
|
||||||
|
local_learning_kl: bool = True,
|
||||||
|
p_mode_0: bool = False,
|
||||||
|
p_mode_1: bool = False,
|
||||||
|
p_mode_2: bool = False,
|
||||||
|
p_mode_3: bool = False,
|
||||||
|
) -> tuple[torch.nn.Sequential, list[int], list[int]]:
|
||||||
|
|
||||||
|
if enable_onoff:
|
||||||
|
input_number_of_channel *= 2
|
||||||
|
|
||||||
|
list_cnn_top_id: list[int] = []
|
||||||
|
list_other_id: list[int] = []
|
||||||
|
|
||||||
|
test_image = torch.ones((1, input_number_of_channel, input_dim_x, input_dim_y))
|
||||||
|
|
||||||
|
network = torch.nn.Sequential()
|
||||||
|
|
||||||
|
if enable_onoff:
|
||||||
|
network.append(SplitOnOffLayer())
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
list_other_id.append(len(network))
|
||||||
|
if use_nnmf:
|
||||||
|
if p_mode_0:
|
||||||
|
network.append(
|
||||||
|
NNMFConv2dP(
|
||||||
|
in_channels=test_image.shape[1],
|
||||||
|
out_channels=number_of_output_channels_conv1,
|
||||||
|
kernel_size=kernel_size_conv1,
|
||||||
|
convolution_contribution_map_enable=convolution_contribution_map_enable,
|
||||||
|
epsilon=epsilon,
|
||||||
|
positive_function_type=positive_function_type,
|
||||||
|
init_min=init_min,
|
||||||
|
init_max=init_max,
|
||||||
|
beta=beta,
|
||||||
|
use_convolution=use_convolution,
|
||||||
|
iterations=iterations,
|
||||||
|
local_learning=local_learning_0,
|
||||||
|
local_learning_kl=local_learning_kl,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
network.append(
|
||||||
|
NNMFConv2d(
|
||||||
|
in_channels=test_image.shape[1],
|
||||||
|
out_channels=number_of_output_channels_conv1,
|
||||||
|
kernel_size=kernel_size_conv1,
|
||||||
|
convolution_contribution_map_enable=convolution_contribution_map_enable,
|
||||||
|
epsilon=epsilon,
|
||||||
|
positive_function_type=positive_function_type,
|
||||||
|
init_min=init_min,
|
||||||
|
init_max=init_max,
|
||||||
|
beta=beta,
|
||||||
|
use_convolution=use_convolution,
|
||||||
|
iterations=iterations,
|
||||||
|
local_learning=local_learning_0,
|
||||||
|
local_learning_kl=local_learning_kl,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
else:
|
||||||
|
network.append(
|
||||||
|
torch.nn.Conv2d(
|
||||||
|
in_channels=test_image.shape[1],
|
||||||
|
out_channels=number_of_output_channels_conv1,
|
||||||
|
kernel_size=kernel_size_conv1,
|
||||||
|
stride=stride_conv1,
|
||||||
|
padding=padding_conv1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
network.append(torch.nn.ReLU())
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
if cnn_top:
|
||||||
|
list_cnn_top_id.append(len(network))
|
||||||
|
network.append(
|
||||||
|
torch.nn.Conv2d(
|
||||||
|
in_channels=test_image.shape[1],
|
||||||
|
out_channels=number_of_output_channels_conv1,
|
||||||
|
kernel_size=(1, 1),
|
||||||
|
stride=(1, 1),
|
||||||
|
padding=(0, 0),
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
network.append(torch.nn.ReLU())
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
network.append(
|
||||||
|
torch.nn.MaxPool2d(
|
||||||
|
kernel_size=kernel_size_pool1, stride=stride_pool1, padding=padding_pool1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
list_other_id.append(len(network))
|
||||||
|
if use_nnmf:
|
||||||
|
if p_mode_1:
|
||||||
|
network.append(
|
||||||
|
NNMFConv2dP(
|
||||||
|
in_channels=test_image.shape[1],
|
||||||
|
out_channels=number_of_output_channels_conv2,
|
||||||
|
kernel_size=kernel_size_conv2,
|
||||||
|
convolution_contribution_map_enable=convolution_contribution_map_enable,
|
||||||
|
epsilon=epsilon,
|
||||||
|
positive_function_type=positive_function_type,
|
||||||
|
init_min=init_min,
|
||||||
|
init_max=init_max,
|
||||||
|
beta=beta,
|
||||||
|
use_convolution=use_convolution,
|
||||||
|
iterations=iterations,
|
||||||
|
local_learning=local_learning_1,
|
||||||
|
local_learning_kl=local_learning_kl,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
network.append(
|
||||||
|
NNMFConv2d(
|
||||||
|
in_channels=test_image.shape[1],
|
||||||
|
out_channels=number_of_output_channels_conv2,
|
||||||
|
kernel_size=kernel_size_conv2,
|
||||||
|
convolution_contribution_map_enable=convolution_contribution_map_enable,
|
||||||
|
epsilon=epsilon,
|
||||||
|
positive_function_type=positive_function_type,
|
||||||
|
init_min=init_min,
|
||||||
|
init_max=init_max,
|
||||||
|
beta=beta,
|
||||||
|
use_convolution=use_convolution,
|
||||||
|
iterations=iterations,
|
||||||
|
local_learning=local_learning_1,
|
||||||
|
local_learning_kl=local_learning_kl,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
else:
|
||||||
|
network.append(
|
||||||
|
torch.nn.Conv2d(
|
||||||
|
in_channels=test_image.shape[1],
|
||||||
|
out_channels=number_of_output_channels_conv2,
|
||||||
|
kernel_size=kernel_size_conv2,
|
||||||
|
stride=stride_conv2,
|
||||||
|
padding=padding_conv2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
network.append(torch.nn.ReLU())
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
if cnn_top:
|
||||||
|
list_cnn_top_id.append(len(network))
|
||||||
|
network.append(
|
||||||
|
torch.nn.Conv2d(
|
||||||
|
in_channels=test_image.shape[1],
|
||||||
|
out_channels=number_of_output_channels_conv2,
|
||||||
|
kernel_size=(1, 1),
|
||||||
|
stride=(1, 1),
|
||||||
|
padding=(0, 0),
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
network.append(torch.nn.ReLU())
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
network.append(
|
||||||
|
torch.nn.MaxPool2d(
|
||||||
|
kernel_size=kernel_size_pool2, stride=stride_pool2, padding=padding_pool2
|
||||||
|
)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
list_other_id.append(len(network))
|
||||||
|
if use_nnmf:
|
||||||
|
if p_mode_2:
|
||||||
|
network.append(
|
||||||
|
NNMFConv2dP(
|
||||||
|
in_channels=test_image.shape[1],
|
||||||
|
out_channels=number_of_output_channels_flatten2,
|
||||||
|
kernel_size=(test_image.shape[2], test_image.shape[3]),
|
||||||
|
convolution_contribution_map_enable=convolution_contribution_map_enable,
|
||||||
|
epsilon=epsilon,
|
||||||
|
positive_function_type=positive_function_type,
|
||||||
|
init_min=init_min,
|
||||||
|
init_max=init_max,
|
||||||
|
beta=beta,
|
||||||
|
use_convolution=use_convolution,
|
||||||
|
iterations=iterations,
|
||||||
|
local_learning=local_learning_2,
|
||||||
|
local_learning_kl=local_learning_kl,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
network.append(
|
||||||
|
NNMFConv2d(
|
||||||
|
in_channels=test_image.shape[1],
|
||||||
|
out_channels=number_of_output_channels_flatten2,
|
||||||
|
kernel_size=(test_image.shape[2], test_image.shape[3]),
|
||||||
|
convolution_contribution_map_enable=convolution_contribution_map_enable,
|
||||||
|
epsilon=epsilon,
|
||||||
|
positive_function_type=positive_function_type,
|
||||||
|
init_min=init_min,
|
||||||
|
init_max=init_max,
|
||||||
|
beta=beta,
|
||||||
|
use_convolution=use_convolution,
|
||||||
|
iterations=iterations,
|
||||||
|
local_learning=local_learning_2,
|
||||||
|
local_learning_kl=local_learning_kl,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
else:
|
||||||
|
network.append(
|
||||||
|
torch.nn.Conv2d(
|
||||||
|
in_channels=test_image.shape[1],
|
||||||
|
out_channels=number_of_output_channels_flatten2,
|
||||||
|
kernel_size=(test_image.shape[2], test_image.shape[3]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
network.append(torch.nn.ReLU())
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
if cnn_top:
|
||||||
|
list_cnn_top_id.append(len(network))
|
||||||
|
network.append(
|
||||||
|
torch.nn.Conv2d(
|
||||||
|
in_channels=test_image.shape[1],
|
||||||
|
out_channels=number_of_output_channels_flatten2,
|
||||||
|
kernel_size=(1, 1),
|
||||||
|
stride=(1, 1),
|
||||||
|
padding=(0, 0),
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
network.append(torch.nn.ReLU())
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
list_other_id.append(len(network))
|
||||||
|
if use_nnmf:
|
||||||
|
if p_mode_3:
|
||||||
|
network.append(
|
||||||
|
NNMFConv2dP(
|
||||||
|
in_channels=test_image.shape[1],
|
||||||
|
out_channels=number_of_output_channels_full1,
|
||||||
|
kernel_size=(test_image.shape[2], test_image.shape[3]),
|
||||||
|
convolution_contribution_map_enable=convolution_contribution_map_enable,
|
||||||
|
epsilon=epsilon,
|
||||||
|
positive_function_type=positive_function_type,
|
||||||
|
init_min=init_min,
|
||||||
|
init_max=init_max,
|
||||||
|
beta=beta,
|
||||||
|
use_convolution=use_convolution,
|
||||||
|
iterations=iterations,
|
||||||
|
local_learning=local_learning_3,
|
||||||
|
local_learning_kl=local_learning_kl,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
network.append(
|
||||||
|
NNMFConv2d(
|
||||||
|
in_channels=test_image.shape[1],
|
||||||
|
out_channels=number_of_output_channels_full1,
|
||||||
|
kernel_size=(test_image.shape[2], test_image.shape[3]),
|
||||||
|
convolution_contribution_map_enable=convolution_contribution_map_enable,
|
||||||
|
epsilon=epsilon,
|
||||||
|
positive_function_type=positive_function_type,
|
||||||
|
init_min=init_min,
|
||||||
|
init_max=init_max,
|
||||||
|
beta=beta,
|
||||||
|
use_convolution=use_convolution,
|
||||||
|
iterations=iterations,
|
||||||
|
local_learning=local_learning_3,
|
||||||
|
local_learning_kl=local_learning_kl,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
else:
|
||||||
|
network.append(
|
||||||
|
torch.nn.Conv2d(
|
||||||
|
in_channels=test_image.shape[1],
|
||||||
|
out_channels=number_of_output_channels_full1,
|
||||||
|
kernel_size=(test_image.shape[2], test_image.shape[3]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
if cnn_top:
|
||||||
|
network.append(torch.nn.ReLU())
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
if cnn_top:
|
||||||
|
list_cnn_top_id.append(len(network))
|
||||||
|
network.append(
|
||||||
|
torch.nn.Conv2d(
|
||||||
|
in_channels=test_image.shape[1],
|
||||||
|
out_channels=number_of_output_channels_full1,
|
||||||
|
kernel_size=(1, 1),
|
||||||
|
stride=(1, 1),
|
||||||
|
padding=(0, 0),
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
network.append(torch.nn.Flatten())
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
network.append(torch.nn.Softmax(dim=1))
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
return network, list_cnn_top_id, list_other_id
|
103
make_optimize.py
Normal file
103
make_optimize.py
Normal file
|
@ -0,0 +1,103 @@
|
||||||
|
import torch
|
||||||
|
from NNMFConv2d import NNMFConv2d
|
||||||
|
from NNMFConv2dP import NNMFConv2dP
|
||||||
|
|
||||||
|
|
||||||
|
def make_optimize(
|
||||||
|
network: torch.nn.Sequential,
|
||||||
|
list_cnn_top_id: list[int],
|
||||||
|
list_other_id: list[int],
|
||||||
|
lr_initial_nnmf: float = 0.01,
|
||||||
|
lr_initial_cnn: float = 0.001,
|
||||||
|
lr_initial_cnn_top: float = 0.001,
|
||||||
|
eps=1e-10,
|
||||||
|
) -> tuple[
|
||||||
|
torch.optim.Adam | None,
|
||||||
|
torch.optim.Adam | None,
|
||||||
|
torch.optim.Adam | None,
|
||||||
|
torch.optim.lr_scheduler.ReduceLROnPlateau | None,
|
||||||
|
torch.optim.lr_scheduler.ReduceLROnPlateau | None,
|
||||||
|
torch.optim.lr_scheduler.ReduceLROnPlateau | None,
|
||||||
|
]:
|
||||||
|
|
||||||
|
list_cnn_top: list = []
|
||||||
|
# Init the cnn top layers 1x1 conv2d layers
|
||||||
|
for layerid in list_cnn_top_id:
|
||||||
|
for netp in network[layerid].parameters():
|
||||||
|
with torch.no_grad():
|
||||||
|
if netp.ndim == 1:
|
||||||
|
netp.data *= 0
|
||||||
|
if netp.ndim == 4:
|
||||||
|
assert netp.shape[-2] == 1
|
||||||
|
assert netp.shape[-1] == 1
|
||||||
|
netp[: netp.shape[0], : netp.shape[0], 0, 0] = torch.eye(
|
||||||
|
netp.shape[0], dtype=netp.dtype, device=netp.device
|
||||||
|
)
|
||||||
|
netp[netp.shape[0] :, :, 0, 0] = 0
|
||||||
|
netp[:, netp.shape[0] :, 0, 0] = 0
|
||||||
|
|
||||||
|
list_cnn_top.append(netp)
|
||||||
|
|
||||||
|
list_cnn: list = []
|
||||||
|
list_nnmf: list = []
|
||||||
|
for layerid in list_other_id:
|
||||||
|
if isinstance(network[layerid], torch.nn.Conv2d):
|
||||||
|
for netp in network[layerid].parameters():
|
||||||
|
list_cnn.append(netp)
|
||||||
|
|
||||||
|
if isinstance(network[layerid], (NNMFConv2d, NNMFConv2dP)):
|
||||||
|
for netp in network[layerid].parameters():
|
||||||
|
list_nnmf.append(netp)
|
||||||
|
|
||||||
|
# The optimizer
|
||||||
|
if len(list_nnmf) > 0:
|
||||||
|
optimizer_nnmf: torch.optim.Adam | None = torch.optim.Adam(
|
||||||
|
list_nnmf, lr=lr_initial_nnmf
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
optimizer_nnmf = None
|
||||||
|
|
||||||
|
if len(list_cnn) > 0:
|
||||||
|
optimizer_cnn: torch.optim.Adam | None = torch.optim.Adam(
|
||||||
|
list_cnn, lr=lr_initial_cnn
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
optimizer_cnn = None
|
||||||
|
|
||||||
|
if len(list_cnn_top) > 0:
|
||||||
|
optimizer_cnn_top: torch.optim.Adam | None = torch.optim.Adam(
|
||||||
|
list_cnn_top, lr=lr_initial_cnn_top
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
optimizer_cnn_top = None
|
||||||
|
|
||||||
|
# The LR Scheduler
|
||||||
|
if optimizer_nnmf is not None:
|
||||||
|
lr_scheduler_nnmf: torch.optim.lr_scheduler.ReduceLROnPlateau | None = (
|
||||||
|
torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_nnmf, eps=eps)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
lr_scheduler_nnmf = None
|
||||||
|
|
||||||
|
if optimizer_cnn is not None:
|
||||||
|
lr_scheduler_cnn: torch.optim.lr_scheduler.ReduceLROnPlateau | None = (
|
||||||
|
torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_cnn, eps=eps)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
lr_scheduler_cnn = None
|
||||||
|
|
||||||
|
if optimizer_cnn_top is not None:
|
||||||
|
lr_scheduler_cnn_top: torch.optim.lr_scheduler.ReduceLROnPlateau | None = (
|
||||||
|
torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_cnn_top, eps=eps)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
lr_scheduler_cnn_top = None
|
||||||
|
|
||||||
|
return (
|
||||||
|
optimizer_nnmf,
|
||||||
|
optimizer_cnn,
|
||||||
|
optimizer_cnn_top,
|
||||||
|
lr_scheduler_nnmf,
|
||||||
|
lr_scheduler_cnn,
|
||||||
|
lr_scheduler_cnn_top,
|
||||||
|
)
|
26
non_linear_weigth_function.py
Normal file
26
non_linear_weigth_function.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def non_linear_weigth_function(
|
||||||
|
weight: torch.Tensor, beta: torch.Tensor | None, positive_function_type: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
if positive_function_type == 0:
|
||||||
|
positive_weights = torch.abs(weight)
|
||||||
|
|
||||||
|
elif positive_function_type == 1:
|
||||||
|
assert beta is not None
|
||||||
|
positive_weights = weight
|
||||||
|
max_value = torch.abs(positive_weights).max()
|
||||||
|
if max_value > 80:
|
||||||
|
positive_weights = 80.0 * positive_weights / max_value
|
||||||
|
positive_weights = torch.exp((torch.tanh(beta) + 1.0) * 0.5 * positive_weights)
|
||||||
|
|
||||||
|
elif positive_function_type == 2:
|
||||||
|
assert beta is not None
|
||||||
|
positive_weights = (torch.tanh(beta * weight) + 1.0) * 0.5
|
||||||
|
|
||||||
|
else:
|
||||||
|
positive_weights = weight
|
||||||
|
|
||||||
|
return positive_weights
|
46
plot.py
Normal file
46
plot.py
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
data = np.load("data_log_cnn_20_True_0.001_0.01_True_True_True_True.npy")
|
||||||
|
plt.loglog(data[:, 0], 100.0 * (1.0 - data[:, 1] / 10000.0), "k", label="CNN + CNN Top")
|
||||||
|
|
||||||
|
data = np.load("data_log_cnn_20_False_0.001_0.01_True_True_True_True.npy")
|
||||||
|
plt.loglog(data[:, 0], 100.0 * (1.0 - data[:, 1] / 10000.0), "k--", label="CNN")
|
||||||
|
|
||||||
|
data = np.load("data_log_nnmf_20_True_0.001_0.01_True_True_True_True.npy")
|
||||||
|
plt.loglog(
|
||||||
|
data[:, 0],
|
||||||
|
100.0 * (1.0 - data[:, 1] / 10000.0),
|
||||||
|
"r",
|
||||||
|
label="NNMF + CNN Top (Iter 20, KL)",
|
||||||
|
)
|
||||||
|
|
||||||
|
data = np.load("data_log_nnmf_20_False_0.001_0.01_True_True_True_True.npy")
|
||||||
|
plt.loglog(
|
||||||
|
data[:, 0],
|
||||||
|
100.0 * (1.0 - data[:, 1] / 10000.0),
|
||||||
|
"r--",
|
||||||
|
label="NNMF (Iter 20, KL)",
|
||||||
|
)
|
||||||
|
|
||||||
|
data = np.load("data_log_nnmf_20_True_0.001_0.01_True_True_True_False.npy")
|
||||||
|
plt.loglog(
|
||||||
|
data[:, 0],
|
||||||
|
100.0 * (1.0 - data[:, 1] / 10000.0),
|
||||||
|
"b",
|
||||||
|
label="NNMF + CNN Top (Iter 20, MSE)",
|
||||||
|
)
|
||||||
|
|
||||||
|
data = np.load("data_log_nnmf_20_False_0.001_0.01_True_True_True_False.npy")
|
||||||
|
plt.loglog(
|
||||||
|
data[:, 0],
|
||||||
|
100.0 * (1.0 - data[:, 1] / 10000.0),
|
||||||
|
"b--",
|
||||||
|
label="NNMF (Iter 20, MSE)",
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.legend()
|
||||||
|
plt.xlabel("Epoch")
|
||||||
|
plt.ylabel("Error [%]")
|
||||||
|
plt.show()
|
272
run_network.py
Normal file
272
run_network.py
Normal file
|
@ -0,0 +1,272 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||||
|
|
||||||
|
import argh
|
||||||
|
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from make_network import make_network
|
||||||
|
from get_the_data import get_the_data
|
||||||
|
from loss_function import loss_function
|
||||||
|
from make_optimize import make_optimize
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
lr_initial_nnmf: float = 0.01,
|
||||||
|
lr_initial_cnn: float = 0.001,
|
||||||
|
lr_initial_cnn_top: float = 0.001,
|
||||||
|
iterations: int = 20,
|
||||||
|
cnn_top: bool = True,
|
||||||
|
use_nnmf: bool = True,
|
||||||
|
dataset: str = "CIFAR10", # "CIFAR10", "FashionMNIST", "MNIST"
|
||||||
|
rand_seed: int = 21,
|
||||||
|
enable_onoff: bool = False,
|
||||||
|
local_learning_0: bool = False,
|
||||||
|
local_learning_1: bool = False,
|
||||||
|
local_learning_2: bool = False,
|
||||||
|
local_learning_3: bool = False,
|
||||||
|
local_learning_kl: bool = False,
|
||||||
|
p_mode_0: bool = False,
|
||||||
|
p_mode_1: bool = False,
|
||||||
|
p_mode_2: bool = False,
|
||||||
|
p_mode_3: bool = False,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
lr_limit: float = 1e-9
|
||||||
|
|
||||||
|
torch.manual_seed(rand_seed)
|
||||||
|
|
||||||
|
torch_device: torch.device = (
|
||||||
|
torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
)
|
||||||
|
torch.set_default_dtype(torch.float32)
|
||||||
|
|
||||||
|
# Some parameters
|
||||||
|
batch_size_train: int = 500
|
||||||
|
batch_size_test: int = 500
|
||||||
|
number_of_epoch: int = 500
|
||||||
|
|
||||||
|
if use_nnmf:
|
||||||
|
prefix: str = "nnmf"
|
||||||
|
else:
|
||||||
|
prefix = "cnn"
|
||||||
|
|
||||||
|
default_path: str = (
|
||||||
|
f"{prefix}_{iterations}_{cnn_top}_{lr_initial_cnn}_{lr_initial_nnmf}_{local_learning_0}_{local_learning_1}_{local_learning_2}_{local_learning_kl}"
|
||||||
|
)
|
||||||
|
log_dir: str = f"log_{default_path}"
|
||||||
|
|
||||||
|
loss_mode: int = 0
|
||||||
|
loss_coeffs_mse: float = 0.5
|
||||||
|
loss_coeffs_kldiv: float = 1.0
|
||||||
|
print(
|
||||||
|
"loss_mode: ",
|
||||||
|
loss_mode,
|
||||||
|
"loss_coeffs_mse: ",
|
||||||
|
loss_coeffs_mse,
|
||||||
|
"loss_coeffs_kldiv: ",
|
||||||
|
loss_coeffs_kldiv,
|
||||||
|
)
|
||||||
|
|
||||||
|
if dataset == "MNIST" or dataset == "FashionMNIST":
|
||||||
|
input_number_of_channel: int = 1
|
||||||
|
input_dim_x: int = 24
|
||||||
|
input_dim_y: int = 24
|
||||||
|
else:
|
||||||
|
input_number_of_channel = 3
|
||||||
|
input_dim_x = 28
|
||||||
|
input_dim_y = 28
|
||||||
|
|
||||||
|
train_dataloader, test_dataloader, test_processing_chain, train_processing_chain = (
|
||||||
|
get_the_data(
|
||||||
|
dataset,
|
||||||
|
batch_size_train,
|
||||||
|
batch_size_test,
|
||||||
|
torch_device,
|
||||||
|
input_dim_x,
|
||||||
|
input_dim_y,
|
||||||
|
flip_p=0.5,
|
||||||
|
jitter_brightness=0.5,
|
||||||
|
jitter_contrast=0.1,
|
||||||
|
jitter_saturation=0.1,
|
||||||
|
jitter_hue=0.15,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
network, list_cnn_top_id, list_other_id = make_network(
|
||||||
|
use_nnmf=use_nnmf,
|
||||||
|
cnn_top=cnn_top,
|
||||||
|
input_dim_x=input_dim_x,
|
||||||
|
input_dim_y=input_dim_y,
|
||||||
|
input_number_of_channel=input_number_of_channel,
|
||||||
|
iterations=iterations,
|
||||||
|
enable_onoff=enable_onoff,
|
||||||
|
local_learning_0=local_learning_0,
|
||||||
|
local_learning_1=local_learning_1,
|
||||||
|
local_learning_2=local_learning_2,
|
||||||
|
local_learning_3=local_learning_3,
|
||||||
|
local_learning_kl=local_learning_kl,
|
||||||
|
p_mode_0=p_mode_0,
|
||||||
|
p_mode_1=p_mode_1,
|
||||||
|
p_mode_2=p_mode_2,
|
||||||
|
p_mode_3=p_mode_3,
|
||||||
|
)
|
||||||
|
network = network.to(torch_device)
|
||||||
|
|
||||||
|
print(network)
|
||||||
|
|
||||||
|
(
|
||||||
|
optimizer_nnmf,
|
||||||
|
optimizer_cnn,
|
||||||
|
optimizer_cnn_top,
|
||||||
|
lr_scheduler_nnmf,
|
||||||
|
lr_scheduler_cnn,
|
||||||
|
lr_scheduler_cnn_top,
|
||||||
|
) = make_optimize(
|
||||||
|
network=network,
|
||||||
|
list_cnn_top_id=list_cnn_top_id,
|
||||||
|
list_other_id=list_other_id,
|
||||||
|
lr_initial_nnmf=lr_initial_nnmf,
|
||||||
|
lr_initial_cnn=lr_initial_cnn,
|
||||||
|
lr_initial_cnn_top=lr_initial_cnn_top,
|
||||||
|
)
|
||||||
|
|
||||||
|
tb = SummaryWriter(log_dir=log_dir)
|
||||||
|
|
||||||
|
for epoch_id in range(0, number_of_epoch):
|
||||||
|
print()
|
||||||
|
print(f"Epoch: {epoch_id}")
|
||||||
|
t_start: float = time.perf_counter()
|
||||||
|
|
||||||
|
train_loss: float = 0.0
|
||||||
|
train_correct: int = 0
|
||||||
|
train_number: int = 0
|
||||||
|
test_correct: int = 0
|
||||||
|
test_number: int = 0
|
||||||
|
|
||||||
|
# Switch the network into training mode
|
||||||
|
network.train()
|
||||||
|
|
||||||
|
# This runs in total for one epoch split up into mini-batches
|
||||||
|
for image, target in train_dataloader:
|
||||||
|
# Clean the gradient
|
||||||
|
if optimizer_nnmf is not None:
|
||||||
|
optimizer_nnmf.zero_grad()
|
||||||
|
if optimizer_cnn is not None:
|
||||||
|
optimizer_cnn.zero_grad()
|
||||||
|
if optimizer_cnn_top is not None:
|
||||||
|
optimizer_cnn_top.zero_grad()
|
||||||
|
|
||||||
|
output = network(train_processing_chain(image))
|
||||||
|
|
||||||
|
loss = loss_function(
|
||||||
|
h=output,
|
||||||
|
labels=target,
|
||||||
|
number_of_output_neurons=output.shape[1],
|
||||||
|
loss_mode=loss_mode,
|
||||||
|
loss_coeffs_mse=loss_coeffs_mse,
|
||||||
|
loss_coeffs_kldiv=loss_coeffs_kldiv,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert loss is not None
|
||||||
|
train_loss += loss.item()
|
||||||
|
train_correct += (output.argmax(dim=1) == target).sum().cpu().numpy()
|
||||||
|
train_number += target.shape[0]
|
||||||
|
|
||||||
|
# Calculate backprop
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# Update the parameter
|
||||||
|
if optimizer_nnmf is not None:
|
||||||
|
optimizer_nnmf.step()
|
||||||
|
if optimizer_cnn is not None:
|
||||||
|
optimizer_cnn.step()
|
||||||
|
if optimizer_cnn_top is not None:
|
||||||
|
optimizer_cnn_top.step()
|
||||||
|
|
||||||
|
perfomance_train_correct: float = 100.0 * train_correct / train_number
|
||||||
|
# Update the learning rate
|
||||||
|
if lr_scheduler_nnmf is not None:
|
||||||
|
lr_scheduler_nnmf.step(train_loss)
|
||||||
|
|
||||||
|
if lr_scheduler_cnn is not None:
|
||||||
|
lr_scheduler_cnn.step(train_loss)
|
||||||
|
|
||||||
|
if lr_scheduler_cnn_top is not None:
|
||||||
|
lr_scheduler_cnn_top.step(train_loss)
|
||||||
|
|
||||||
|
print(
|
||||||
|
"Actual lr: ",
|
||||||
|
"nnmf: ",
|
||||||
|
lr_scheduler_nnmf.get_last_lr() if lr_scheduler_nnmf is not None else -1.0,
|
||||||
|
"cnn: ",
|
||||||
|
lr_scheduler_cnn.get_last_lr() if lr_scheduler_cnn is not None else -1.0,
|
||||||
|
"cnn top: ",
|
||||||
|
(
|
||||||
|
lr_scheduler_cnn_top.get_last_lr()
|
||||||
|
if lr_scheduler_cnn_top is not None
|
||||||
|
else -1.0
|
||||||
|
),
|
||||||
|
)
|
||||||
|
t_training: float = time.perf_counter()
|
||||||
|
|
||||||
|
# Switch the network into evalution mode
|
||||||
|
network.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
|
||||||
|
for image, target in test_dataloader:
|
||||||
|
output = network(test_processing_chain(image))
|
||||||
|
|
||||||
|
test_correct += (output.argmax(dim=1) == target).sum().cpu().numpy()
|
||||||
|
test_number += target.shape[0]
|
||||||
|
|
||||||
|
t_testing = time.perf_counter()
|
||||||
|
|
||||||
|
perfomance_test_correct: float = 100.0 * test_correct / test_number
|
||||||
|
|
||||||
|
tb.add_scalar("Train Loss", train_loss / float(train_number), epoch_id)
|
||||||
|
tb.add_scalar("Train Number Correct", train_correct, epoch_id)
|
||||||
|
tb.add_scalar("Test Number Correct", test_correct, epoch_id)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Training: Loss={train_loss / float(train_number):.5f} Correct={perfomance_train_correct:.2f}%"
|
||||||
|
)
|
||||||
|
print(f"Testing: Correct={perfomance_test_correct:.2f}%")
|
||||||
|
print(
|
||||||
|
f"Time: Training={(t_training - t_start):.1f}sec, Testing={(t_testing - t_training):.1f}sec"
|
||||||
|
)
|
||||||
|
|
||||||
|
tb.flush()
|
||||||
|
|
||||||
|
lr_check: list[float] = []
|
||||||
|
if lr_scheduler_nnmf is not None:
|
||||||
|
lr_check.append(lr_scheduler_nnmf.get_last_lr()[0])
|
||||||
|
if lr_scheduler_cnn is not None:
|
||||||
|
lr_check.append(lr_scheduler_cnn.get_last_lr()[0])
|
||||||
|
if lr_scheduler_cnn_top is not None:
|
||||||
|
lr_check.append(lr_scheduler_cnn_top.get_last_lr()[0])
|
||||||
|
|
||||||
|
lr_check_max = float(torch.tensor(lr_check).max())
|
||||||
|
|
||||||
|
if lr_check_max < lr_limit:
|
||||||
|
torch.save(network, f"Model_{default_path}.pt")
|
||||||
|
tb.close()
|
||||||
|
print("Done (lr_limit)")
|
||||||
|
return
|
||||||
|
|
||||||
|
torch.save(network, f"Model_{default_path}.pt")
|
||||||
|
print()
|
||||||
|
|
||||||
|
tb.close()
|
||||||
|
print("Done (loop end)")
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
argh.dispatch_command(main)
|
Loading…
Reference in a new issue