pynnmf/make_network.py

188 lines
6.1 KiB
Python
Raw Permalink Normal View History

2024-05-30 14:08:44 +02:00
import torch
from SplitOnOffLayer import SplitOnOffLayer
2024-05-31 17:56:34 +02:00
from append_nnmf_block import append_nnmf_block
2024-05-30 14:08:44 +02:00
def make_network(
use_nnmf: bool,
cnn_top: bool,
input_dim_x: int,
input_dim_y: int,
input_number_of_channel: int,
iterations: int,
epsilon: bool | None = None,
positive_function_type: int = 0,
beta: float | None = None,
2024-05-31 17:56:34 +02:00
# Conv:
number_of_output_channels: list[int] = [32, 64, 96, 10],
kernel_size_conv: list[tuple[int, int]] = [
(5, 5),
(5, 5),
(-1, -1), # Take the whole input image x and y size
(1, 1),
],
stride_conv: list[tuple[int, int]] = [
(1, 1),
(1, 1),
(1, 1),
(1, 1),
],
padding_conv: list[tuple[int, int]] = [
(0, 0),
(0, 0),
(0, 0),
(0, 0),
],
dilation_conv: list[tuple[int, int]] = [
(1, 1),
(1, 1),
(1, 1),
(1, 1),
],
# Pool:
kernel_size_pool: list[tuple[int, int]] = [
(2, 2),
(2, 2),
(-1, -1), # No pooling layer
(-1, -1), # No pooling layer
],
stride_pool: list[tuple[int, int]] = [
(2, 2),
(2, 2),
(-1, -1),
(-1, -1),
],
padding_pool: list[tuple[int, int]] = [
(0, 0),
(0, 0),
(0, 0),
(0, 0),
],
dilation_pool: list[tuple[int, int]] = [
(1, 1),
(1, 1),
(1, 1),
(1, 1),
],
local_learning: list[bool] = [False, False, False, False],
skip_connection: list[bool] = [False, False, False, False],
2024-05-30 14:08:44 +02:00
local_learning_kl: bool = True,
2024-05-30 15:53:53 +02:00
use_reconstruction: bool = False,
max_pool: bool = True,
2024-05-31 17:56:34 +02:00
enable_onoff: bool = False,
2024-05-30 14:08:44 +02:00
) -> tuple[torch.nn.Sequential, list[int], list[int]]:
2024-05-31 17:56:34 +02:00
assert len(number_of_output_channels) == len(kernel_size_conv)
assert len(number_of_output_channels) == len(stride_conv)
assert len(number_of_output_channels) == len(padding_conv)
assert len(number_of_output_channels) == len(dilation_conv)
assert len(number_of_output_channels) == len(kernel_size_pool)
assert len(number_of_output_channels) == len(stride_pool)
assert len(number_of_output_channels) == len(padding_pool)
assert len(number_of_output_channels) == len(dilation_pool)
assert len(number_of_output_channels) == len(local_learning)
assert len(number_of_output_channels) == len(skip_connection)
2024-05-30 14:08:44 +02:00
if enable_onoff:
input_number_of_channel *= 2
list_cnn_top_id: list[int] = []
list_other_id: list[int] = []
test_image = torch.ones((1, input_number_of_channel, input_dim_x, input_dim_y))
network = torch.nn.Sequential()
if enable_onoff:
network.append(SplitOnOffLayer())
test_image = network[-1](test_image)
2024-05-31 17:56:34 +02:00
for block_id in range(0, len(number_of_output_channels)):
if use_nnmf:
test_image = append_nnmf_block(
network=network,
out_channels=number_of_output_channels[block_id],
test_image=test_image,
list_other_id=list_other_id,
dilation=dilation_conv[block_id],
padding=padding_conv[block_id],
stride=stride_conv[block_id],
kernel_size=kernel_size_conv[block_id],
epsilon=epsilon,
positive_function_type=positive_function_type,
beta=beta,
iterations=iterations,
local_learning=local_learning[block_id],
local_learning_kl=local_learning_kl,
use_reconstruction=use_reconstruction,
skip_connection=skip_connection[block_id],
2024-05-30 14:08:44 +02:00
)
else:
2024-05-31 17:56:34 +02:00
list_other_id.append(len(network))
2024-05-30 14:08:44 +02:00
2024-05-31 17:56:34 +02:00
kernel_size_conv_internal = list(kernel_size_conv[block_id])
2024-05-30 14:08:44 +02:00
2024-05-31 17:56:34 +02:00
if kernel_size_conv[block_id][0] == -1:
kernel_size_conv_internal[0] = test_image.shape[-2]
2024-05-30 14:08:44 +02:00
2024-05-31 17:56:34 +02:00
if kernel_size_conv[block_id][1] == -1:
kernel_size_conv_internal[1] = test_image.shape[-1]
2024-05-30 14:08:44 +02:00
network.append(
2024-05-31 17:56:34 +02:00
torch.nn.Conv2d(
2024-05-30 14:08:44 +02:00
in_channels=test_image.shape[1],
2024-05-31 17:56:34 +02:00
out_channels=number_of_output_channels[block_id],
kernel_size=kernel_size_conv_internal,
stride=1,
padding=0,
2024-05-30 14:08:44 +02:00
)
)
2024-05-31 17:56:34 +02:00
test_image = network[-1](test_image)
if cnn_top or block_id < len(number_of_output_channels) - 1:
network.append(torch.nn.ReLU())
test_image = network[-1](test_image)
2024-05-30 14:08:44 +02:00
2024-05-31 17:56:34 +02:00
if cnn_top:
list_cnn_top_id.append(len(network))
2024-05-30 14:08:44 +02:00
network.append(
2024-05-31 17:56:34 +02:00
torch.nn.Conv2d(
2024-05-30 14:08:44 +02:00
in_channels=test_image.shape[1],
2024-05-31 17:56:34 +02:00
out_channels=number_of_output_channels[block_id],
kernel_size=(1, 1),
stride=(1, 1),
padding=(0, 0),
bias=True,
2024-05-30 14:08:44 +02:00
)
)
2024-05-31 17:56:34 +02:00
test_image = network[-1](test_image)
if block_id < len(number_of_output_channels) - 1:
network.append(torch.nn.ReLU())
test_image = network[-1](test_image)
if (kernel_size_pool[block_id][0] > 0) and (kernel_size_pool[block_id][1] > 0):
if max_pool:
network.append(
torch.nn.MaxPool2d(
kernel_size=kernel_size_pool[block_id],
stride=stride_pool[block_id],
padding=padding_pool[block_id],
)
)
else:
network.append(
torch.nn.AvgPool2d(
kernel_size=kernel_size_pool[block_id],
stride=stride_pool[block_id],
padding=padding_pool[block_id],
)
2024-05-30 14:08:44 +02:00
)
test_image = network[-1](test_image)
network.append(torch.nn.Flatten())
test_image = network[-1](test_image)
network.append(torch.nn.Softmax(dim=1))
test_image = network[-1](test_image)
return network, list_cnn_top_id, list_other_id