64 lines
1.7 KiB
Python
64 lines
1.7 KiB
Python
|
import torch
|
||
|
from append_input_conv2d import append_input_conv2d
|
||
|
from L1NormLayer import L1NormLayer
|
||
|
from NNMF2d import NNMF2d
|
||
|
|
||
|
|
||
|
def append_nnmf_block(
|
||
|
network: torch.nn.Sequential,
|
||
|
out_channels: int,
|
||
|
test_image: torch.tensor,
|
||
|
list_other_id: list[int],
|
||
|
dilation: int = 1,
|
||
|
padding: int = 0,
|
||
|
stride: int = 1,
|
||
|
kernel_size: list[int] = [5, 5],
|
||
|
epsilon: float | None = None,
|
||
|
positive_function_type: int = 0,
|
||
|
beta: float | None = None,
|
||
|
iterations: int = 20,
|
||
|
local_learning: bool = False,
|
||
|
local_learning_kl: bool = False,
|
||
|
use_reconstruction: bool = False,
|
||
|
skip_connection: bool = False,
|
||
|
) -> torch.Tensor:
|
||
|
|
||
|
kernel_size_internal: list[int] = list(kernel_size)
|
||
|
|
||
|
if kernel_size[0] < 1:
|
||
|
kernel_size_internal[0] = test_image.shape[-2]
|
||
|
|
||
|
if kernel_size[1] < 1:
|
||
|
kernel_size_internal[1] = test_image.shape[-1]
|
||
|
|
||
|
test_image = append_input_conv2d(
|
||
|
network=network,
|
||
|
test_image=test_image,
|
||
|
dilation=dilation,
|
||
|
padding=padding,
|
||
|
stride=stride,
|
||
|
kernel_size=kernel_size_internal,
|
||
|
)
|
||
|
|
||
|
network.append(L1NormLayer())
|
||
|
test_image = network[-1](test_image)
|
||
|
|
||
|
list_other_id.append(len(network))
|
||
|
network.append(
|
||
|
NNMF2d(
|
||
|
in_channels=test_image.shape[1],
|
||
|
out_channels=out_channels,
|
||
|
epsilon=epsilon,
|
||
|
positive_function_type=positive_function_type,
|
||
|
beta=beta,
|
||
|
iterations=iterations,
|
||
|
local_learning=local_learning,
|
||
|
local_learning_kl=local_learning_kl,
|
||
|
use_reconstruction=use_reconstruction,
|
||
|
skip_connection=skip_connection,
|
||
|
)
|
||
|
)
|
||
|
test_image = network[-1](test_image)
|
||
|
|
||
|
return test_image
|