Add files via upload
This commit is contained in:
parent
e848d49a7c
commit
0b8dd174ce
3 changed files with 131 additions and 24 deletions
21
NNMF2d.py
21
NNMF2d.py
|
@ -104,16 +104,17 @@ class NNMF2d(torch.nn.Module):
|
|||
self.local_learning,
|
||||
self.local_learning_kl,
|
||||
)
|
||||
if self.skip_connection:
|
||||
if self.use_reconstruction:
|
||||
reconstruction = torch.nn.functional.linear(
|
||||
h_dyn.movedim(1, -1), positive_weights.T
|
||||
).movedim(-1, 1)
|
||||
output = torch.cat((h_dyn, input - reconstruction), dim=1)
|
||||
else:
|
||||
output = torch.cat((h_dyn, input), dim=1)
|
||||
return output
|
||||
else:
|
||||
# if self.skip_connection:
|
||||
# if self.use_reconstruction:
|
||||
# reconstruction = torch.nn.functional.linear(
|
||||
# h_dyn.movedim(1, -1), positive_weights.T
|
||||
# ).movedim(-1, 1)
|
||||
# output = torch.cat((h_dyn, input - reconstruction), dim=1)
|
||||
# else:
|
||||
# output = torch.cat((h_dyn, input), dim=1)
|
||||
# return output
|
||||
# else:
|
||||
# return h_dyn
|
||||
return h_dyn
|
||||
|
||||
|
||||
|
|
82
Y.py
Normal file
82
Y.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
import torch
|
||||
from typing import Callable
|
||||
|
||||
|
||||
class Y(torch.nn.Module):
|
||||
"""
|
||||
A PyTorch module that splits the processing path of a input tensor
|
||||
and processes it through multiple torch.nn.Sequential segments,
|
||||
then combines the outputs using a specified methods.
|
||||
|
||||
This module allows for creating split paths within a `torch.nn.Sequential`
|
||||
model, making it possible to implement architectures with skip connections
|
||||
or parallel paths without abandoning the sequential model structure.
|
||||
|
||||
Attributes:
|
||||
segments (torch.nn.Sequential[torch.nn.Sequential]): A list of sequential modules to
|
||||
process the input tensor.
|
||||
combine_func (Callable | None): A function to combine the outputs
|
||||
from the segments.
|
||||
dim (int | None): The dimension along which to concatenate
|
||||
the outputs if `combine_func` is `torch.cat`.
|
||||
|
||||
Args:
|
||||
segments (torch.nn.Sequential[torch.nn.Sequential]): A torch.nn.Sequential
|
||||
with a list of sequential modules to process the input tensor.
|
||||
combine (str, optional): The method to combine the outputs.
|
||||
"cat" for concatenation (default), or "func" to use a
|
||||
custom combine function.
|
||||
dim (int | None, optional): The dimension along which to
|
||||
concatenate the outputs if `combine` is "cat".
|
||||
Defaults to 1.
|
||||
combine_func (Callable | None, optional): A custom function
|
||||
to combine the outputs if `combine` is "func".
|
||||
Defaults to None.
|
||||
|
||||
Example:
|
||||
A simple example for the `Y` module with two sub-torch.nn.Sequential:
|
||||
|
||||
----- segment_a -----
|
||||
main_Sequential ----| |---- main_Sequential
|
||||
----- segment_b -----
|
||||
|
||||
segments = [segment_a, segment_b]
|
||||
y_split = Y(segments)
|
||||
result = y_split(input_tensor)
|
||||
|
||||
Methods:
|
||||
forward(input: torch.Tensor) -> torch.Tensor:
|
||||
Processes the input tensor through the segments and
|
||||
combines the results.
|
||||
"""
|
||||
|
||||
segments: torch.nn.Sequential
|
||||
combine_func: Callable
|
||||
dim: int | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
segments: torch.nn.Sequential,
|
||||
combine: str = "cat", # "cat", "func"
|
||||
dim: int | None = 1,
|
||||
combine_func: Callable | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.segments = segments
|
||||
self.dim = dim
|
||||
|
||||
if combine.upper() == "CAT":
|
||||
self.combine_func = torch.cat
|
||||
else:
|
||||
assert combine_func is not None
|
||||
self.combine_func = combine_func
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
results: list[torch.Tensor] = []
|
||||
for segment in self.segments:
|
||||
results.append(segment(input))
|
||||
|
||||
if self.dim is None:
|
||||
return self.combine_func(results)
|
||||
else:
|
||||
return self.combine_func(results, dim=self.dim)
|
|
@ -2,6 +2,7 @@ import torch
|
|||
from append_input_conv2d import append_input_conv2d
|
||||
from L1NormLayer import L1NormLayer
|
||||
from NNMF2d import NNMF2d
|
||||
from Y import Y
|
||||
|
||||
|
||||
def append_nnmf_block(
|
||||
|
@ -44,6 +45,29 @@ def append_nnmf_block(
|
|||
test_image = network[-1](test_image)
|
||||
|
||||
list_other_id.append(len(network))
|
||||
if skip_connection:
|
||||
network.append(
|
||||
Y(
|
||||
torch.nn.Sequential(
|
||||
torch.nn.Sequential(
|
||||
NNMF2d(
|
||||
in_channels=test_image.shape[1],
|
||||
out_channels=out_channels,
|
||||
epsilon=epsilon,
|
||||
positive_function_type=positive_function_type,
|
||||
beta=beta,
|
||||
iterations=iterations,
|
||||
local_learning=local_learning,
|
||||
local_learning_kl=local_learning_kl,
|
||||
use_reconstruction=use_reconstruction,
|
||||
skip_connection=skip_connection,
|
||||
)
|
||||
),
|
||||
torch.nn.Sequential(torch.nn.Identity()),
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
network.append(
|
||||
NNMF2d(
|
||||
in_channels=test_image.shape[1],
|
||||
|
|
Loading…
Reference in a new issue