2024-05-30 14:08:44 +02:00
|
|
|
import torch
|
|
|
|
from NNMFConv2d import NNMFConv2d
|
|
|
|
from NNMFConv2dP import NNMFConv2dP
|
|
|
|
from SplitOnOffLayer import SplitOnOffLayer
|
|
|
|
|
|
|
|
|
|
|
|
def make_network(
|
|
|
|
use_nnmf: bool,
|
|
|
|
cnn_top: bool,
|
|
|
|
input_dim_x: int,
|
|
|
|
input_dim_y: int,
|
|
|
|
input_number_of_channel: int,
|
|
|
|
iterations: int,
|
|
|
|
init_min: float = 0.0,
|
|
|
|
init_max: float = 1.0,
|
|
|
|
use_convolution: bool = False,
|
|
|
|
convolution_contribution_map_enable: bool = False,
|
|
|
|
epsilon: bool | None = None,
|
|
|
|
positive_function_type: int = 0,
|
|
|
|
beta: float | None = None,
|
|
|
|
number_of_output_channels_conv1: int = 32,
|
|
|
|
number_of_output_channels_conv2: int = 64,
|
|
|
|
number_of_output_channels_flatten2: int = 96,
|
|
|
|
number_of_output_channels_full1: int = 10,
|
|
|
|
kernel_size_conv1: tuple[int, int] = (5, 5),
|
|
|
|
kernel_size_pool1: tuple[int, int] = (2, 2),
|
|
|
|
kernel_size_conv2: tuple[int, int] = (5, 5),
|
|
|
|
kernel_size_pool2: tuple[int, int] = (2, 2),
|
|
|
|
stride_conv1: tuple[int, int] = (1, 1),
|
|
|
|
stride_pool1: tuple[int, int] = (2, 2),
|
|
|
|
stride_conv2: tuple[int, int] = (1, 1),
|
|
|
|
stride_pool2: tuple[int, int] = (2, 2),
|
|
|
|
padding_conv1: int = 0,
|
|
|
|
padding_pool1: int = 0,
|
|
|
|
padding_conv2: int = 0,
|
|
|
|
padding_pool2: int = 0,
|
|
|
|
enable_onoff: bool = False,
|
|
|
|
local_learning_0: bool = False,
|
|
|
|
local_learning_1: bool = False,
|
|
|
|
local_learning_2: bool = False,
|
|
|
|
local_learning_3: bool = False,
|
|
|
|
local_learning_kl: bool = True,
|
|
|
|
p_mode_0: bool = False,
|
|
|
|
p_mode_1: bool = False,
|
|
|
|
p_mode_2: bool = False,
|
|
|
|
p_mode_3: bool = False,
|
2024-05-30 15:53:53 +02:00
|
|
|
use_reconstruction: bool = False,
|
|
|
|
max_pool: bool = True,
|
2024-05-30 14:08:44 +02:00
|
|
|
) -> tuple[torch.nn.Sequential, list[int], list[int]]:
|
|
|
|
|
|
|
|
if enable_onoff:
|
|
|
|
input_number_of_channel *= 2
|
|
|
|
|
|
|
|
list_cnn_top_id: list[int] = []
|
|
|
|
list_other_id: list[int] = []
|
|
|
|
|
|
|
|
test_image = torch.ones((1, input_number_of_channel, input_dim_x, input_dim_y))
|
|
|
|
|
|
|
|
network = torch.nn.Sequential()
|
|
|
|
|
|
|
|
if enable_onoff:
|
|
|
|
network.append(SplitOnOffLayer())
|
|
|
|
test_image = network[-1](test_image)
|
|
|
|
|
|
|
|
list_other_id.append(len(network))
|
|
|
|
if use_nnmf:
|
|
|
|
if p_mode_0:
|
|
|
|
network.append(
|
|
|
|
NNMFConv2dP(
|
|
|
|
in_channels=test_image.shape[1],
|
|
|
|
out_channels=number_of_output_channels_conv1,
|
|
|
|
kernel_size=kernel_size_conv1,
|
|
|
|
convolution_contribution_map_enable=convolution_contribution_map_enable,
|
|
|
|
epsilon=epsilon,
|
|
|
|
positive_function_type=positive_function_type,
|
|
|
|
init_min=init_min,
|
|
|
|
init_max=init_max,
|
|
|
|
beta=beta,
|
|
|
|
use_convolution=use_convolution,
|
|
|
|
iterations=iterations,
|
|
|
|
local_learning=local_learning_0,
|
|
|
|
local_learning_kl=local_learning_kl,
|
2024-05-30 15:53:53 +02:00
|
|
|
use_reconstruction=use_reconstruction,
|
2024-05-30 14:08:44 +02:00
|
|
|
)
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
network.append(
|
|
|
|
NNMFConv2d(
|
|
|
|
in_channels=test_image.shape[1],
|
|
|
|
out_channels=number_of_output_channels_conv1,
|
|
|
|
kernel_size=kernel_size_conv1,
|
|
|
|
convolution_contribution_map_enable=convolution_contribution_map_enable,
|
|
|
|
epsilon=epsilon,
|
|
|
|
positive_function_type=positive_function_type,
|
|
|
|
init_min=init_min,
|
|
|
|
init_max=init_max,
|
|
|
|
beta=beta,
|
|
|
|
use_convolution=use_convolution,
|
|
|
|
iterations=iterations,
|
|
|
|
local_learning=local_learning_0,
|
|
|
|
local_learning_kl=local_learning_kl,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
test_image = network[-1](test_image)
|
|
|
|
else:
|
|
|
|
network.append(
|
|
|
|
torch.nn.Conv2d(
|
|
|
|
in_channels=test_image.shape[1],
|
|
|
|
out_channels=number_of_output_channels_conv1,
|
|
|
|
kernel_size=kernel_size_conv1,
|
|
|
|
stride=stride_conv1,
|
|
|
|
padding=padding_conv1,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
test_image = network[-1](test_image)
|
|
|
|
network.append(torch.nn.ReLU())
|
|
|
|
test_image = network[-1](test_image)
|
|
|
|
|
|
|
|
if cnn_top:
|
|
|
|
list_cnn_top_id.append(len(network))
|
|
|
|
network.append(
|
|
|
|
torch.nn.Conv2d(
|
|
|
|
in_channels=test_image.shape[1],
|
|
|
|
out_channels=number_of_output_channels_conv1,
|
|
|
|
kernel_size=(1, 1),
|
|
|
|
stride=(1, 1),
|
|
|
|
padding=(0, 0),
|
|
|
|
bias=True,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
test_image = network[-1](test_image)
|
|
|
|
network.append(torch.nn.ReLU())
|
|
|
|
test_image = network[-1](test_image)
|
|
|
|
|
2024-05-30 15:53:53 +02:00
|
|
|
if max_pool:
|
|
|
|
network.append(
|
|
|
|
torch.nn.MaxPool2d(
|
|
|
|
kernel_size=kernel_size_pool1,
|
|
|
|
stride=stride_pool1,
|
|
|
|
padding=padding_pool1,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
network.append(
|
|
|
|
torch.nn.AvgPool2d(
|
|
|
|
kernel_size=kernel_size_pool1,
|
|
|
|
stride=stride_pool1,
|
|
|
|
padding=padding_pool1,
|
|
|
|
)
|
2024-05-30 14:08:44 +02:00
|
|
|
)
|
|
|
|
test_image = network[-1](test_image)
|
|
|
|
|
|
|
|
list_other_id.append(len(network))
|
|
|
|
if use_nnmf:
|
|
|
|
if p_mode_1:
|
|
|
|
network.append(
|
|
|
|
NNMFConv2dP(
|
|
|
|
in_channels=test_image.shape[1],
|
|
|
|
out_channels=number_of_output_channels_conv2,
|
|
|
|
kernel_size=kernel_size_conv2,
|
|
|
|
convolution_contribution_map_enable=convolution_contribution_map_enable,
|
|
|
|
epsilon=epsilon,
|
|
|
|
positive_function_type=positive_function_type,
|
|
|
|
init_min=init_min,
|
|
|
|
init_max=init_max,
|
|
|
|
beta=beta,
|
|
|
|
use_convolution=use_convolution,
|
|
|
|
iterations=iterations,
|
|
|
|
local_learning=local_learning_1,
|
|
|
|
local_learning_kl=local_learning_kl,
|
2024-05-30 15:53:53 +02:00
|
|
|
use_reconstruction=use_reconstruction,
|
2024-05-30 14:08:44 +02:00
|
|
|
)
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
network.append(
|
|
|
|
NNMFConv2d(
|
|
|
|
in_channels=test_image.shape[1],
|
|
|
|
out_channels=number_of_output_channels_conv2,
|
|
|
|
kernel_size=kernel_size_conv2,
|
|
|
|
convolution_contribution_map_enable=convolution_contribution_map_enable,
|
|
|
|
epsilon=epsilon,
|
|
|
|
positive_function_type=positive_function_type,
|
|
|
|
init_min=init_min,
|
|
|
|
init_max=init_max,
|
|
|
|
beta=beta,
|
|
|
|
use_convolution=use_convolution,
|
|
|
|
iterations=iterations,
|
|
|
|
local_learning=local_learning_1,
|
|
|
|
local_learning_kl=local_learning_kl,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
test_image = network[-1](test_image)
|
|
|
|
else:
|
|
|
|
network.append(
|
|
|
|
torch.nn.Conv2d(
|
|
|
|
in_channels=test_image.shape[1],
|
|
|
|
out_channels=number_of_output_channels_conv2,
|
|
|
|
kernel_size=kernel_size_conv2,
|
|
|
|
stride=stride_conv2,
|
|
|
|
padding=padding_conv2,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
test_image = network[-1](test_image)
|
|
|
|
network.append(torch.nn.ReLU())
|
|
|
|
test_image = network[-1](test_image)
|
|
|
|
|
|
|
|
if cnn_top:
|
|
|
|
list_cnn_top_id.append(len(network))
|
|
|
|
network.append(
|
|
|
|
torch.nn.Conv2d(
|
|
|
|
in_channels=test_image.shape[1],
|
|
|
|
out_channels=number_of_output_channels_conv2,
|
|
|
|
kernel_size=(1, 1),
|
|
|
|
stride=(1, 1),
|
|
|
|
padding=(0, 0),
|
|
|
|
bias=True,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
test_image = network[-1](test_image)
|
|
|
|
network.append(torch.nn.ReLU())
|
|
|
|
test_image = network[-1](test_image)
|
|
|
|
|
2024-05-30 15:53:53 +02:00
|
|
|
if max_pool:
|
|
|
|
network.append(
|
|
|
|
torch.nn.MaxPool2d(
|
|
|
|
kernel_size=kernel_size_pool2,
|
|
|
|
stride=stride_pool2,
|
|
|
|
padding=padding_pool2,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
network.append(
|
|
|
|
torch.nn.AvgPool2d(
|
|
|
|
kernel_size=kernel_size_pool2,
|
|
|
|
stride=stride_pool2,
|
|
|
|
padding=padding_pool2,
|
|
|
|
)
|
2024-05-30 14:08:44 +02:00
|
|
|
)
|
|
|
|
test_image = network[-1](test_image)
|
|
|
|
|
|
|
|
list_other_id.append(len(network))
|
|
|
|
if use_nnmf:
|
|
|
|
if p_mode_2:
|
|
|
|
network.append(
|
|
|
|
NNMFConv2dP(
|
|
|
|
in_channels=test_image.shape[1],
|
|
|
|
out_channels=number_of_output_channels_flatten2,
|
|
|
|
kernel_size=(test_image.shape[2], test_image.shape[3]),
|
|
|
|
convolution_contribution_map_enable=convolution_contribution_map_enable,
|
|
|
|
epsilon=epsilon,
|
|
|
|
positive_function_type=positive_function_type,
|
|
|
|
init_min=init_min,
|
|
|
|
init_max=init_max,
|
|
|
|
beta=beta,
|
|
|
|
use_convolution=use_convolution,
|
|
|
|
iterations=iterations,
|
|
|
|
local_learning=local_learning_2,
|
|
|
|
local_learning_kl=local_learning_kl,
|
2024-05-30 15:53:53 +02:00
|
|
|
use_reconstruction=use_reconstruction,
|
2024-05-30 14:08:44 +02:00
|
|
|
)
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
network.append(
|
|
|
|
NNMFConv2d(
|
|
|
|
in_channels=test_image.shape[1],
|
|
|
|
out_channels=number_of_output_channels_flatten2,
|
|
|
|
kernel_size=(test_image.shape[2], test_image.shape[3]),
|
|
|
|
convolution_contribution_map_enable=convolution_contribution_map_enable,
|
|
|
|
epsilon=epsilon,
|
|
|
|
positive_function_type=positive_function_type,
|
|
|
|
init_min=init_min,
|
|
|
|
init_max=init_max,
|
|
|
|
beta=beta,
|
|
|
|
use_convolution=use_convolution,
|
|
|
|
iterations=iterations,
|
|
|
|
local_learning=local_learning_2,
|
|
|
|
local_learning_kl=local_learning_kl,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
test_image = network[-1](test_image)
|
|
|
|
else:
|
|
|
|
network.append(
|
|
|
|
torch.nn.Conv2d(
|
|
|
|
in_channels=test_image.shape[1],
|
|
|
|
out_channels=number_of_output_channels_flatten2,
|
|
|
|
kernel_size=(test_image.shape[2], test_image.shape[3]),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
test_image = network[-1](test_image)
|
|
|
|
network.append(torch.nn.ReLU())
|
|
|
|
test_image = network[-1](test_image)
|
|
|
|
|
|
|
|
if cnn_top:
|
|
|
|
list_cnn_top_id.append(len(network))
|
|
|
|
network.append(
|
|
|
|
torch.nn.Conv2d(
|
|
|
|
in_channels=test_image.shape[1],
|
|
|
|
out_channels=number_of_output_channels_flatten2,
|
|
|
|
kernel_size=(1, 1),
|
|
|
|
stride=(1, 1),
|
|
|
|
padding=(0, 0),
|
|
|
|
bias=True,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
test_image = network[-1](test_image)
|
|
|
|
network.append(torch.nn.ReLU())
|
|
|
|
test_image = network[-1](test_image)
|
|
|
|
|
|
|
|
list_other_id.append(len(network))
|
|
|
|
if use_nnmf:
|
|
|
|
if p_mode_3:
|
|
|
|
network.append(
|
|
|
|
NNMFConv2dP(
|
|
|
|
in_channels=test_image.shape[1],
|
|
|
|
out_channels=number_of_output_channels_full1,
|
|
|
|
kernel_size=(test_image.shape[2], test_image.shape[3]),
|
|
|
|
convolution_contribution_map_enable=convolution_contribution_map_enable,
|
|
|
|
epsilon=epsilon,
|
|
|
|
positive_function_type=positive_function_type,
|
|
|
|
init_min=init_min,
|
|
|
|
init_max=init_max,
|
|
|
|
beta=beta,
|
|
|
|
use_convolution=use_convolution,
|
|
|
|
iterations=iterations,
|
|
|
|
local_learning=local_learning_3,
|
|
|
|
local_learning_kl=local_learning_kl,
|
2024-05-30 15:53:53 +02:00
|
|
|
use_reconstruction=use_reconstruction,
|
2024-05-30 14:08:44 +02:00
|
|
|
)
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
network.append(
|
|
|
|
NNMFConv2d(
|
|
|
|
in_channels=test_image.shape[1],
|
|
|
|
out_channels=number_of_output_channels_full1,
|
|
|
|
kernel_size=(test_image.shape[2], test_image.shape[3]),
|
|
|
|
convolution_contribution_map_enable=convolution_contribution_map_enable,
|
|
|
|
epsilon=epsilon,
|
|
|
|
positive_function_type=positive_function_type,
|
|
|
|
init_min=init_min,
|
|
|
|
init_max=init_max,
|
|
|
|
beta=beta,
|
|
|
|
use_convolution=use_convolution,
|
|
|
|
iterations=iterations,
|
|
|
|
local_learning=local_learning_3,
|
|
|
|
local_learning_kl=local_learning_kl,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
test_image = network[-1](test_image)
|
|
|
|
else:
|
|
|
|
network.append(
|
|
|
|
torch.nn.Conv2d(
|
|
|
|
in_channels=test_image.shape[1],
|
|
|
|
out_channels=number_of_output_channels_full1,
|
|
|
|
kernel_size=(test_image.shape[2], test_image.shape[3]),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
test_image = network[-1](test_image)
|
|
|
|
if cnn_top:
|
|
|
|
network.append(torch.nn.ReLU())
|
|
|
|
test_image = network[-1](test_image)
|
|
|
|
|
|
|
|
if cnn_top:
|
|
|
|
list_cnn_top_id.append(len(network))
|
|
|
|
network.append(
|
|
|
|
torch.nn.Conv2d(
|
|
|
|
in_channels=test_image.shape[1],
|
|
|
|
out_channels=number_of_output_channels_full1,
|
|
|
|
kernel_size=(1, 1),
|
|
|
|
stride=(1, 1),
|
|
|
|
padding=(0, 0),
|
|
|
|
bias=True,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
test_image = network[-1](test_image)
|
|
|
|
|
|
|
|
network.append(torch.nn.Flatten())
|
|
|
|
test_image = network[-1](test_image)
|
|
|
|
|
|
|
|
network.append(torch.nn.Softmax(dim=1))
|
|
|
|
test_image = network[-1](test_image)
|
|
|
|
|
|
|
|
return network, list_cnn_top_id, list_other_id
|