Add files via upload
This commit is contained in:
parent
945f02d2e7
commit
f907b5f601
7 changed files with 548 additions and 347 deletions
13
L1NormLayer.py
Normal file
13
L1NormLayer.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class L1NormLayer(torch.nn.Module):
|
||||||
|
|
||||||
|
epsilon: float
|
||||||
|
|
||||||
|
def __init__(self, epsilon: float = 10e-20) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.epsilon = epsilon
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
return input / (input.sum(dim=1, keepdim=True) + self.epsilon)
|
269
NNMF2d.py
Normal file
269
NNMF2d.py
Normal file
|
@ -0,0 +1,269 @@
|
||||||
|
import torch
|
||||||
|
from non_linear_weigth_function import non_linear_weigth_function
|
||||||
|
|
||||||
|
|
||||||
|
class NNMF2d(torch.nn.Module):
|
||||||
|
|
||||||
|
in_channels: int
|
||||||
|
out_channels: int
|
||||||
|
weight: torch.Tensor
|
||||||
|
bias: None | torch.Tensor
|
||||||
|
iterations: int
|
||||||
|
epsilon: float | None
|
||||||
|
init_min: float
|
||||||
|
init_max: float
|
||||||
|
beta: torch.Tensor | None
|
||||||
|
positive_function_type: int
|
||||||
|
local_learning: bool
|
||||||
|
local_learning_kl: bool
|
||||||
|
use_reconstruction: bool
|
||||||
|
skip_connection: bool
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
iterations: int = 20,
|
||||||
|
epsilon: float | None = None,
|
||||||
|
init_min: float = 0.0,
|
||||||
|
init_max: float = 1.0,
|
||||||
|
beta: float | None = None,
|
||||||
|
positive_function_type: int = 0,
|
||||||
|
local_learning: bool = False,
|
||||||
|
local_learning_kl: bool = False,
|
||||||
|
use_reconstruction: bool = False,
|
||||||
|
skip_connection: bool = False,
|
||||||
|
) -> None:
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.positive_function_type = positive_function_type
|
||||||
|
self.init_min = init_min
|
||||||
|
self.init_max = init_max
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
self.iterations = iterations
|
||||||
|
self.local_learning = local_learning
|
||||||
|
self.local_learning_kl = local_learning_kl
|
||||||
|
|
||||||
|
self.use_reconstruction = use_reconstruction
|
||||||
|
|
||||||
|
self.weight = torch.nn.parameter.Parameter(
|
||||||
|
torch.empty((out_channels, in_channels), **factory_kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
|
if beta is not None:
|
||||||
|
self.beta = torch.nn.parameter.Parameter(torch.empty((1), **factory_kwargs))
|
||||||
|
self.beta.data[0] = beta
|
||||||
|
else:
|
||||||
|
self.beta = None
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
self.functional_nnmf2d = FunctionalNNMF2d.apply
|
||||||
|
|
||||||
|
self.epsilon = epsilon
|
||||||
|
|
||||||
|
self.skip_connection = skip_connection
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
s: str = f"{self.in_channels}, {self.out_channels}"
|
||||||
|
|
||||||
|
if self.epsilon is not None:
|
||||||
|
s += f", epsilon={self.epsilon}"
|
||||||
|
s += f", pfunctype={self.positive_function_type}"
|
||||||
|
s += f", local_learning={self.local_learning}"
|
||||||
|
|
||||||
|
if self.local_learning:
|
||||||
|
s += f", local_learning_kl={self.local_learning_kl}"
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
def reset_parameters(self) -> None:
|
||||||
|
torch.nn.init.uniform_(self.weight, a=self.init_min, b=self.init_max)
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
positive_weights = non_linear_weigth_function(
|
||||||
|
self.weight, self.beta, self.positive_function_type
|
||||||
|
)
|
||||||
|
positive_weights = positive_weights / (
|
||||||
|
positive_weights.sum(dim=1, keepdim=True) + 10e-20
|
||||||
|
)
|
||||||
|
|
||||||
|
h_dyn = self.functional_nnmf2d(
|
||||||
|
input,
|
||||||
|
positive_weights,
|
||||||
|
self.out_channels,
|
||||||
|
self.iterations,
|
||||||
|
self.epsilon,
|
||||||
|
self.local_learning,
|
||||||
|
self.local_learning_kl,
|
||||||
|
)
|
||||||
|
if self.skip_connection:
|
||||||
|
if self.use_reconstruction:
|
||||||
|
reconstruction = torch.nn.functional.linear(
|
||||||
|
h_dyn.movedim(1, -1), positive_weights.T
|
||||||
|
).movedim(-1, 1)
|
||||||
|
output = torch.cat((h_dyn, input - reconstruction), dim=1)
|
||||||
|
else:
|
||||||
|
output = torch.cat((h_dyn, input), dim=1)
|
||||||
|
return output
|
||||||
|
else:
|
||||||
|
return h_dyn
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionalNNMF2d(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward( # type: ignore
|
||||||
|
ctx,
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
out_channels: int,
|
||||||
|
iterations: int,
|
||||||
|
epsilon: float | None,
|
||||||
|
local_learning: bool,
|
||||||
|
local_learning_kl: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# Prepare h
|
||||||
|
h = torch.full(
|
||||||
|
(input.shape[0], out_channels, input.shape[-2], input.shape[-1]),
|
||||||
|
1.0 / float(out_channels),
|
||||||
|
device=input.device,
|
||||||
|
dtype=input.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
h = h.movedim(1, -1)
|
||||||
|
input = input.movedim(1, -1)
|
||||||
|
for _ in range(0, iterations):
|
||||||
|
reconstruction = torch.nn.functional.linear(h, weight.T)
|
||||||
|
reconstruction += 1e-20
|
||||||
|
if epsilon is None:
|
||||||
|
h *= torch.nn.functional.linear((input / reconstruction), weight)
|
||||||
|
else:
|
||||||
|
h *= 1 + epsilon * torch.nn.functional.linear(
|
||||||
|
(input / reconstruction), weight
|
||||||
|
)
|
||||||
|
h /= h.sum(-1, keepdim=True) + 10e-20
|
||||||
|
h = h.movedim(-1, 1)
|
||||||
|
input = input.movedim(-1, 1)
|
||||||
|
|
||||||
|
# ###########################################################
|
||||||
|
# Save the necessary data for the backward pass
|
||||||
|
# ###########################################################
|
||||||
|
ctx.save_for_backward(input, weight, h)
|
||||||
|
ctx.local_learning = local_learning
|
||||||
|
ctx.local_learning_kl = local_learning_kl
|
||||||
|
|
||||||
|
assert torch.isfinite(h).all()
|
||||||
|
return h
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.autograd.function.once_differentiable
|
||||||
|
def backward(ctx, grad_output: torch.Tensor) -> tuple[ # type: ignore
|
||||||
|
torch.Tensor | None,
|
||||||
|
torch.Tensor | None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
]:
|
||||||
|
|
||||||
|
# ##############################################
|
||||||
|
# Default values
|
||||||
|
# ##############################################
|
||||||
|
grad_input: torch.Tensor | None = None
|
||||||
|
grad_weight: torch.Tensor | None = None
|
||||||
|
|
||||||
|
# ##############################################
|
||||||
|
# Get the variables back
|
||||||
|
# ##############################################
|
||||||
|
(input, weight, h) = ctx.saved_tensors
|
||||||
|
|
||||||
|
# The back prop gradient
|
||||||
|
h = h.movedim(1, -1)
|
||||||
|
grad_output = grad_output.movedim(1, -1)
|
||||||
|
input = input.movedim(1, -1)
|
||||||
|
big_r = torch.nn.functional.linear(h, weight.T)
|
||||||
|
big_r_div = 1.0 / (big_r + 1e-20)
|
||||||
|
|
||||||
|
factor_x_div_r = input * big_r_div
|
||||||
|
|
||||||
|
grad_input = torch.nn.functional.linear(h * grad_output, weight.T) * big_r_div
|
||||||
|
|
||||||
|
del big_r_div
|
||||||
|
|
||||||
|
# The weight gradient
|
||||||
|
if ctx.local_learning is False:
|
||||||
|
del big_r
|
||||||
|
|
||||||
|
grad_weight = -torch.nn.functional.linear(
|
||||||
|
h.reshape(
|
||||||
|
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
||||||
|
h.shape[3],
|
||||||
|
).T,
|
||||||
|
(factor_x_div_r * grad_input)
|
||||||
|
.reshape(
|
||||||
|
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
||||||
|
grad_input.shape[3],
|
||||||
|
)
|
||||||
|
.T,
|
||||||
|
)
|
||||||
|
|
||||||
|
grad_weight += torch.nn.functional.linear(
|
||||||
|
(h * grad_output)
|
||||||
|
.reshape(
|
||||||
|
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
||||||
|
h.shape[3],
|
||||||
|
)
|
||||||
|
.T,
|
||||||
|
factor_x_div_r.reshape(
|
||||||
|
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
||||||
|
grad_input.shape[3],
|
||||||
|
).T,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if ctx.local_learning_kl:
|
||||||
|
grad_weight = -torch.nn.functional.linear(
|
||||||
|
h.reshape(
|
||||||
|
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
||||||
|
h.shape[3],
|
||||||
|
).T,
|
||||||
|
factor_x_div_r.reshape(
|
||||||
|
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
||||||
|
grad_input.shape[3],
|
||||||
|
).T,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
grad_weight = -torch.nn.functional.linear(
|
||||||
|
h.reshape(
|
||||||
|
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
||||||
|
h.shape[3],
|
||||||
|
).T,
|
||||||
|
(2 * (input - big_r))
|
||||||
|
.reshape(
|
||||||
|
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
|
||||||
|
grad_input.shape[3],
|
||||||
|
)
|
||||||
|
.T,
|
||||||
|
)
|
||||||
|
grad_input = grad_input.movedim(-1, 1)
|
||||||
|
assert torch.isfinite(grad_input).all()
|
||||||
|
assert torch.isfinite(grad_weight).all()
|
||||||
|
|
||||||
|
return (
|
||||||
|
grad_input,
|
||||||
|
grad_weight,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
48
append_input_conv2d.py
Normal file
48
append_input_conv2d.py
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def append_input_conv2d(
|
||||||
|
network: torch.nn.Sequential,
|
||||||
|
test_image: torch.tensor,
|
||||||
|
dilation: int = 1,
|
||||||
|
padding: int = 0,
|
||||||
|
stride: int = 1,
|
||||||
|
kernel_size: list[int] = [5, 5],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
mock_output = (
|
||||||
|
torch.nn.functional.conv2d(
|
||||||
|
torch.zeros(
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
test_image.shape[2],
|
||||||
|
test_image.shape[3],
|
||||||
|
),
|
||||||
|
torch.zeros((1, 1, kernel_size[0], kernel_size[1])),
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
)
|
||||||
|
.squeeze(0)
|
||||||
|
.squeeze(0)
|
||||||
|
)
|
||||||
|
|
||||||
|
network.append(
|
||||||
|
torch.nn.Unfold(
|
||||||
|
kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride
|
||||||
|
)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
network.append(
|
||||||
|
torch.nn.Fold(
|
||||||
|
output_size=mock_output.shape,
|
||||||
|
kernel_size=(1, 1),
|
||||||
|
dilation=1,
|
||||||
|
padding=0,
|
||||||
|
stride=1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
return test_image
|
63
append_nnmf_block.py
Normal file
63
append_nnmf_block.py
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
import torch
|
||||||
|
from append_input_conv2d import append_input_conv2d
|
||||||
|
from L1NormLayer import L1NormLayer
|
||||||
|
from NNMF2d import NNMF2d
|
||||||
|
|
||||||
|
|
||||||
|
def append_nnmf_block(
|
||||||
|
network: torch.nn.Sequential,
|
||||||
|
out_channels: int,
|
||||||
|
test_image: torch.tensor,
|
||||||
|
list_other_id: list[int],
|
||||||
|
dilation: int = 1,
|
||||||
|
padding: int = 0,
|
||||||
|
stride: int = 1,
|
||||||
|
kernel_size: list[int] = [5, 5],
|
||||||
|
epsilon: float | None = None,
|
||||||
|
positive_function_type: int = 0,
|
||||||
|
beta: float | None = None,
|
||||||
|
iterations: int = 20,
|
||||||
|
local_learning: bool = False,
|
||||||
|
local_learning_kl: bool = False,
|
||||||
|
use_reconstruction: bool = False,
|
||||||
|
skip_connection: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
kernel_size_internal: list[int] = list(kernel_size)
|
||||||
|
|
||||||
|
if kernel_size[0] < 1:
|
||||||
|
kernel_size_internal[0] = test_image.shape[-2]
|
||||||
|
|
||||||
|
if kernel_size[1] < 1:
|
||||||
|
kernel_size_internal[1] = test_image.shape[-1]
|
||||||
|
|
||||||
|
test_image = append_input_conv2d(
|
||||||
|
network=network,
|
||||||
|
test_image=test_image,
|
||||||
|
dilation=dilation,
|
||||||
|
padding=padding,
|
||||||
|
stride=stride,
|
||||||
|
kernel_size=kernel_size_internal,
|
||||||
|
)
|
||||||
|
|
||||||
|
network.append(L1NormLayer())
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
list_other_id.append(len(network))
|
||||||
|
network.append(
|
||||||
|
NNMF2d(
|
||||||
|
in_channels=test_image.shape[1],
|
||||||
|
out_channels=out_channels,
|
||||||
|
epsilon=epsilon,
|
||||||
|
positive_function_type=positive_function_type,
|
||||||
|
beta=beta,
|
||||||
|
iterations=iterations,
|
||||||
|
local_learning=local_learning,
|
||||||
|
local_learning_kl=local_learning_kl,
|
||||||
|
use_reconstruction=use_reconstruction,
|
||||||
|
skip_connection=skip_connection,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
return test_image
|
469
make_network.py
469
make_network.py
|
@ -1,7 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
from NNMFConv2d import NNMFConv2d
|
|
||||||
from NNMFConv2dP import NNMFConv2dP
|
|
||||||
from SplitOnOffLayer import SplitOnOffLayer
|
from SplitOnOffLayer import SplitOnOffLayer
|
||||||
|
from append_nnmf_block import append_nnmf_block
|
||||||
|
|
||||||
|
|
||||||
def make_network(
|
def make_network(
|
||||||
|
@ -11,43 +10,79 @@ def make_network(
|
||||||
input_dim_y: int,
|
input_dim_y: int,
|
||||||
input_number_of_channel: int,
|
input_number_of_channel: int,
|
||||||
iterations: int,
|
iterations: int,
|
||||||
init_min: float = 0.0,
|
|
||||||
init_max: float = 1.0,
|
|
||||||
use_convolution: bool = False,
|
|
||||||
convolution_contribution_map_enable: bool = False,
|
|
||||||
epsilon: bool | None = None,
|
epsilon: bool | None = None,
|
||||||
positive_function_type: int = 0,
|
positive_function_type: int = 0,
|
||||||
beta: float | None = None,
|
beta: float | None = None,
|
||||||
number_of_output_channels_conv1: int = 32,
|
# Conv:
|
||||||
number_of_output_channels_conv2: int = 64,
|
number_of_output_channels: list[int] = [32, 64, 96, 10],
|
||||||
number_of_output_channels_flatten2: int = 96,
|
kernel_size_conv: list[tuple[int, int]] = [
|
||||||
number_of_output_channels_full1: int = 10,
|
(5, 5),
|
||||||
kernel_size_conv1: tuple[int, int] = (5, 5),
|
(5, 5),
|
||||||
kernel_size_pool1: tuple[int, int] = (2, 2),
|
(-1, -1), # Take the whole input image x and y size
|
||||||
kernel_size_conv2: tuple[int, int] = (5, 5),
|
(1, 1),
|
||||||
kernel_size_pool2: tuple[int, int] = (2, 2),
|
],
|
||||||
stride_conv1: tuple[int, int] = (1, 1),
|
stride_conv: list[tuple[int, int]] = [
|
||||||
stride_pool1: tuple[int, int] = (2, 2),
|
(1, 1),
|
||||||
stride_conv2: tuple[int, int] = (1, 1),
|
(1, 1),
|
||||||
stride_pool2: tuple[int, int] = (2, 2),
|
(1, 1),
|
||||||
padding_conv1: int = 0,
|
(1, 1),
|
||||||
padding_pool1: int = 0,
|
],
|
||||||
padding_conv2: int = 0,
|
padding_conv: list[tuple[int, int]] = [
|
||||||
padding_pool2: int = 0,
|
(0, 0),
|
||||||
enable_onoff: bool = False,
|
(0, 0),
|
||||||
local_learning_0: bool = False,
|
(0, 0),
|
||||||
local_learning_1: bool = False,
|
(0, 0),
|
||||||
local_learning_2: bool = False,
|
],
|
||||||
local_learning_3: bool = False,
|
dilation_conv: list[tuple[int, int]] = [
|
||||||
|
(1, 1),
|
||||||
|
(1, 1),
|
||||||
|
(1, 1),
|
||||||
|
(1, 1),
|
||||||
|
],
|
||||||
|
# Pool:
|
||||||
|
kernel_size_pool: list[tuple[int, int]] = [
|
||||||
|
(2, 2),
|
||||||
|
(2, 2),
|
||||||
|
(-1, -1), # No pooling layer
|
||||||
|
(-1, -1), # No pooling layer
|
||||||
|
],
|
||||||
|
stride_pool: list[tuple[int, int]] = [
|
||||||
|
(2, 2),
|
||||||
|
(2, 2),
|
||||||
|
(-1, -1),
|
||||||
|
(-1, -1),
|
||||||
|
],
|
||||||
|
padding_pool: list[tuple[int, int]] = [
|
||||||
|
(0, 0),
|
||||||
|
(0, 0),
|
||||||
|
(0, 0),
|
||||||
|
(0, 0),
|
||||||
|
],
|
||||||
|
dilation_pool: list[tuple[int, int]] = [
|
||||||
|
(1, 1),
|
||||||
|
(1, 1),
|
||||||
|
(1, 1),
|
||||||
|
(1, 1),
|
||||||
|
],
|
||||||
|
local_learning: list[bool] = [False, False, False, False],
|
||||||
|
skip_connection: list[bool] = [False, False, False, False],
|
||||||
local_learning_kl: bool = True,
|
local_learning_kl: bool = True,
|
||||||
p_mode_0: bool = False,
|
|
||||||
p_mode_1: bool = False,
|
|
||||||
p_mode_2: bool = False,
|
|
||||||
p_mode_3: bool = False,
|
|
||||||
use_reconstruction: bool = False,
|
use_reconstruction: bool = False,
|
||||||
max_pool: bool = True,
|
max_pool: bool = True,
|
||||||
|
enable_onoff: bool = False,
|
||||||
) -> tuple[torch.nn.Sequential, list[int], list[int]]:
|
) -> tuple[torch.nn.Sequential, list[int], list[int]]:
|
||||||
|
|
||||||
|
assert len(number_of_output_channels) == len(kernel_size_conv)
|
||||||
|
assert len(number_of_output_channels) == len(stride_conv)
|
||||||
|
assert len(number_of_output_channels) == len(padding_conv)
|
||||||
|
assert len(number_of_output_channels) == len(dilation_conv)
|
||||||
|
assert len(number_of_output_channels) == len(kernel_size_pool)
|
||||||
|
assert len(number_of_output_channels) == len(stride_pool)
|
||||||
|
assert len(number_of_output_channels) == len(padding_pool)
|
||||||
|
assert len(number_of_output_channels) == len(dilation_pool)
|
||||||
|
assert len(number_of_output_channels) == len(local_learning)
|
||||||
|
assert len(number_of_output_channels) == len(skip_connection)
|
||||||
|
|
||||||
if enable_onoff:
|
if enable_onoff:
|
||||||
input_number_of_channel *= 2
|
input_number_of_channel *= 2
|
||||||
|
|
||||||
|
@ -62,316 +97,86 @@ def make_network(
|
||||||
network.append(SplitOnOffLayer())
|
network.append(SplitOnOffLayer())
|
||||||
test_image = network[-1](test_image)
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
list_other_id.append(len(network))
|
for block_id in range(0, len(number_of_output_channels)):
|
||||||
if use_nnmf:
|
if use_nnmf:
|
||||||
if p_mode_0:
|
test_image = append_nnmf_block(
|
||||||
network.append(
|
network=network,
|
||||||
NNMFConv2dP(
|
out_channels=number_of_output_channels[block_id],
|
||||||
in_channels=test_image.shape[1],
|
test_image=test_image,
|
||||||
out_channels=number_of_output_channels_conv1,
|
list_other_id=list_other_id,
|
||||||
kernel_size=kernel_size_conv1,
|
dilation=dilation_conv[block_id],
|
||||||
convolution_contribution_map_enable=convolution_contribution_map_enable,
|
padding=padding_conv[block_id],
|
||||||
epsilon=epsilon,
|
stride=stride_conv[block_id],
|
||||||
positive_function_type=positive_function_type,
|
kernel_size=kernel_size_conv[block_id],
|
||||||
init_min=init_min,
|
epsilon=epsilon,
|
||||||
init_max=init_max,
|
positive_function_type=positive_function_type,
|
||||||
beta=beta,
|
beta=beta,
|
||||||
use_convolution=use_convolution,
|
iterations=iterations,
|
||||||
iterations=iterations,
|
local_learning=local_learning[block_id],
|
||||||
local_learning=local_learning_0,
|
local_learning_kl=local_learning_kl,
|
||||||
local_learning_kl=local_learning_kl,
|
use_reconstruction=use_reconstruction,
|
||||||
use_reconstruction=use_reconstruction,
|
skip_connection=skip_connection[block_id],
|
||||||
)
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
list_other_id.append(len(network))
|
||||||
|
|
||||||
|
kernel_size_conv_internal = list(kernel_size_conv[block_id])
|
||||||
|
|
||||||
|
if kernel_size_conv[block_id][0] == -1:
|
||||||
|
kernel_size_conv_internal[0] = test_image.shape[-2]
|
||||||
|
|
||||||
|
if kernel_size_conv[block_id][1] == -1:
|
||||||
|
kernel_size_conv_internal[1] = test_image.shape[-1]
|
||||||
|
|
||||||
network.append(
|
network.append(
|
||||||
NNMFConv2d(
|
torch.nn.Conv2d(
|
||||||
in_channels=test_image.shape[1],
|
in_channels=test_image.shape[1],
|
||||||
out_channels=number_of_output_channels_conv1,
|
out_channels=number_of_output_channels[block_id],
|
||||||
kernel_size=kernel_size_conv1,
|
kernel_size=kernel_size_conv_internal,
|
||||||
convolution_contribution_map_enable=convolution_contribution_map_enable,
|
stride=1,
|
||||||
epsilon=epsilon,
|
padding=0,
|
||||||
positive_function_type=positive_function_type,
|
|
||||||
init_min=init_min,
|
|
||||||
init_max=init_max,
|
|
||||||
beta=beta,
|
|
||||||
use_convolution=use_convolution,
|
|
||||||
iterations=iterations,
|
|
||||||
local_learning=local_learning_0,
|
|
||||||
local_learning_kl=local_learning_kl,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
test_image = network[-1](test_image)
|
|
||||||
else:
|
|
||||||
network.append(
|
|
||||||
torch.nn.Conv2d(
|
|
||||||
in_channels=test_image.shape[1],
|
|
||||||
out_channels=number_of_output_channels_conv1,
|
|
||||||
kernel_size=kernel_size_conv1,
|
|
||||||
stride=stride_conv1,
|
|
||||||
padding=padding_conv1,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
test_image = network[-1](test_image)
|
|
||||||
network.append(torch.nn.ReLU())
|
|
||||||
test_image = network[-1](test_image)
|
|
||||||
|
|
||||||
if cnn_top:
|
|
||||||
list_cnn_top_id.append(len(network))
|
|
||||||
network.append(
|
|
||||||
torch.nn.Conv2d(
|
|
||||||
in_channels=test_image.shape[1],
|
|
||||||
out_channels=number_of_output_channels_conv1,
|
|
||||||
kernel_size=(1, 1),
|
|
||||||
stride=(1, 1),
|
|
||||||
padding=(0, 0),
|
|
||||||
bias=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
test_image = network[-1](test_image)
|
|
||||||
network.append(torch.nn.ReLU())
|
|
||||||
test_image = network[-1](test_image)
|
|
||||||
|
|
||||||
if max_pool:
|
|
||||||
network.append(
|
|
||||||
torch.nn.MaxPool2d(
|
|
||||||
kernel_size=kernel_size_pool1,
|
|
||||||
stride=stride_pool1,
|
|
||||||
padding=padding_pool1,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
network.append(
|
|
||||||
torch.nn.AvgPool2d(
|
|
||||||
kernel_size=kernel_size_pool1,
|
|
||||||
stride=stride_pool1,
|
|
||||||
padding=padding_pool1,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
test_image = network[-1](test_image)
|
|
||||||
|
|
||||||
list_other_id.append(len(network))
|
|
||||||
if use_nnmf:
|
|
||||||
if p_mode_1:
|
|
||||||
network.append(
|
|
||||||
NNMFConv2dP(
|
|
||||||
in_channels=test_image.shape[1],
|
|
||||||
out_channels=number_of_output_channels_conv2,
|
|
||||||
kernel_size=kernel_size_conv2,
|
|
||||||
convolution_contribution_map_enable=convolution_contribution_map_enable,
|
|
||||||
epsilon=epsilon,
|
|
||||||
positive_function_type=positive_function_type,
|
|
||||||
init_min=init_min,
|
|
||||||
init_max=init_max,
|
|
||||||
beta=beta,
|
|
||||||
use_convolution=use_convolution,
|
|
||||||
iterations=iterations,
|
|
||||||
local_learning=local_learning_1,
|
|
||||||
local_learning_kl=local_learning_kl,
|
|
||||||
use_reconstruction=use_reconstruction,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
network.append(
|
|
||||||
NNMFConv2d(
|
|
||||||
in_channels=test_image.shape[1],
|
|
||||||
out_channels=number_of_output_channels_conv2,
|
|
||||||
kernel_size=kernel_size_conv2,
|
|
||||||
convolution_contribution_map_enable=convolution_contribution_map_enable,
|
|
||||||
epsilon=epsilon,
|
|
||||||
positive_function_type=positive_function_type,
|
|
||||||
init_min=init_min,
|
|
||||||
init_max=init_max,
|
|
||||||
beta=beta,
|
|
||||||
use_convolution=use_convolution,
|
|
||||||
iterations=iterations,
|
|
||||||
local_learning=local_learning_1,
|
|
||||||
local_learning_kl=local_learning_kl,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
test_image = network[-1](test_image)
|
|
||||||
else:
|
|
||||||
network.append(
|
|
||||||
torch.nn.Conv2d(
|
|
||||||
in_channels=test_image.shape[1],
|
|
||||||
out_channels=number_of_output_channels_conv2,
|
|
||||||
kernel_size=kernel_size_conv2,
|
|
||||||
stride=stride_conv2,
|
|
||||||
padding=padding_conv2,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
test_image = network[-1](test_image)
|
|
||||||
network.append(torch.nn.ReLU())
|
|
||||||
test_image = network[-1](test_image)
|
|
||||||
|
|
||||||
if cnn_top:
|
|
||||||
list_cnn_top_id.append(len(network))
|
|
||||||
network.append(
|
|
||||||
torch.nn.Conv2d(
|
|
||||||
in_channels=test_image.shape[1],
|
|
||||||
out_channels=number_of_output_channels_conv2,
|
|
||||||
kernel_size=(1, 1),
|
|
||||||
stride=(1, 1),
|
|
||||||
padding=(0, 0),
|
|
||||||
bias=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
test_image = network[-1](test_image)
|
|
||||||
network.append(torch.nn.ReLU())
|
|
||||||
test_image = network[-1](test_image)
|
|
||||||
|
|
||||||
if max_pool:
|
|
||||||
network.append(
|
|
||||||
torch.nn.MaxPool2d(
|
|
||||||
kernel_size=kernel_size_pool2,
|
|
||||||
stride=stride_pool2,
|
|
||||||
padding=padding_pool2,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
network.append(
|
|
||||||
torch.nn.AvgPool2d(
|
|
||||||
kernel_size=kernel_size_pool2,
|
|
||||||
stride=stride_pool2,
|
|
||||||
padding=padding_pool2,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
test_image = network[-1](test_image)
|
|
||||||
|
|
||||||
list_other_id.append(len(network))
|
|
||||||
if use_nnmf:
|
|
||||||
if p_mode_2:
|
|
||||||
network.append(
|
|
||||||
NNMFConv2dP(
|
|
||||||
in_channels=test_image.shape[1],
|
|
||||||
out_channels=number_of_output_channels_flatten2,
|
|
||||||
kernel_size=(test_image.shape[2], test_image.shape[3]),
|
|
||||||
convolution_contribution_map_enable=convolution_contribution_map_enable,
|
|
||||||
epsilon=epsilon,
|
|
||||||
positive_function_type=positive_function_type,
|
|
||||||
init_min=init_min,
|
|
||||||
init_max=init_max,
|
|
||||||
beta=beta,
|
|
||||||
use_convolution=use_convolution,
|
|
||||||
iterations=iterations,
|
|
||||||
local_learning=local_learning_2,
|
|
||||||
local_learning_kl=local_learning_kl,
|
|
||||||
use_reconstruction=use_reconstruction,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
network.append(
|
|
||||||
NNMFConv2d(
|
|
||||||
in_channels=test_image.shape[1],
|
|
||||||
out_channels=number_of_output_channels_flatten2,
|
|
||||||
kernel_size=(test_image.shape[2], test_image.shape[3]),
|
|
||||||
convolution_contribution_map_enable=convolution_contribution_map_enable,
|
|
||||||
epsilon=epsilon,
|
|
||||||
positive_function_type=positive_function_type,
|
|
||||||
init_min=init_min,
|
|
||||||
init_max=init_max,
|
|
||||||
beta=beta,
|
|
||||||
use_convolution=use_convolution,
|
|
||||||
iterations=iterations,
|
|
||||||
local_learning=local_learning_2,
|
|
||||||
local_learning_kl=local_learning_kl,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
test_image = network[-1](test_image)
|
|
||||||
else:
|
|
||||||
network.append(
|
|
||||||
torch.nn.Conv2d(
|
|
||||||
in_channels=test_image.shape[1],
|
|
||||||
out_channels=number_of_output_channels_flatten2,
|
|
||||||
kernel_size=(test_image.shape[2], test_image.shape[3]),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
test_image = network[-1](test_image)
|
|
||||||
network.append(torch.nn.ReLU())
|
|
||||||
test_image = network[-1](test_image)
|
|
||||||
|
|
||||||
if cnn_top:
|
|
||||||
list_cnn_top_id.append(len(network))
|
|
||||||
network.append(
|
|
||||||
torch.nn.Conv2d(
|
|
||||||
in_channels=test_image.shape[1],
|
|
||||||
out_channels=number_of_output_channels_flatten2,
|
|
||||||
kernel_size=(1, 1),
|
|
||||||
stride=(1, 1),
|
|
||||||
padding=(0, 0),
|
|
||||||
bias=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
test_image = network[-1](test_image)
|
|
||||||
network.append(torch.nn.ReLU())
|
|
||||||
test_image = network[-1](test_image)
|
|
||||||
|
|
||||||
list_other_id.append(len(network))
|
|
||||||
if use_nnmf:
|
|
||||||
if p_mode_3:
|
|
||||||
network.append(
|
|
||||||
NNMFConv2dP(
|
|
||||||
in_channels=test_image.shape[1],
|
|
||||||
out_channels=number_of_output_channels_full1,
|
|
||||||
kernel_size=(test_image.shape[2], test_image.shape[3]),
|
|
||||||
convolution_contribution_map_enable=convolution_contribution_map_enable,
|
|
||||||
epsilon=epsilon,
|
|
||||||
positive_function_type=positive_function_type,
|
|
||||||
init_min=init_min,
|
|
||||||
init_max=init_max,
|
|
||||||
beta=beta,
|
|
||||||
use_convolution=use_convolution,
|
|
||||||
iterations=iterations,
|
|
||||||
local_learning=local_learning_3,
|
|
||||||
local_learning_kl=local_learning_kl,
|
|
||||||
use_reconstruction=use_reconstruction,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
network.append(
|
|
||||||
NNMFConv2d(
|
|
||||||
in_channels=test_image.shape[1],
|
|
||||||
out_channels=number_of_output_channels_full1,
|
|
||||||
kernel_size=(test_image.shape[2], test_image.shape[3]),
|
|
||||||
convolution_contribution_map_enable=convolution_contribution_map_enable,
|
|
||||||
epsilon=epsilon,
|
|
||||||
positive_function_type=positive_function_type,
|
|
||||||
init_min=init_min,
|
|
||||||
init_max=init_max,
|
|
||||||
beta=beta,
|
|
||||||
use_convolution=use_convolution,
|
|
||||||
iterations=iterations,
|
|
||||||
local_learning=local_learning_3,
|
|
||||||
local_learning_kl=local_learning_kl,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
test_image = network[-1](test_image)
|
|
||||||
else:
|
|
||||||
network.append(
|
|
||||||
torch.nn.Conv2d(
|
|
||||||
in_channels=test_image.shape[1],
|
|
||||||
out_channels=number_of_output_channels_full1,
|
|
||||||
kernel_size=(test_image.shape[2], test_image.shape[3]),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
test_image = network[-1](test_image)
|
|
||||||
if cnn_top:
|
|
||||||
network.append(torch.nn.ReLU())
|
|
||||||
test_image = network[-1](test_image)
|
test_image = network[-1](test_image)
|
||||||
|
if cnn_top or block_id < len(number_of_output_channels) - 1:
|
||||||
|
network.append(torch.nn.ReLU())
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
if cnn_top:
|
if cnn_top:
|
||||||
list_cnn_top_id.append(len(network))
|
list_cnn_top_id.append(len(network))
|
||||||
network.append(
|
network.append(
|
||||||
torch.nn.Conv2d(
|
torch.nn.Conv2d(
|
||||||
in_channels=test_image.shape[1],
|
in_channels=test_image.shape[1],
|
||||||
out_channels=number_of_output_channels_full1,
|
out_channels=number_of_output_channels[block_id],
|
||||||
kernel_size=(1, 1),
|
kernel_size=(1, 1),
|
||||||
stride=(1, 1),
|
stride=(1, 1),
|
||||||
padding=(0, 0),
|
padding=(0, 0),
|
||||||
bias=True,
|
bias=True,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
test_image = network[-1](test_image)
|
||||||
test_image = network[-1](test_image)
|
if block_id < len(number_of_output_channels) - 1:
|
||||||
|
network.append(torch.nn.ReLU())
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
|
if (kernel_size_pool[block_id][0] > 0) and (kernel_size_pool[block_id][1] > 0):
|
||||||
|
if max_pool:
|
||||||
|
network.append(
|
||||||
|
torch.nn.MaxPool2d(
|
||||||
|
kernel_size=kernel_size_pool[block_id],
|
||||||
|
stride=stride_pool[block_id],
|
||||||
|
padding=padding_pool[block_id],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
network.append(
|
||||||
|
torch.nn.AvgPool2d(
|
||||||
|
kernel_size=kernel_size_pool[block_id],
|
||||||
|
stride=stride_pool[block_id],
|
||||||
|
padding=padding_pool[block_id],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
test_image = network[-1](test_image)
|
||||||
|
|
||||||
network.append(torch.nn.Flatten())
|
network.append(torch.nn.Flatten())
|
||||||
test_image = network[-1](test_image)
|
test_image = network[-1](test_image)
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
from NNMFConv2d import NNMFConv2d
|
from NNMF2d import NNMF2d
|
||||||
from NNMFConv2dP import NNMFConv2dP
|
|
||||||
|
|
||||||
|
|
||||||
def make_optimize(
|
def make_optimize(
|
||||||
|
@ -45,7 +44,7 @@ def make_optimize(
|
||||||
for netp in network[layerid].parameters():
|
for netp in network[layerid].parameters():
|
||||||
list_cnn.append(netp)
|
list_cnn.append(netp)
|
||||||
|
|
||||||
if isinstance(network[layerid], (NNMFConv2d, NNMFConv2dP)):
|
if isinstance(network[layerid], NNMF2d):
|
||||||
for netp in network[layerid].parameters():
|
for netp in network[layerid].parameters():
|
||||||
list_nnmf.append(netp)
|
list_nnmf.append(netp)
|
||||||
|
|
||||||
|
|
|
@ -30,10 +30,10 @@ def main(
|
||||||
local_learning_2: bool = False,
|
local_learning_2: bool = False,
|
||||||
local_learning_3: bool = False,
|
local_learning_3: bool = False,
|
||||||
local_learning_kl: bool = False,
|
local_learning_kl: bool = False,
|
||||||
p_mode_0: bool = False,
|
skip_connection_0: bool = True,
|
||||||
p_mode_1: bool = False,
|
skip_connection_1: bool = True,
|
||||||
p_mode_2: bool = False,
|
skip_connection_2: bool = True,
|
||||||
p_mode_3: bool = False,
|
skip_connection_3: bool = True,
|
||||||
use_reconstruction: bool = False,
|
use_reconstruction: bool = False,
|
||||||
max_pool: bool = True,
|
max_pool: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -107,15 +107,19 @@ def main(
|
||||||
input_number_of_channel=input_number_of_channel,
|
input_number_of_channel=input_number_of_channel,
|
||||||
iterations=iterations,
|
iterations=iterations,
|
||||||
enable_onoff=enable_onoff,
|
enable_onoff=enable_onoff,
|
||||||
local_learning_0=local_learning_0,
|
local_learning=[
|
||||||
local_learning_1=local_learning_1,
|
local_learning_0,
|
||||||
local_learning_2=local_learning_2,
|
local_learning_1,
|
||||||
local_learning_3=local_learning_3,
|
local_learning_2,
|
||||||
|
local_learning_3,
|
||||||
|
],
|
||||||
local_learning_kl=local_learning_kl,
|
local_learning_kl=local_learning_kl,
|
||||||
p_mode_0=p_mode_0,
|
skip_connection=[
|
||||||
p_mode_1=p_mode_1,
|
skip_connection_0,
|
||||||
p_mode_2=p_mode_2,
|
skip_connection_1,
|
||||||
p_mode_3=p_mode_3,
|
skip_connection_2,
|
||||||
|
skip_connection_3,
|
||||||
|
],
|
||||||
use_reconstruction=use_reconstruction,
|
use_reconstruction=use_reconstruction,
|
||||||
max_pool=max_pool,
|
max_pool=max_pool,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue