Add files via upload

This commit is contained in:
David Rotermund 2024-05-31 17:56:34 +02:00 committed by GitHub
parent 945f02d2e7
commit f907b5f601
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 548 additions and 347 deletions

13
L1NormLayer.py Normal file
View file

@ -0,0 +1,13 @@
import torch
class L1NormLayer(torch.nn.Module):
epsilon: float
def __init__(self, epsilon: float = 10e-20) -> None:
super().__init__()
self.epsilon = epsilon
def forward(self, input: torch.Tensor) -> torch.Tensor:
return input / (input.sum(dim=1, keepdim=True) + self.epsilon)

269
NNMF2d.py Normal file
View file

@ -0,0 +1,269 @@
import torch
from non_linear_weigth_function import non_linear_weigth_function
class NNMF2d(torch.nn.Module):
in_channels: int
out_channels: int
weight: torch.Tensor
bias: None | torch.Tensor
iterations: int
epsilon: float | None
init_min: float
init_max: float
beta: torch.Tensor | None
positive_function_type: int
local_learning: bool
local_learning_kl: bool
use_reconstruction: bool
skip_connection: bool
def __init__(
self,
in_channels: int,
out_channels: int,
device=None,
dtype=None,
iterations: int = 20,
epsilon: float | None = None,
init_min: float = 0.0,
init_max: float = 1.0,
beta: float | None = None,
positive_function_type: int = 0,
local_learning: bool = False,
local_learning_kl: bool = False,
use_reconstruction: bool = False,
skip_connection: bool = False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.positive_function_type = positive_function_type
self.init_min = init_min
self.init_max = init_max
self.in_channels = in_channels
self.out_channels = out_channels
self.iterations = iterations
self.local_learning = local_learning
self.local_learning_kl = local_learning_kl
self.use_reconstruction = use_reconstruction
self.weight = torch.nn.parameter.Parameter(
torch.empty((out_channels, in_channels), **factory_kwargs)
)
if beta is not None:
self.beta = torch.nn.parameter.Parameter(torch.empty((1), **factory_kwargs))
self.beta.data[0] = beta
else:
self.beta = None
self.reset_parameters()
self.functional_nnmf2d = FunctionalNNMF2d.apply
self.epsilon = epsilon
self.skip_connection = skip_connection
def extra_repr(self) -> str:
s: str = f"{self.in_channels}, {self.out_channels}"
if self.epsilon is not None:
s += f", epsilon={self.epsilon}"
s += f", pfunctype={self.positive_function_type}"
s += f", local_learning={self.local_learning}"
if self.local_learning:
s += f", local_learning_kl={self.local_learning_kl}"
return s
def reset_parameters(self) -> None:
torch.nn.init.uniform_(self.weight, a=self.init_min, b=self.init_max)
def forward(self, input: torch.Tensor) -> torch.Tensor:
positive_weights = non_linear_weigth_function(
self.weight, self.beta, self.positive_function_type
)
positive_weights = positive_weights / (
positive_weights.sum(dim=1, keepdim=True) + 10e-20
)
h_dyn = self.functional_nnmf2d(
input,
positive_weights,
self.out_channels,
self.iterations,
self.epsilon,
self.local_learning,
self.local_learning_kl,
)
if self.skip_connection:
if self.use_reconstruction:
reconstruction = torch.nn.functional.linear(
h_dyn.movedim(1, -1), positive_weights.T
).movedim(-1, 1)
output = torch.cat((h_dyn, input - reconstruction), dim=1)
else:
output = torch.cat((h_dyn, input), dim=1)
return output
else:
return h_dyn
class FunctionalNNMF2d(torch.autograd.Function):
@staticmethod
def forward( # type: ignore
ctx,
input: torch.Tensor,
weight: torch.Tensor,
out_channels: int,
iterations: int,
epsilon: float | None,
local_learning: bool,
local_learning_kl: bool,
) -> torch.Tensor:
# Prepare h
h = torch.full(
(input.shape[0], out_channels, input.shape[-2], input.shape[-1]),
1.0 / float(out_channels),
device=input.device,
dtype=input.dtype,
)
h = h.movedim(1, -1)
input = input.movedim(1, -1)
for _ in range(0, iterations):
reconstruction = torch.nn.functional.linear(h, weight.T)
reconstruction += 1e-20
if epsilon is None:
h *= torch.nn.functional.linear((input / reconstruction), weight)
else:
h *= 1 + epsilon * torch.nn.functional.linear(
(input / reconstruction), weight
)
h /= h.sum(-1, keepdim=True) + 10e-20
h = h.movedim(-1, 1)
input = input.movedim(-1, 1)
# ###########################################################
# Save the necessary data for the backward pass
# ###########################################################
ctx.save_for_backward(input, weight, h)
ctx.local_learning = local_learning
ctx.local_learning_kl = local_learning_kl
assert torch.isfinite(h).all()
return h
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_output: torch.Tensor) -> tuple[ # type: ignore
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
None,
None,
]:
# ##############################################
# Default values
# ##############################################
grad_input: torch.Tensor | None = None
grad_weight: torch.Tensor | None = None
# ##############################################
# Get the variables back
# ##############################################
(input, weight, h) = ctx.saved_tensors
# The back prop gradient
h = h.movedim(1, -1)
grad_output = grad_output.movedim(1, -1)
input = input.movedim(1, -1)
big_r = torch.nn.functional.linear(h, weight.T)
big_r_div = 1.0 / (big_r + 1e-20)
factor_x_div_r = input * big_r_div
grad_input = torch.nn.functional.linear(h * grad_output, weight.T) * big_r_div
del big_r_div
# The weight gradient
if ctx.local_learning is False:
del big_r
grad_weight = -torch.nn.functional.linear(
h.reshape(
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
h.shape[3],
).T,
(factor_x_div_r * grad_input)
.reshape(
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
grad_input.shape[3],
)
.T,
)
grad_weight += torch.nn.functional.linear(
(h * grad_output)
.reshape(
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
h.shape[3],
)
.T,
factor_x_div_r.reshape(
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
grad_input.shape[3],
).T,
)
else:
if ctx.local_learning_kl:
grad_weight = -torch.nn.functional.linear(
h.reshape(
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
h.shape[3],
).T,
factor_x_div_r.reshape(
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
grad_input.shape[3],
).T,
)
else:
grad_weight = -torch.nn.functional.linear(
h.reshape(
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
h.shape[3],
).T,
(2 * (input - big_r))
.reshape(
grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2],
grad_input.shape[3],
)
.T,
)
grad_input = grad_input.movedim(-1, 1)
assert torch.isfinite(grad_input).all()
assert torch.isfinite(grad_weight).all()
return (
grad_input,
grad_weight,
None,
None,
None,
None,
None,
)

48
append_input_conv2d.py Normal file
View file

@ -0,0 +1,48 @@
import torch
def append_input_conv2d(
network: torch.nn.Sequential,
test_image: torch.tensor,
dilation: int = 1,
padding: int = 0,
stride: int = 1,
kernel_size: list[int] = [5, 5],
) -> torch.Tensor:
mock_output = (
torch.nn.functional.conv2d(
torch.zeros(
1,
1,
test_image.shape[2],
test_image.shape[3],
),
torch.zeros((1, 1, kernel_size[0], kernel_size[1])),
stride=stride,
padding=padding,
dilation=dilation,
)
.squeeze(0)
.squeeze(0)
)
network.append(
torch.nn.Unfold(
kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride
)
)
test_image = network[-1](test_image)
network.append(
torch.nn.Fold(
output_size=mock_output.shape,
kernel_size=(1, 1),
dilation=1,
padding=0,
stride=1,
)
)
test_image = network[-1](test_image)
return test_image

63
append_nnmf_block.py Normal file
View file

@ -0,0 +1,63 @@
import torch
from append_input_conv2d import append_input_conv2d
from L1NormLayer import L1NormLayer
from NNMF2d import NNMF2d
def append_nnmf_block(
network: torch.nn.Sequential,
out_channels: int,
test_image: torch.tensor,
list_other_id: list[int],
dilation: int = 1,
padding: int = 0,
stride: int = 1,
kernel_size: list[int] = [5, 5],
epsilon: float | None = None,
positive_function_type: int = 0,
beta: float | None = None,
iterations: int = 20,
local_learning: bool = False,
local_learning_kl: bool = False,
use_reconstruction: bool = False,
skip_connection: bool = False,
) -> torch.Tensor:
kernel_size_internal: list[int] = list(kernel_size)
if kernel_size[0] < 1:
kernel_size_internal[0] = test_image.shape[-2]
if kernel_size[1] < 1:
kernel_size_internal[1] = test_image.shape[-1]
test_image = append_input_conv2d(
network=network,
test_image=test_image,
dilation=dilation,
padding=padding,
stride=stride,
kernel_size=kernel_size_internal,
)
network.append(L1NormLayer())
test_image = network[-1](test_image)
list_other_id.append(len(network))
network.append(
NNMF2d(
in_channels=test_image.shape[1],
out_channels=out_channels,
epsilon=epsilon,
positive_function_type=positive_function_type,
beta=beta,
iterations=iterations,
local_learning=local_learning,
local_learning_kl=local_learning_kl,
use_reconstruction=use_reconstruction,
skip_connection=skip_connection,
)
)
test_image = network[-1](test_image)
return test_image

View file

@ -1,7 +1,6 @@
import torch
from NNMFConv2d import NNMFConv2d
from NNMFConv2dP import NNMFConv2dP
from SplitOnOffLayer import SplitOnOffLayer
from append_nnmf_block import append_nnmf_block
def make_network(
@ -11,43 +10,79 @@ def make_network(
input_dim_y: int,
input_number_of_channel: int,
iterations: int,
init_min: float = 0.0,
init_max: float = 1.0,
use_convolution: bool = False,
convolution_contribution_map_enable: bool = False,
epsilon: bool | None = None,
positive_function_type: int = 0,
beta: float | None = None,
number_of_output_channels_conv1: int = 32,
number_of_output_channels_conv2: int = 64,
number_of_output_channels_flatten2: int = 96,
number_of_output_channels_full1: int = 10,
kernel_size_conv1: tuple[int, int] = (5, 5),
kernel_size_pool1: tuple[int, int] = (2, 2),
kernel_size_conv2: tuple[int, int] = (5, 5),
kernel_size_pool2: tuple[int, int] = (2, 2),
stride_conv1: tuple[int, int] = (1, 1),
stride_pool1: tuple[int, int] = (2, 2),
stride_conv2: tuple[int, int] = (1, 1),
stride_pool2: tuple[int, int] = (2, 2),
padding_conv1: int = 0,
padding_pool1: int = 0,
padding_conv2: int = 0,
padding_pool2: int = 0,
enable_onoff: bool = False,
local_learning_0: bool = False,
local_learning_1: bool = False,
local_learning_2: bool = False,
local_learning_3: bool = False,
# Conv:
number_of_output_channels: list[int] = [32, 64, 96, 10],
kernel_size_conv: list[tuple[int, int]] = [
(5, 5),
(5, 5),
(-1, -1), # Take the whole input image x and y size
(1, 1),
],
stride_conv: list[tuple[int, int]] = [
(1, 1),
(1, 1),
(1, 1),
(1, 1),
],
padding_conv: list[tuple[int, int]] = [
(0, 0),
(0, 0),
(0, 0),
(0, 0),
],
dilation_conv: list[tuple[int, int]] = [
(1, 1),
(1, 1),
(1, 1),
(1, 1),
],
# Pool:
kernel_size_pool: list[tuple[int, int]] = [
(2, 2),
(2, 2),
(-1, -1), # No pooling layer
(-1, -1), # No pooling layer
],
stride_pool: list[tuple[int, int]] = [
(2, 2),
(2, 2),
(-1, -1),
(-1, -1),
],
padding_pool: list[tuple[int, int]] = [
(0, 0),
(0, 0),
(0, 0),
(0, 0),
],
dilation_pool: list[tuple[int, int]] = [
(1, 1),
(1, 1),
(1, 1),
(1, 1),
],
local_learning: list[bool] = [False, False, False, False],
skip_connection: list[bool] = [False, False, False, False],
local_learning_kl: bool = True,
p_mode_0: bool = False,
p_mode_1: bool = False,
p_mode_2: bool = False,
p_mode_3: bool = False,
use_reconstruction: bool = False,
max_pool: bool = True,
enable_onoff: bool = False,
) -> tuple[torch.nn.Sequential, list[int], list[int]]:
assert len(number_of_output_channels) == len(kernel_size_conv)
assert len(number_of_output_channels) == len(stride_conv)
assert len(number_of_output_channels) == len(padding_conv)
assert len(number_of_output_channels) == len(dilation_conv)
assert len(number_of_output_channels) == len(kernel_size_pool)
assert len(number_of_output_channels) == len(stride_pool)
assert len(number_of_output_channels) == len(padding_pool)
assert len(number_of_output_channels) == len(dilation_pool)
assert len(number_of_output_channels) == len(local_learning)
assert len(number_of_output_channels) == len(skip_connection)
if enable_onoff:
input_number_of_channel *= 2
@ -62,57 +97,48 @@ def make_network(
network.append(SplitOnOffLayer())
test_image = network[-1](test_image)
list_other_id.append(len(network))
for block_id in range(0, len(number_of_output_channels)):
if use_nnmf:
if p_mode_0:
network.append(
NNMFConv2dP(
in_channels=test_image.shape[1],
out_channels=number_of_output_channels_conv1,
kernel_size=kernel_size_conv1,
convolution_contribution_map_enable=convolution_contribution_map_enable,
test_image = append_nnmf_block(
network=network,
out_channels=number_of_output_channels[block_id],
test_image=test_image,
list_other_id=list_other_id,
dilation=dilation_conv[block_id],
padding=padding_conv[block_id],
stride=stride_conv[block_id],
kernel_size=kernel_size_conv[block_id],
epsilon=epsilon,
positive_function_type=positive_function_type,
init_min=init_min,
init_max=init_max,
beta=beta,
use_convolution=use_convolution,
iterations=iterations,
local_learning=local_learning_0,
local_learning=local_learning[block_id],
local_learning_kl=local_learning_kl,
use_reconstruction=use_reconstruction,
)
skip_connection=skip_connection[block_id],
)
else:
network.append(
NNMFConv2d(
in_channels=test_image.shape[1],
out_channels=number_of_output_channels_conv1,
kernel_size=kernel_size_conv1,
convolution_contribution_map_enable=convolution_contribution_map_enable,
epsilon=epsilon,
positive_function_type=positive_function_type,
init_min=init_min,
init_max=init_max,
beta=beta,
use_convolution=use_convolution,
iterations=iterations,
local_learning=local_learning_0,
local_learning_kl=local_learning_kl,
)
)
test_image = network[-1](test_image)
else:
list_other_id.append(len(network))
kernel_size_conv_internal = list(kernel_size_conv[block_id])
if kernel_size_conv[block_id][0] == -1:
kernel_size_conv_internal[0] = test_image.shape[-2]
if kernel_size_conv[block_id][1] == -1:
kernel_size_conv_internal[1] = test_image.shape[-1]
network.append(
torch.nn.Conv2d(
in_channels=test_image.shape[1],
out_channels=number_of_output_channels_conv1,
kernel_size=kernel_size_conv1,
stride=stride_conv1,
padding=padding_conv1,
out_channels=number_of_output_channels[block_id],
kernel_size=kernel_size_conv_internal,
stride=1,
padding=0,
)
)
test_image = network[-1](test_image)
if cnn_top or block_id < len(number_of_output_channels) - 1:
network.append(torch.nn.ReLU())
test_image = network[-1](test_image)
@ -121,7 +147,7 @@ def make_network(
network.append(
torch.nn.Conv2d(
in_channels=test_image.shape[1],
out_channels=number_of_output_channels_conv1,
out_channels=number_of_output_channels[block_id],
kernel_size=(1, 1),
stride=(1, 1),
padding=(0, 0),
@ -129,246 +155,25 @@ def make_network(
)
)
test_image = network[-1](test_image)
if block_id < len(number_of_output_channels) - 1:
network.append(torch.nn.ReLU())
test_image = network[-1](test_image)
if (kernel_size_pool[block_id][0] > 0) and (kernel_size_pool[block_id][1] > 0):
if max_pool:
network.append(
torch.nn.MaxPool2d(
kernel_size=kernel_size_pool1,
stride=stride_pool1,
padding=padding_pool1,
kernel_size=kernel_size_pool[block_id],
stride=stride_pool[block_id],
padding=padding_pool[block_id],
)
)
else:
network.append(
torch.nn.AvgPool2d(
kernel_size=kernel_size_pool1,
stride=stride_pool1,
padding=padding_pool1,
)
)
test_image = network[-1](test_image)
list_other_id.append(len(network))
if use_nnmf:
if p_mode_1:
network.append(
NNMFConv2dP(
in_channels=test_image.shape[1],
out_channels=number_of_output_channels_conv2,
kernel_size=kernel_size_conv2,
convolution_contribution_map_enable=convolution_contribution_map_enable,
epsilon=epsilon,
positive_function_type=positive_function_type,
init_min=init_min,
init_max=init_max,
beta=beta,
use_convolution=use_convolution,
iterations=iterations,
local_learning=local_learning_1,
local_learning_kl=local_learning_kl,
use_reconstruction=use_reconstruction,
)
)
else:
network.append(
NNMFConv2d(
in_channels=test_image.shape[1],
out_channels=number_of_output_channels_conv2,
kernel_size=kernel_size_conv2,
convolution_contribution_map_enable=convolution_contribution_map_enable,
epsilon=epsilon,
positive_function_type=positive_function_type,
init_min=init_min,
init_max=init_max,
beta=beta,
use_convolution=use_convolution,
iterations=iterations,
local_learning=local_learning_1,
local_learning_kl=local_learning_kl,
)
)
test_image = network[-1](test_image)
else:
network.append(
torch.nn.Conv2d(
in_channels=test_image.shape[1],
out_channels=number_of_output_channels_conv2,
kernel_size=kernel_size_conv2,
stride=stride_conv2,
padding=padding_conv2,
)
)
test_image = network[-1](test_image)
network.append(torch.nn.ReLU())
test_image = network[-1](test_image)
if cnn_top:
list_cnn_top_id.append(len(network))
network.append(
torch.nn.Conv2d(
in_channels=test_image.shape[1],
out_channels=number_of_output_channels_conv2,
kernel_size=(1, 1),
stride=(1, 1),
padding=(0, 0),
bias=True,
)
)
test_image = network[-1](test_image)
network.append(torch.nn.ReLU())
test_image = network[-1](test_image)
if max_pool:
network.append(
torch.nn.MaxPool2d(
kernel_size=kernel_size_pool2,
stride=stride_pool2,
padding=padding_pool2,
)
)
else:
network.append(
torch.nn.AvgPool2d(
kernel_size=kernel_size_pool2,
stride=stride_pool2,
padding=padding_pool2,
)
)
test_image = network[-1](test_image)
list_other_id.append(len(network))
if use_nnmf:
if p_mode_2:
network.append(
NNMFConv2dP(
in_channels=test_image.shape[1],
out_channels=number_of_output_channels_flatten2,
kernel_size=(test_image.shape[2], test_image.shape[3]),
convolution_contribution_map_enable=convolution_contribution_map_enable,
epsilon=epsilon,
positive_function_type=positive_function_type,
init_min=init_min,
init_max=init_max,
beta=beta,
use_convolution=use_convolution,
iterations=iterations,
local_learning=local_learning_2,
local_learning_kl=local_learning_kl,
use_reconstruction=use_reconstruction,
)
)
else:
network.append(
NNMFConv2d(
in_channels=test_image.shape[1],
out_channels=number_of_output_channels_flatten2,
kernel_size=(test_image.shape[2], test_image.shape[3]),
convolution_contribution_map_enable=convolution_contribution_map_enable,
epsilon=epsilon,
positive_function_type=positive_function_type,
init_min=init_min,
init_max=init_max,
beta=beta,
use_convolution=use_convolution,
iterations=iterations,
local_learning=local_learning_2,
local_learning_kl=local_learning_kl,
)
)
test_image = network[-1](test_image)
else:
network.append(
torch.nn.Conv2d(
in_channels=test_image.shape[1],
out_channels=number_of_output_channels_flatten2,
kernel_size=(test_image.shape[2], test_image.shape[3]),
)
)
test_image = network[-1](test_image)
network.append(torch.nn.ReLU())
test_image = network[-1](test_image)
if cnn_top:
list_cnn_top_id.append(len(network))
network.append(
torch.nn.Conv2d(
in_channels=test_image.shape[1],
out_channels=number_of_output_channels_flatten2,
kernel_size=(1, 1),
stride=(1, 1),
padding=(0, 0),
bias=True,
)
)
test_image = network[-1](test_image)
network.append(torch.nn.ReLU())
test_image = network[-1](test_image)
list_other_id.append(len(network))
if use_nnmf:
if p_mode_3:
network.append(
NNMFConv2dP(
in_channels=test_image.shape[1],
out_channels=number_of_output_channels_full1,
kernel_size=(test_image.shape[2], test_image.shape[3]),
convolution_contribution_map_enable=convolution_contribution_map_enable,
epsilon=epsilon,
positive_function_type=positive_function_type,
init_min=init_min,
init_max=init_max,
beta=beta,
use_convolution=use_convolution,
iterations=iterations,
local_learning=local_learning_3,
local_learning_kl=local_learning_kl,
use_reconstruction=use_reconstruction,
)
)
else:
network.append(
NNMFConv2d(
in_channels=test_image.shape[1],
out_channels=number_of_output_channels_full1,
kernel_size=(test_image.shape[2], test_image.shape[3]),
convolution_contribution_map_enable=convolution_contribution_map_enable,
epsilon=epsilon,
positive_function_type=positive_function_type,
init_min=init_min,
init_max=init_max,
beta=beta,
use_convolution=use_convolution,
iterations=iterations,
local_learning=local_learning_3,
local_learning_kl=local_learning_kl,
)
)
test_image = network[-1](test_image)
else:
network.append(
torch.nn.Conv2d(
in_channels=test_image.shape[1],
out_channels=number_of_output_channels_full1,
kernel_size=(test_image.shape[2], test_image.shape[3]),
)
)
test_image = network[-1](test_image)
if cnn_top:
network.append(torch.nn.ReLU())
test_image = network[-1](test_image)
if cnn_top:
list_cnn_top_id.append(len(network))
network.append(
torch.nn.Conv2d(
in_channels=test_image.shape[1],
out_channels=number_of_output_channels_full1,
kernel_size=(1, 1),
stride=(1, 1),
padding=(0, 0),
bias=True,
kernel_size=kernel_size_pool[block_id],
stride=stride_pool[block_id],
padding=padding_pool[block_id],
)
)
test_image = network[-1](test_image)

View file

@ -1,6 +1,5 @@
import torch
from NNMFConv2d import NNMFConv2d
from NNMFConv2dP import NNMFConv2dP
from NNMF2d import NNMF2d
def make_optimize(
@ -45,7 +44,7 @@ def make_optimize(
for netp in network[layerid].parameters():
list_cnn.append(netp)
if isinstance(network[layerid], (NNMFConv2d, NNMFConv2dP)):
if isinstance(network[layerid], NNMF2d):
for netp in network[layerid].parameters():
list_nnmf.append(netp)

View file

@ -30,10 +30,10 @@ def main(
local_learning_2: bool = False,
local_learning_3: bool = False,
local_learning_kl: bool = False,
p_mode_0: bool = False,
p_mode_1: bool = False,
p_mode_2: bool = False,
p_mode_3: bool = False,
skip_connection_0: bool = True,
skip_connection_1: bool = True,
skip_connection_2: bool = True,
skip_connection_3: bool = True,
use_reconstruction: bool = False,
max_pool: bool = True,
) -> None:
@ -107,15 +107,19 @@ def main(
input_number_of_channel=input_number_of_channel,
iterations=iterations,
enable_onoff=enable_onoff,
local_learning_0=local_learning_0,
local_learning_1=local_learning_1,
local_learning_2=local_learning_2,
local_learning_3=local_learning_3,
local_learning=[
local_learning_0,
local_learning_1,
local_learning_2,
local_learning_3,
],
local_learning_kl=local_learning_kl,
p_mode_0=p_mode_0,
p_mode_1=p_mode_1,
p_mode_2=p_mode_2,
p_mode_3=p_mode_3,
skip_connection=[
skip_connection_0,
skip_connection_1,
skip_connection_2,
skip_connection_3,
],
use_reconstruction=use_reconstruction,
max_pool=max_pool,
)