From f907b5f601a7c97a8fb48c72c6555c99456339c2 Mon Sep 17 00:00:00 2001 From: David Rotermund Date: Fri, 31 May 2024 17:56:34 +0200 Subject: [PATCH] Add files via upload --- L1NormLayer.py | 13 ++ NNMF2d.py | 269 +++++++++++++++++++++++ append_input_conv2d.py | 48 +++++ append_nnmf_block.py | 63 ++++++ make_network.py | 469 ++++++++++++----------------------------- make_optimize.py | 5 +- run_network.py | 28 +-- 7 files changed, 548 insertions(+), 347 deletions(-) create mode 100644 L1NormLayer.py create mode 100644 NNMF2d.py create mode 100644 append_input_conv2d.py create mode 100644 append_nnmf_block.py diff --git a/L1NormLayer.py b/L1NormLayer.py new file mode 100644 index 0000000..6816b3a --- /dev/null +++ b/L1NormLayer.py @@ -0,0 +1,13 @@ +import torch + + +class L1NormLayer(torch.nn.Module): + + epsilon: float + + def __init__(self, epsilon: float = 10e-20) -> None: + super().__init__() + self.epsilon = epsilon + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return input / (input.sum(dim=1, keepdim=True) + self.epsilon) diff --git a/NNMF2d.py b/NNMF2d.py new file mode 100644 index 0000000..15f169d --- /dev/null +++ b/NNMF2d.py @@ -0,0 +1,269 @@ +import torch +from non_linear_weigth_function import non_linear_weigth_function + + +class NNMF2d(torch.nn.Module): + + in_channels: int + out_channels: int + weight: torch.Tensor + bias: None | torch.Tensor + iterations: int + epsilon: float | None + init_min: float + init_max: float + beta: torch.Tensor | None + positive_function_type: int + local_learning: bool + local_learning_kl: bool + use_reconstruction: bool + skip_connection: bool + + def __init__( + self, + in_channels: int, + out_channels: int, + device=None, + dtype=None, + iterations: int = 20, + epsilon: float | None = None, + init_min: float = 0.0, + init_max: float = 1.0, + beta: float | None = None, + positive_function_type: int = 0, + local_learning: bool = False, + local_learning_kl: bool = False, + use_reconstruction: bool = False, + skip_connection: bool = False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + + super().__init__() + + self.positive_function_type = positive_function_type + self.init_min = init_min + self.init_max = init_max + + self.in_channels = in_channels + self.out_channels = out_channels + + self.iterations = iterations + self.local_learning = local_learning + self.local_learning_kl = local_learning_kl + + self.use_reconstruction = use_reconstruction + + self.weight = torch.nn.parameter.Parameter( + torch.empty((out_channels, in_channels), **factory_kwargs) + ) + + if beta is not None: + self.beta = torch.nn.parameter.Parameter(torch.empty((1), **factory_kwargs)) + self.beta.data[0] = beta + else: + self.beta = None + + self.reset_parameters() + self.functional_nnmf2d = FunctionalNNMF2d.apply + + self.epsilon = epsilon + + self.skip_connection = skip_connection + + def extra_repr(self) -> str: + s: str = f"{self.in_channels}, {self.out_channels}" + + if self.epsilon is not None: + s += f", epsilon={self.epsilon}" + s += f", pfunctype={self.positive_function_type}" + s += f", local_learning={self.local_learning}" + + if self.local_learning: + s += f", local_learning_kl={self.local_learning_kl}" + + return s + + def reset_parameters(self) -> None: + torch.nn.init.uniform_(self.weight, a=self.init_min, b=self.init_max) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + + positive_weights = non_linear_weigth_function( + self.weight, self.beta, self.positive_function_type + ) + positive_weights = positive_weights / ( + positive_weights.sum(dim=1, keepdim=True) + 10e-20 + ) + + h_dyn = self.functional_nnmf2d( + input, + positive_weights, + self.out_channels, + self.iterations, + self.epsilon, + self.local_learning, + self.local_learning_kl, + ) + if self.skip_connection: + if self.use_reconstruction: + reconstruction = torch.nn.functional.linear( + h_dyn.movedim(1, -1), positive_weights.T + ).movedim(-1, 1) + output = torch.cat((h_dyn, input - reconstruction), dim=1) + else: + output = torch.cat((h_dyn, input), dim=1) + return output + else: + return h_dyn + + +class FunctionalNNMF2d(torch.autograd.Function): + @staticmethod + def forward( # type: ignore + ctx, + input: torch.Tensor, + weight: torch.Tensor, + out_channels: int, + iterations: int, + epsilon: float | None, + local_learning: bool, + local_learning_kl: bool, + ) -> torch.Tensor: + + # Prepare h + h = torch.full( + (input.shape[0], out_channels, input.shape[-2], input.shape[-1]), + 1.0 / float(out_channels), + device=input.device, + dtype=input.dtype, + ) + + h = h.movedim(1, -1) + input = input.movedim(1, -1) + for _ in range(0, iterations): + reconstruction = torch.nn.functional.linear(h, weight.T) + reconstruction += 1e-20 + if epsilon is None: + h *= torch.nn.functional.linear((input / reconstruction), weight) + else: + h *= 1 + epsilon * torch.nn.functional.linear( + (input / reconstruction), weight + ) + h /= h.sum(-1, keepdim=True) + 10e-20 + h = h.movedim(-1, 1) + input = input.movedim(-1, 1) + + # ########################################################### + # Save the necessary data for the backward pass + # ########################################################### + ctx.save_for_backward(input, weight, h) + ctx.local_learning = local_learning + ctx.local_learning_kl = local_learning_kl + + assert torch.isfinite(h).all() + return h + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, grad_output: torch.Tensor) -> tuple[ # type: ignore + torch.Tensor | None, + torch.Tensor | None, + None, + None, + None, + None, + None, + ]: + + # ############################################## + # Default values + # ############################################## + grad_input: torch.Tensor | None = None + grad_weight: torch.Tensor | None = None + + # ############################################## + # Get the variables back + # ############################################## + (input, weight, h) = ctx.saved_tensors + + # The back prop gradient + h = h.movedim(1, -1) + grad_output = grad_output.movedim(1, -1) + input = input.movedim(1, -1) + big_r = torch.nn.functional.linear(h, weight.T) + big_r_div = 1.0 / (big_r + 1e-20) + + factor_x_div_r = input * big_r_div + + grad_input = torch.nn.functional.linear(h * grad_output, weight.T) * big_r_div + + del big_r_div + + # The weight gradient + if ctx.local_learning is False: + del big_r + + grad_weight = -torch.nn.functional.linear( + h.reshape( + grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2], + h.shape[3], + ).T, + (factor_x_div_r * grad_input) + .reshape( + grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2], + grad_input.shape[3], + ) + .T, + ) + + grad_weight += torch.nn.functional.linear( + (h * grad_output) + .reshape( + grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2], + h.shape[3], + ) + .T, + factor_x_div_r.reshape( + grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2], + grad_input.shape[3], + ).T, + ) + + else: + if ctx.local_learning_kl: + grad_weight = -torch.nn.functional.linear( + h.reshape( + grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2], + h.shape[3], + ).T, + factor_x_div_r.reshape( + grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2], + grad_input.shape[3], + ).T, + ) + else: + grad_weight = -torch.nn.functional.linear( + h.reshape( + grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2], + h.shape[3], + ).T, + (2 * (input - big_r)) + .reshape( + grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2], + grad_input.shape[3], + ) + .T, + ) + grad_input = grad_input.movedim(-1, 1) + assert torch.isfinite(grad_input).all() + assert torch.isfinite(grad_weight).all() + + return ( + grad_input, + grad_weight, + None, + None, + None, + None, + None, + ) diff --git a/append_input_conv2d.py b/append_input_conv2d.py new file mode 100644 index 0000000..5faf4f7 --- /dev/null +++ b/append_input_conv2d.py @@ -0,0 +1,48 @@ +import torch + + +def append_input_conv2d( + network: torch.nn.Sequential, + test_image: torch.tensor, + dilation: int = 1, + padding: int = 0, + stride: int = 1, + kernel_size: list[int] = [5, 5], +) -> torch.Tensor: + + mock_output = ( + torch.nn.functional.conv2d( + torch.zeros( + 1, + 1, + test_image.shape[2], + test_image.shape[3], + ), + torch.zeros((1, 1, kernel_size[0], kernel_size[1])), + stride=stride, + padding=padding, + dilation=dilation, + ) + .squeeze(0) + .squeeze(0) + ) + + network.append( + torch.nn.Unfold( + kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride + ) + ) + test_image = network[-1](test_image) + + network.append( + torch.nn.Fold( + output_size=mock_output.shape, + kernel_size=(1, 1), + dilation=1, + padding=0, + stride=1, + ) + ) + test_image = network[-1](test_image) + + return test_image diff --git a/append_nnmf_block.py b/append_nnmf_block.py new file mode 100644 index 0000000..538d421 --- /dev/null +++ b/append_nnmf_block.py @@ -0,0 +1,63 @@ +import torch +from append_input_conv2d import append_input_conv2d +from L1NormLayer import L1NormLayer +from NNMF2d import NNMF2d + + +def append_nnmf_block( + network: torch.nn.Sequential, + out_channels: int, + test_image: torch.tensor, + list_other_id: list[int], + dilation: int = 1, + padding: int = 0, + stride: int = 1, + kernel_size: list[int] = [5, 5], + epsilon: float | None = None, + positive_function_type: int = 0, + beta: float | None = None, + iterations: int = 20, + local_learning: bool = False, + local_learning_kl: bool = False, + use_reconstruction: bool = False, + skip_connection: bool = False, +) -> torch.Tensor: + + kernel_size_internal: list[int] = list(kernel_size) + + if kernel_size[0] < 1: + kernel_size_internal[0] = test_image.shape[-2] + + if kernel_size[1] < 1: + kernel_size_internal[1] = test_image.shape[-1] + + test_image = append_input_conv2d( + network=network, + test_image=test_image, + dilation=dilation, + padding=padding, + stride=stride, + kernel_size=kernel_size_internal, + ) + + network.append(L1NormLayer()) + test_image = network[-1](test_image) + + list_other_id.append(len(network)) + network.append( + NNMF2d( + in_channels=test_image.shape[1], + out_channels=out_channels, + epsilon=epsilon, + positive_function_type=positive_function_type, + beta=beta, + iterations=iterations, + local_learning=local_learning, + local_learning_kl=local_learning_kl, + use_reconstruction=use_reconstruction, + skip_connection=skip_connection, + ) + ) + test_image = network[-1](test_image) + + return test_image diff --git a/make_network.py b/make_network.py index d978e26..957672d 100644 --- a/make_network.py +++ b/make_network.py @@ -1,7 +1,6 @@ import torch -from NNMFConv2d import NNMFConv2d -from NNMFConv2dP import NNMFConv2dP from SplitOnOffLayer import SplitOnOffLayer +from append_nnmf_block import append_nnmf_block def make_network( @@ -11,43 +10,79 @@ def make_network( input_dim_y: int, input_number_of_channel: int, iterations: int, - init_min: float = 0.0, - init_max: float = 1.0, - use_convolution: bool = False, - convolution_contribution_map_enable: bool = False, epsilon: bool | None = None, positive_function_type: int = 0, beta: float | None = None, - number_of_output_channels_conv1: int = 32, - number_of_output_channels_conv2: int = 64, - number_of_output_channels_flatten2: int = 96, - number_of_output_channels_full1: int = 10, - kernel_size_conv1: tuple[int, int] = (5, 5), - kernel_size_pool1: tuple[int, int] = (2, 2), - kernel_size_conv2: tuple[int, int] = (5, 5), - kernel_size_pool2: tuple[int, int] = (2, 2), - stride_conv1: tuple[int, int] = (1, 1), - stride_pool1: tuple[int, int] = (2, 2), - stride_conv2: tuple[int, int] = (1, 1), - stride_pool2: tuple[int, int] = (2, 2), - padding_conv1: int = 0, - padding_pool1: int = 0, - padding_conv2: int = 0, - padding_pool2: int = 0, - enable_onoff: bool = False, - local_learning_0: bool = False, - local_learning_1: bool = False, - local_learning_2: bool = False, - local_learning_3: bool = False, + # Conv: + number_of_output_channels: list[int] = [32, 64, 96, 10], + kernel_size_conv: list[tuple[int, int]] = [ + (5, 5), + (5, 5), + (-1, -1), # Take the whole input image x and y size + (1, 1), + ], + stride_conv: list[tuple[int, int]] = [ + (1, 1), + (1, 1), + (1, 1), + (1, 1), + ], + padding_conv: list[tuple[int, int]] = [ + (0, 0), + (0, 0), + (0, 0), + (0, 0), + ], + dilation_conv: list[tuple[int, int]] = [ + (1, 1), + (1, 1), + (1, 1), + (1, 1), + ], + # Pool: + kernel_size_pool: list[tuple[int, int]] = [ + (2, 2), + (2, 2), + (-1, -1), # No pooling layer + (-1, -1), # No pooling layer + ], + stride_pool: list[tuple[int, int]] = [ + (2, 2), + (2, 2), + (-1, -1), + (-1, -1), + ], + padding_pool: list[tuple[int, int]] = [ + (0, 0), + (0, 0), + (0, 0), + (0, 0), + ], + dilation_pool: list[tuple[int, int]] = [ + (1, 1), + (1, 1), + (1, 1), + (1, 1), + ], + local_learning: list[bool] = [False, False, False, False], + skip_connection: list[bool] = [False, False, False, False], local_learning_kl: bool = True, - p_mode_0: bool = False, - p_mode_1: bool = False, - p_mode_2: bool = False, - p_mode_3: bool = False, use_reconstruction: bool = False, max_pool: bool = True, + enable_onoff: bool = False, ) -> tuple[torch.nn.Sequential, list[int], list[int]]: + assert len(number_of_output_channels) == len(kernel_size_conv) + assert len(number_of_output_channels) == len(stride_conv) + assert len(number_of_output_channels) == len(padding_conv) + assert len(number_of_output_channels) == len(dilation_conv) + assert len(number_of_output_channels) == len(kernel_size_pool) + assert len(number_of_output_channels) == len(stride_pool) + assert len(number_of_output_channels) == len(padding_pool) + assert len(number_of_output_channels) == len(dilation_pool) + assert len(number_of_output_channels) == len(local_learning) + assert len(number_of_output_channels) == len(skip_connection) + if enable_onoff: input_number_of_channel *= 2 @@ -62,316 +97,86 @@ def make_network( network.append(SplitOnOffLayer()) test_image = network[-1](test_image) - list_other_id.append(len(network)) - if use_nnmf: - if p_mode_0: - network.append( - NNMFConv2dP( - in_channels=test_image.shape[1], - out_channels=number_of_output_channels_conv1, - kernel_size=kernel_size_conv1, - convolution_contribution_map_enable=convolution_contribution_map_enable, - epsilon=epsilon, - positive_function_type=positive_function_type, - init_min=init_min, - init_max=init_max, - beta=beta, - use_convolution=use_convolution, - iterations=iterations, - local_learning=local_learning_0, - local_learning_kl=local_learning_kl, - use_reconstruction=use_reconstruction, - ) + for block_id in range(0, len(number_of_output_channels)): + if use_nnmf: + test_image = append_nnmf_block( + network=network, + out_channels=number_of_output_channels[block_id], + test_image=test_image, + list_other_id=list_other_id, + dilation=dilation_conv[block_id], + padding=padding_conv[block_id], + stride=stride_conv[block_id], + kernel_size=kernel_size_conv[block_id], + epsilon=epsilon, + positive_function_type=positive_function_type, + beta=beta, + iterations=iterations, + local_learning=local_learning[block_id], + local_learning_kl=local_learning_kl, + use_reconstruction=use_reconstruction, + skip_connection=skip_connection[block_id], ) else: + list_other_id.append(len(network)) + + kernel_size_conv_internal = list(kernel_size_conv[block_id]) + + if kernel_size_conv[block_id][0] == -1: + kernel_size_conv_internal[0] = test_image.shape[-2] + + if kernel_size_conv[block_id][1] == -1: + kernel_size_conv_internal[1] = test_image.shape[-1] + network.append( - NNMFConv2d( + torch.nn.Conv2d( in_channels=test_image.shape[1], - out_channels=number_of_output_channels_conv1, - kernel_size=kernel_size_conv1, - convolution_contribution_map_enable=convolution_contribution_map_enable, - epsilon=epsilon, - positive_function_type=positive_function_type, - init_min=init_min, - init_max=init_max, - beta=beta, - use_convolution=use_convolution, - iterations=iterations, - local_learning=local_learning_0, - local_learning_kl=local_learning_kl, + out_channels=number_of_output_channels[block_id], + kernel_size=kernel_size_conv_internal, + stride=1, + padding=0, ) ) - test_image = network[-1](test_image) - else: - network.append( - torch.nn.Conv2d( - in_channels=test_image.shape[1], - out_channels=number_of_output_channels_conv1, - kernel_size=kernel_size_conv1, - stride=stride_conv1, - padding=padding_conv1, - ) - ) - test_image = network[-1](test_image) - network.append(torch.nn.ReLU()) - test_image = network[-1](test_image) - - if cnn_top: - list_cnn_top_id.append(len(network)) - network.append( - torch.nn.Conv2d( - in_channels=test_image.shape[1], - out_channels=number_of_output_channels_conv1, - kernel_size=(1, 1), - stride=(1, 1), - padding=(0, 0), - bias=True, - ) - ) - test_image = network[-1](test_image) - network.append(torch.nn.ReLU()) - test_image = network[-1](test_image) - - if max_pool: - network.append( - torch.nn.MaxPool2d( - kernel_size=kernel_size_pool1, - stride=stride_pool1, - padding=padding_pool1, - ) - ) - else: - network.append( - torch.nn.AvgPool2d( - kernel_size=kernel_size_pool1, - stride=stride_pool1, - padding=padding_pool1, - ) - ) - test_image = network[-1](test_image) - - list_other_id.append(len(network)) - if use_nnmf: - if p_mode_1: - network.append( - NNMFConv2dP( - in_channels=test_image.shape[1], - out_channels=number_of_output_channels_conv2, - kernel_size=kernel_size_conv2, - convolution_contribution_map_enable=convolution_contribution_map_enable, - epsilon=epsilon, - positive_function_type=positive_function_type, - init_min=init_min, - init_max=init_max, - beta=beta, - use_convolution=use_convolution, - iterations=iterations, - local_learning=local_learning_1, - local_learning_kl=local_learning_kl, - use_reconstruction=use_reconstruction, - ) - ) - else: - network.append( - NNMFConv2d( - in_channels=test_image.shape[1], - out_channels=number_of_output_channels_conv2, - kernel_size=kernel_size_conv2, - convolution_contribution_map_enable=convolution_contribution_map_enable, - epsilon=epsilon, - positive_function_type=positive_function_type, - init_min=init_min, - init_max=init_max, - beta=beta, - use_convolution=use_convolution, - iterations=iterations, - local_learning=local_learning_1, - local_learning_kl=local_learning_kl, - ) - ) - test_image = network[-1](test_image) - else: - network.append( - torch.nn.Conv2d( - in_channels=test_image.shape[1], - out_channels=number_of_output_channels_conv2, - kernel_size=kernel_size_conv2, - stride=stride_conv2, - padding=padding_conv2, - ) - ) - test_image = network[-1](test_image) - network.append(torch.nn.ReLU()) - test_image = network[-1](test_image) - - if cnn_top: - list_cnn_top_id.append(len(network)) - network.append( - torch.nn.Conv2d( - in_channels=test_image.shape[1], - out_channels=number_of_output_channels_conv2, - kernel_size=(1, 1), - stride=(1, 1), - padding=(0, 0), - bias=True, - ) - ) - test_image = network[-1](test_image) - network.append(torch.nn.ReLU()) - test_image = network[-1](test_image) - - if max_pool: - network.append( - torch.nn.MaxPool2d( - kernel_size=kernel_size_pool2, - stride=stride_pool2, - padding=padding_pool2, - ) - ) - else: - network.append( - torch.nn.AvgPool2d( - kernel_size=kernel_size_pool2, - stride=stride_pool2, - padding=padding_pool2, - ) - ) - test_image = network[-1](test_image) - - list_other_id.append(len(network)) - if use_nnmf: - if p_mode_2: - network.append( - NNMFConv2dP( - in_channels=test_image.shape[1], - out_channels=number_of_output_channels_flatten2, - kernel_size=(test_image.shape[2], test_image.shape[3]), - convolution_contribution_map_enable=convolution_contribution_map_enable, - epsilon=epsilon, - positive_function_type=positive_function_type, - init_min=init_min, - init_max=init_max, - beta=beta, - use_convolution=use_convolution, - iterations=iterations, - local_learning=local_learning_2, - local_learning_kl=local_learning_kl, - use_reconstruction=use_reconstruction, - ) - ) - else: - network.append( - NNMFConv2d( - in_channels=test_image.shape[1], - out_channels=number_of_output_channels_flatten2, - kernel_size=(test_image.shape[2], test_image.shape[3]), - convolution_contribution_map_enable=convolution_contribution_map_enable, - epsilon=epsilon, - positive_function_type=positive_function_type, - init_min=init_min, - init_max=init_max, - beta=beta, - use_convolution=use_convolution, - iterations=iterations, - local_learning=local_learning_2, - local_learning_kl=local_learning_kl, - ) - ) - test_image = network[-1](test_image) - else: - network.append( - torch.nn.Conv2d( - in_channels=test_image.shape[1], - out_channels=number_of_output_channels_flatten2, - kernel_size=(test_image.shape[2], test_image.shape[3]), - ) - ) - test_image = network[-1](test_image) - network.append(torch.nn.ReLU()) - test_image = network[-1](test_image) - - if cnn_top: - list_cnn_top_id.append(len(network)) - network.append( - torch.nn.Conv2d( - in_channels=test_image.shape[1], - out_channels=number_of_output_channels_flatten2, - kernel_size=(1, 1), - stride=(1, 1), - padding=(0, 0), - bias=True, - ) - ) - test_image = network[-1](test_image) - network.append(torch.nn.ReLU()) - test_image = network[-1](test_image) - - list_other_id.append(len(network)) - if use_nnmf: - if p_mode_3: - network.append( - NNMFConv2dP( - in_channels=test_image.shape[1], - out_channels=number_of_output_channels_full1, - kernel_size=(test_image.shape[2], test_image.shape[3]), - convolution_contribution_map_enable=convolution_contribution_map_enable, - epsilon=epsilon, - positive_function_type=positive_function_type, - init_min=init_min, - init_max=init_max, - beta=beta, - use_convolution=use_convolution, - iterations=iterations, - local_learning=local_learning_3, - local_learning_kl=local_learning_kl, - use_reconstruction=use_reconstruction, - ) - ) - else: - network.append( - NNMFConv2d( - in_channels=test_image.shape[1], - out_channels=number_of_output_channels_full1, - kernel_size=(test_image.shape[2], test_image.shape[3]), - convolution_contribution_map_enable=convolution_contribution_map_enable, - epsilon=epsilon, - positive_function_type=positive_function_type, - init_min=init_min, - init_max=init_max, - beta=beta, - use_convolution=use_convolution, - iterations=iterations, - local_learning=local_learning_3, - local_learning_kl=local_learning_kl, - ) - ) - test_image = network[-1](test_image) - else: - network.append( - torch.nn.Conv2d( - in_channels=test_image.shape[1], - out_channels=number_of_output_channels_full1, - kernel_size=(test_image.shape[2], test_image.shape[3]), - ) - ) - test_image = network[-1](test_image) - if cnn_top: - network.append(torch.nn.ReLU()) test_image = network[-1](test_image) + if cnn_top or block_id < len(number_of_output_channels) - 1: + network.append(torch.nn.ReLU()) + test_image = network[-1](test_image) - if cnn_top: - list_cnn_top_id.append(len(network)) - network.append( - torch.nn.Conv2d( - in_channels=test_image.shape[1], - out_channels=number_of_output_channels_full1, - kernel_size=(1, 1), - stride=(1, 1), - padding=(0, 0), - bias=True, + if cnn_top: + list_cnn_top_id.append(len(network)) + network.append( + torch.nn.Conv2d( + in_channels=test_image.shape[1], + out_channels=number_of_output_channels[block_id], + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + bias=True, + ) ) - ) - test_image = network[-1](test_image) + test_image = network[-1](test_image) + if block_id < len(number_of_output_channels) - 1: + network.append(torch.nn.ReLU()) + test_image = network[-1](test_image) + + if (kernel_size_pool[block_id][0] > 0) and (kernel_size_pool[block_id][1] > 0): + if max_pool: + network.append( + torch.nn.MaxPool2d( + kernel_size=kernel_size_pool[block_id], + stride=stride_pool[block_id], + padding=padding_pool[block_id], + ) + ) + else: + network.append( + torch.nn.AvgPool2d( + kernel_size=kernel_size_pool[block_id], + stride=stride_pool[block_id], + padding=padding_pool[block_id], + ) + ) + test_image = network[-1](test_image) network.append(torch.nn.Flatten()) test_image = network[-1](test_image) diff --git a/make_optimize.py b/make_optimize.py index 54b3072..dc5d4f8 100644 --- a/make_optimize.py +++ b/make_optimize.py @@ -1,6 +1,5 @@ import torch -from NNMFConv2d import NNMFConv2d -from NNMFConv2dP import NNMFConv2dP +from NNMF2d import NNMF2d def make_optimize( @@ -45,7 +44,7 @@ def make_optimize( for netp in network[layerid].parameters(): list_cnn.append(netp) - if isinstance(network[layerid], (NNMFConv2d, NNMFConv2dP)): + if isinstance(network[layerid], NNMF2d): for netp in network[layerid].parameters(): list_nnmf.append(netp) diff --git a/run_network.py b/run_network.py index 14099bc..d68e0a0 100644 --- a/run_network.py +++ b/run_network.py @@ -30,10 +30,10 @@ def main( local_learning_2: bool = False, local_learning_3: bool = False, local_learning_kl: bool = False, - p_mode_0: bool = False, - p_mode_1: bool = False, - p_mode_2: bool = False, - p_mode_3: bool = False, + skip_connection_0: bool = True, + skip_connection_1: bool = True, + skip_connection_2: bool = True, + skip_connection_3: bool = True, use_reconstruction: bool = False, max_pool: bool = True, ) -> None: @@ -107,15 +107,19 @@ def main( input_number_of_channel=input_number_of_channel, iterations=iterations, enable_onoff=enable_onoff, - local_learning_0=local_learning_0, - local_learning_1=local_learning_1, - local_learning_2=local_learning_2, - local_learning_3=local_learning_3, + local_learning=[ + local_learning_0, + local_learning_1, + local_learning_2, + local_learning_3, + ], local_learning_kl=local_learning_kl, - p_mode_0=p_mode_0, - p_mode_1=p_mode_1, - p_mode_2=p_mode_2, - p_mode_3=p_mode_3, + skip_connection=[ + skip_connection_0, + skip_connection_1, + skip_connection_2, + skip_connection_3, + ], use_reconstruction=use_reconstruction, max_pool=max_pool, )