diff --git a/NNMF2d.py b/NNMF2d.py index 15f169d..c4ae2ab 100644 --- a/NNMF2d.py +++ b/NNMF2d.py @@ -104,17 +104,18 @@ class NNMF2d(torch.nn.Module): self.local_learning, self.local_learning_kl, ) - if self.skip_connection: - if self.use_reconstruction: - reconstruction = torch.nn.functional.linear( - h_dyn.movedim(1, -1), positive_weights.T - ).movedim(-1, 1) - output = torch.cat((h_dyn, input - reconstruction), dim=1) - else: - output = torch.cat((h_dyn, input), dim=1) - return output - else: - return h_dyn + # if self.skip_connection: + # if self.use_reconstruction: + # reconstruction = torch.nn.functional.linear( + # h_dyn.movedim(1, -1), positive_weights.T + # ).movedim(-1, 1) + # output = torch.cat((h_dyn, input - reconstruction), dim=1) + # else: + # output = torch.cat((h_dyn, input), dim=1) + # return output + # else: + # return h_dyn + return h_dyn class FunctionalNNMF2d(torch.autograd.Function): diff --git a/Y.py b/Y.py new file mode 100644 index 0000000..d30a869 --- /dev/null +++ b/Y.py @@ -0,0 +1,82 @@ +import torch +from typing import Callable + + +class Y(torch.nn.Module): + """ + A PyTorch module that splits the processing path of a input tensor + and processes it through multiple torch.nn.Sequential segments, + then combines the outputs using a specified methods. + + This module allows for creating split paths within a `torch.nn.Sequential` + model, making it possible to implement architectures with skip connections + or parallel paths without abandoning the sequential model structure. + + Attributes: + segments (torch.nn.Sequential[torch.nn.Sequential]): A list of sequential modules to + process the input tensor. + combine_func (Callable | None): A function to combine the outputs + from the segments. + dim (int | None): The dimension along which to concatenate + the outputs if `combine_func` is `torch.cat`. + + Args: + segments (torch.nn.Sequential[torch.nn.Sequential]): A torch.nn.Sequential + with a list of sequential modules to process the input tensor. + combine (str, optional): The method to combine the outputs. + "cat" for concatenation (default), or "func" to use a + custom combine function. + dim (int | None, optional): The dimension along which to + concatenate the outputs if `combine` is "cat". + Defaults to 1. + combine_func (Callable | None, optional): A custom function + to combine the outputs if `combine` is "func". + Defaults to None. + + Example: + A simple example for the `Y` module with two sub-torch.nn.Sequential: + + ----- segment_a ----- + main_Sequential ----| |---- main_Sequential + ----- segment_b ----- + + segments = [segment_a, segment_b] + y_split = Y(segments) + result = y_split(input_tensor) + + Methods: + forward(input: torch.Tensor) -> torch.Tensor: + Processes the input tensor through the segments and + combines the results. + """ + + segments: torch.nn.Sequential + combine_func: Callable + dim: int | None + + def __init__( + self, + segments: torch.nn.Sequential, + combine: str = "cat", # "cat", "func" + dim: int | None = 1, + combine_func: Callable | None = None, + ): + super().__init__() + self.segments = segments + self.dim = dim + + if combine.upper() == "CAT": + self.combine_func = torch.cat + else: + assert combine_func is not None + self.combine_func = combine_func + + def forward(self, input: torch.Tensor) -> torch.Tensor: + results: list[torch.Tensor] = [] + for segment in self.segments: + results.append(segment(input)) + + if self.dim is None: + return self.combine_func(results) + else: + return self.combine_func(results, dim=self.dim) diff --git a/append_nnmf_block.py b/append_nnmf_block.py index 538d421..ed36f5f 100644 --- a/append_nnmf_block.py +++ b/append_nnmf_block.py @@ -2,6 +2,7 @@ import torch from append_input_conv2d import append_input_conv2d from L1NormLayer import L1NormLayer from NNMF2d import NNMF2d +from Y import Y def append_nnmf_block( @@ -44,20 +45,43 @@ def append_nnmf_block( test_image = network[-1](test_image) list_other_id.append(len(network)) - network.append( - NNMF2d( - in_channels=test_image.shape[1], - out_channels=out_channels, - epsilon=epsilon, - positive_function_type=positive_function_type, - beta=beta, - iterations=iterations, - local_learning=local_learning, - local_learning_kl=local_learning_kl, - use_reconstruction=use_reconstruction, - skip_connection=skip_connection, + if skip_connection: + network.append( + Y( + torch.nn.Sequential( + torch.nn.Sequential( + NNMF2d( + in_channels=test_image.shape[1], + out_channels=out_channels, + epsilon=epsilon, + positive_function_type=positive_function_type, + beta=beta, + iterations=iterations, + local_learning=local_learning, + local_learning_kl=local_learning_kl, + use_reconstruction=use_reconstruction, + skip_connection=skip_connection, + ) + ), + torch.nn.Sequential(torch.nn.Identity()), + ) + ) + ) + else: + network.append( + NNMF2d( + in_channels=test_image.shape[1], + out_channels=out_channels, + epsilon=epsilon, + positive_function_type=positive_function_type, + beta=beta, + iterations=iterations, + local_learning=local_learning, + local_learning_kl=local_learning_kl, + use_reconstruction=use_reconstruction, + skip_connection=skip_connection, + ) ) - ) test_image = network[-1](test_image) return test_image