import torch from append_block import append_block from L1NormLayer import L1NormLayer from NNMF2d import NNMF2d from append_parameter import append_parameter def make_network( input_dim_x: int, input_dim_y: int, input_number_of_channel: int, iterations: int, torch_device: torch.device, epsilon: bool | None = None, positive_function_type: int = 0, beta: float | None = None, # Conv: number_of_output_channels: list[int] = [32*2, 64*2, 96*2, 10], kernel_size_conv: list[tuple[int, int]] = [ (5, 5), (5, 5), (-1, -1), # Take the whole input image x and y size (1, 1), ], stride_conv: list[tuple[int, int]] = [ (1, 1), (1, 1), (1, 1), (1, 1), ], padding_conv: list[tuple[int, int]] = [ (0, 0), (0, 0), (0, 0), (0, 0), ], dilation_conv: list[tuple[int, int]] = [ (1, 1), (1, 1), (1, 1), (1, 1), ], # Pool: kernel_size_pool: list[tuple[int, int]] = [ (2, 2), (2, 2), (-1, -1), # No pooling layer (-1, -1), # No pooling layer ], stride_pool: list[tuple[int, int]] = [ (2, 2), (2, 2), (-1, -1), (-1, -1), ], padding_pool: list[tuple[int, int]] = [ (0, 0), (0, 0), (0, 0), (0, 0), ], dilation_pool: list[tuple[int, int]] = [ (1, 1), (1, 1), (1, 1), (1, 1), ], enable_onoff: bool = False, ) -> tuple[ torch.nn.Sequential, list[list[torch.nn.parameter.Parameter]], list[str], ]: assert len(number_of_output_channels) == len(kernel_size_conv) assert len(number_of_output_channels) == len(stride_conv) assert len(number_of_output_channels) == len(padding_conv) assert len(number_of_output_channels) == len(dilation_conv) assert len(number_of_output_channels) == len(kernel_size_pool) assert len(number_of_output_channels) == len(stride_pool) assert len(number_of_output_channels) == len(padding_pool) assert len(number_of_output_channels) == len(dilation_pool) if enable_onoff: input_number_of_channel *= 2 parameter_cnn_top: list[torch.nn.parameter.Parameter] = [] parameter_nnmf: list[torch.nn.parameter.Parameter] = [] parameter_norm: list[torch.nn.parameter.Parameter] = [] test_image = torch.ones( (1, input_number_of_channel, input_dim_x, input_dim_y), device=torch_device ) network = torch.nn.Sequential() network = network.to(torch_device) for block_id in range(0, len(number_of_output_channels)): test_image = append_block( network=network, out_channels=number_of_output_channels[block_id], test_image=test_image, dilation=dilation_conv[block_id], padding=padding_conv[block_id], stride=stride_conv[block_id], kernel_size=kernel_size_conv[block_id], epsilon=epsilon, positive_function_type=positive_function_type, beta=beta, iterations=iterations, torch_device=torch_device, parameter_cnn_top=parameter_cnn_top, parameter_nnmf=parameter_nnmf, parameter_norm=parameter_norm, ) if (kernel_size_pool[block_id][0] > 0) and (kernel_size_pool[block_id][1] > 0): network.append(torch.nn.ReLU()) test_image = network[-1](test_image) mock_output = ( torch.nn.functional.conv2d( torch.zeros( 1, 1, test_image.shape[2], test_image.shape[3], ), torch.zeros((1, 1, 2, 2)), stride=(2, 2), padding=(0, 0), dilation=(1, 1), ) .squeeze(0) .squeeze(0) ) network.append( torch.nn.Unfold( kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), dilation=(1, 1), ) ) test_image = network[-1](test_image) network.append( torch.nn.Fold( output_size=mock_output.shape, kernel_size=(1, 1), dilation=1, padding=0, stride=1, ) ) test_image = network[-1](test_image) network.append(L1NormLayer()) test_image = network[-1](test_image) network.append( NNMF2d( in_channels=test_image.shape[1], out_channels=test_image.shape[1] // 4, epsilon=epsilon, positive_function_type=positive_function_type, beta=beta, iterations=iterations, local_learning=False, local_learning_kl=False, ).to(torch_device) ) test_image = network[-1](test_image) append_parameter(module=network[-1], parameter_list=parameter_nnmf) network.append( torch.nn.BatchNorm2d( num_features=test_image.shape[1], device=torch_device, momentum=0.1, track_running_stats=False, ) ) test_image = network[-1](test_image) append_parameter(module=network[-1], parameter_list=parameter_norm) network.append(torch.nn.Softmax(dim=1)) test_image = network[-1](test_image) network.append(torch.nn.Flatten()) test_image = network[-1](test_image) parameters: list[list[torch.nn.parameter.Parameter]] = [ parameter_cnn_top, parameter_nnmf, parameter_norm, ] name_list: list[str] = [ "cnn_top", "nnmf", "batchnorm2d", ] return ( network, parameters, name_list, )