Add files via upload

This commit is contained in:
David Rotermund 2024-05-31 18:43:36 +02:00 committed by GitHub
parent e848d49a7c
commit 0b8dd174ce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 131 additions and 24 deletions

View file

@ -104,16 +104,17 @@ class NNMF2d(torch.nn.Module):
self.local_learning,
self.local_learning_kl,
)
if self.skip_connection:
if self.use_reconstruction:
reconstruction = torch.nn.functional.linear(
h_dyn.movedim(1, -1), positive_weights.T
).movedim(-1, 1)
output = torch.cat((h_dyn, input - reconstruction), dim=1)
else:
output = torch.cat((h_dyn, input), dim=1)
return output
else:
# if self.skip_connection:
# if self.use_reconstruction:
# reconstruction = torch.nn.functional.linear(
# h_dyn.movedim(1, -1), positive_weights.T
# ).movedim(-1, 1)
# output = torch.cat((h_dyn, input - reconstruction), dim=1)
# else:
# output = torch.cat((h_dyn, input), dim=1)
# return output
# else:
# return h_dyn
return h_dyn

82
Y.py Normal file
View file

@ -0,0 +1,82 @@
import torch
from typing import Callable
class Y(torch.nn.Module):
"""
A PyTorch module that splits the processing path of a input tensor
and processes it through multiple torch.nn.Sequential segments,
then combines the outputs using a specified methods.
This module allows for creating split paths within a `torch.nn.Sequential`
model, making it possible to implement architectures with skip connections
or parallel paths without abandoning the sequential model structure.
Attributes:
segments (torch.nn.Sequential[torch.nn.Sequential]): A list of sequential modules to
process the input tensor.
combine_func (Callable | None): A function to combine the outputs
from the segments.
dim (int | None): The dimension along which to concatenate
the outputs if `combine_func` is `torch.cat`.
Args:
segments (torch.nn.Sequential[torch.nn.Sequential]): A torch.nn.Sequential
with a list of sequential modules to process the input tensor.
combine (str, optional): The method to combine the outputs.
"cat" for concatenation (default), or "func" to use a
custom combine function.
dim (int | None, optional): The dimension along which to
concatenate the outputs if `combine` is "cat".
Defaults to 1.
combine_func (Callable | None, optional): A custom function
to combine the outputs if `combine` is "func".
Defaults to None.
Example:
A simple example for the `Y` module with two sub-torch.nn.Sequential:
----- segment_a -----
main_Sequential ----| |---- main_Sequential
----- segment_b -----
segments = [segment_a, segment_b]
y_split = Y(segments)
result = y_split(input_tensor)
Methods:
forward(input: torch.Tensor) -> torch.Tensor:
Processes the input tensor through the segments and
combines the results.
"""
segments: torch.nn.Sequential
combine_func: Callable
dim: int | None
def __init__(
self,
segments: torch.nn.Sequential,
combine: str = "cat", # "cat", "func"
dim: int | None = 1,
combine_func: Callable | None = None,
):
super().__init__()
self.segments = segments
self.dim = dim
if combine.upper() == "CAT":
self.combine_func = torch.cat
else:
assert combine_func is not None
self.combine_func = combine_func
def forward(self, input: torch.Tensor) -> torch.Tensor:
results: list[torch.Tensor] = []
for segment in self.segments:
results.append(segment(input))
if self.dim is None:
return self.combine_func(results)
else:
return self.combine_func(results, dim=self.dim)

View file

@ -2,6 +2,7 @@ import torch
from append_input_conv2d import append_input_conv2d
from L1NormLayer import L1NormLayer
from NNMF2d import NNMF2d
from Y import Y
def append_nnmf_block(
@ -44,6 +45,29 @@ def append_nnmf_block(
test_image = network[-1](test_image)
list_other_id.append(len(network))
if skip_connection:
network.append(
Y(
torch.nn.Sequential(
torch.nn.Sequential(
NNMF2d(
in_channels=test_image.shape[1],
out_channels=out_channels,
epsilon=epsilon,
positive_function_type=positive_function_type,
beta=beta,
iterations=iterations,
local_learning=local_learning,
local_learning_kl=local_learning_kl,
use_reconstruction=use_reconstruction,
skip_connection=skip_connection,
)
),
torch.nn.Sequential(torch.nn.Identity()),
)
)
)
else:
network.append(
NNMF2d(
in_channels=test_image.shape[1],