367 lines
10 KiB
Python
367 lines
10 KiB
Python
import torch
|
|
from PositionalEncoding import PositionalEncoding
|
|
from SequentialSplit import SequentialSplit
|
|
from NNMF2dGrouped import NNMF2dGrouped
|
|
from Functional2Layer import Functional2Layer
|
|
|
|
|
|
def add_block(
|
|
network: torch.nn.Sequential,
|
|
embed_dim: int,
|
|
num_heads: int,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
example_image: torch.Tensor,
|
|
mlp_ratio: int = 4,
|
|
block_id: int = 0,
|
|
iterations: int = 20,
|
|
padding: int = 1,
|
|
kernel_size: tuple[int, int] = (3, 3),
|
|
) -> torch.Tensor | None:
|
|
|
|
# ###########
|
|
# Attention #
|
|
# ###########
|
|
|
|
example_image_a: torch.Tensor = example_image.clone()
|
|
example_image_b: torch.Tensor = example_image.clone()
|
|
|
|
attention_a_sequential = torch.nn.Sequential()
|
|
|
|
attention_a_sequential.add_module(
|
|
"Attention Layer Norm 1 [Pre-Permute]",
|
|
Functional2Layer(func=torch.permute, dims=(0, 2, 3, 1)),
|
|
)
|
|
example_image_a = attention_a_sequential[-1](example_image_a)
|
|
|
|
attention_a_sequential.add_module(
|
|
"Attention Layer Norm 1",
|
|
torch.nn.LayerNorm(
|
|
normalized_shape=example_image_a.shape[-1],
|
|
eps=1e-06,
|
|
bias=True,
|
|
dtype=dtype,
|
|
device=device,
|
|
),
|
|
)
|
|
example_image_a = attention_a_sequential[-1](example_image_a)
|
|
|
|
attention_a_sequential.add_module(
|
|
"Attention Layer Norm 1 [Post-Permute]",
|
|
Functional2Layer(func=torch.permute, dims=(0, 3, 1, 2)),
|
|
)
|
|
example_image_a = attention_a_sequential[-1](example_image_a)
|
|
|
|
attention_a_sequential.add_module(
|
|
"Attention Clamp Layer", Functional2Layer(func=torch.clamp, min=1e-6)
|
|
)
|
|
example_image_a = attention_a_sequential[-1](example_image_a)
|
|
|
|
backup_image_dim = example_image_a.shape[1]
|
|
|
|
attention_a_sequential.add_module(
|
|
"Attention Zero Padding Layer", torch.nn.ZeroPad2d(padding=padding)
|
|
)
|
|
example_image_a = attention_a_sequential[-1](example_image_a)
|
|
|
|
# I need the output size
|
|
mock_output_shape = (
|
|
torch.nn.functional.conv2d(
|
|
torch.zeros(
|
|
1,
|
|
1,
|
|
example_image_a.shape[2],
|
|
example_image_a.shape[3],
|
|
),
|
|
torch.zeros((1, 1, kernel_size[0], kernel_size[1])),
|
|
stride=1,
|
|
padding=0,
|
|
dilation=1,
|
|
)
|
|
.squeeze(0)
|
|
.squeeze(0)
|
|
).shape
|
|
|
|
attention_a_sequential.add_module(
|
|
"Attention Windowing [Part 1]",
|
|
torch.nn.Unfold(
|
|
kernel_size=(kernel_size[-2], kernel_size[-1]),
|
|
dilation=1,
|
|
padding=0,
|
|
stride=1,
|
|
),
|
|
)
|
|
example_image_a = attention_a_sequential[-1](example_image_a)
|
|
|
|
attention_a_sequential.add_module(
|
|
"Attention Windowing [Part 2]",
|
|
torch.nn.Fold(
|
|
output_size=mock_output_shape,
|
|
kernel_size=(1, 1),
|
|
dilation=1,
|
|
padding=0,
|
|
stride=1,
|
|
),
|
|
)
|
|
example_image_a = attention_a_sequential[-1](example_image_a)
|
|
|
|
attention_a_sequential.add_module("Attention NNMFConv2d", torch.nn.ReLU())
|
|
example_image_a = attention_a_sequential[-1](example_image_a)
|
|
|
|
attention_a_sequential.add_module(
|
|
"Attention NNMFConv2d",
|
|
NNMF2dGrouped(
|
|
in_channels=example_image_a.shape[1],
|
|
out_channels=embed_dim,
|
|
groups=num_heads,
|
|
device=device,
|
|
dtype=dtype,
|
|
iterations=iterations,
|
|
),
|
|
)
|
|
example_image_a = attention_a_sequential[-1](example_image_a)
|
|
|
|
attention_a_sequential.add_module(
|
|
"Attention Layer Norm 2 [Pre-Permute]",
|
|
Functional2Layer(func=torch.permute, dims=(0, 2, 3, 1)),
|
|
)
|
|
example_image_a = attention_a_sequential[-1](example_image_a)
|
|
|
|
attention_a_sequential.add_module(
|
|
"Attention Layer Norm 2",
|
|
torch.nn.LayerNorm(
|
|
normalized_shape=example_image_a.shape[-1],
|
|
eps=1e-06,
|
|
bias=True,
|
|
dtype=dtype,
|
|
device=device,
|
|
),
|
|
)
|
|
example_image_a = attention_a_sequential[-1](example_image_a)
|
|
|
|
attention_a_sequential.add_module(
|
|
"Attention Layer Norm 2 [Post-Permute]",
|
|
Functional2Layer(func=torch.permute, dims=(0, 3, 1, 2)),
|
|
)
|
|
example_image_a = attention_a_sequential[-1](example_image_a)
|
|
|
|
attention_a_sequential.add_module(
|
|
"Attention Conv2d Layer ",
|
|
torch.nn.Conv2d(
|
|
in_channels=example_image_a.shape[1],
|
|
out_channels=backup_image_dim,
|
|
kernel_size=1,
|
|
dtype=dtype,
|
|
device=device,
|
|
),
|
|
)
|
|
example_image_a = attention_a_sequential[-1](example_image_a)
|
|
|
|
attention_b_sequential = torch.nn.Sequential()
|
|
attention_b_sequential.add_module(
|
|
"Attention Identity for the skip", torch.nn.Identity()
|
|
)
|
|
example_image_b = attention_b_sequential[-1](example_image_b)
|
|
|
|
assert example_image_b.shape == example_image_a.shape
|
|
|
|
network.add_module(
|
|
f"Block Number {block_id} [Attention]",
|
|
SequentialSplit(
|
|
torch.nn.Sequential(
|
|
attention_a_sequential,
|
|
attention_b_sequential,
|
|
),
|
|
combine="SUM",
|
|
),
|
|
)
|
|
example_image = network[-1](example_image)
|
|
|
|
# ######
|
|
# MLP #
|
|
# #####
|
|
|
|
example_image_a = example_image.clone()
|
|
example_image_b = example_image.clone()
|
|
|
|
mlp_a_sequential = torch.nn.Sequential()
|
|
|
|
mlp_a_sequential.add_module(
|
|
"MLP [Pre-Permute]", Functional2Layer(func=torch.permute, dims=(0, 2, 3, 1))
|
|
)
|
|
example_image_a = mlp_a_sequential[-1](example_image_a)
|
|
|
|
mlp_a_sequential.add_module(
|
|
"MLP Layer Norm",
|
|
torch.nn.LayerNorm(
|
|
normalized_shape=example_image_a.shape[-1],
|
|
eps=1e-06,
|
|
bias=True,
|
|
dtype=dtype,
|
|
device=device,
|
|
),
|
|
)
|
|
example_image_a = mlp_a_sequential[-1](example_image_a)
|
|
|
|
mlp_a_sequential.add_module(
|
|
"MLP Linear Layer A",
|
|
torch.nn.Linear(
|
|
example_image_a.shape[-1],
|
|
int(example_image_a.shape[-1] * mlp_ratio),
|
|
dtype=dtype,
|
|
device=device,
|
|
),
|
|
)
|
|
example_image_a = mlp_a_sequential[-1](example_image_a)
|
|
|
|
mlp_a_sequential.add_module("MLP GELU", torch.nn.GELU())
|
|
example_image_a = mlp_a_sequential[-1](example_image_a)
|
|
|
|
mlp_a_sequential.add_module(
|
|
"MLP Linear Layer B",
|
|
torch.nn.Linear(
|
|
example_image_a.shape[-1],
|
|
int(example_image_a.shape[-1] // mlp_ratio),
|
|
dtype=dtype,
|
|
device=device,
|
|
),
|
|
)
|
|
example_image_a = mlp_a_sequential[-1](example_image_a)
|
|
|
|
mlp_a_sequential.add_module(
|
|
"MLP [Post-Permute]", Functional2Layer(func=torch.permute, dims=(0, 3, 1, 2))
|
|
)
|
|
example_image_a = mlp_a_sequential[-1](example_image_a)
|
|
|
|
mlp_b_sequential = torch.nn.Sequential()
|
|
mlp_b_sequential.add_module("MLP Identity for the skip", torch.nn.Identity())
|
|
|
|
example_image_b = attention_b_sequential[-1](example_image_b)
|
|
|
|
assert example_image_b.shape == example_image_a.shape
|
|
|
|
network.add_module(
|
|
f"Block Number {block_id} [MLP]",
|
|
SequentialSplit(
|
|
torch.nn.Sequential(
|
|
mlp_a_sequential,
|
|
mlp_b_sequential,
|
|
),
|
|
combine="SUM",
|
|
),
|
|
)
|
|
example_image = network[-1](example_image)
|
|
|
|
return example_image
|
|
|
|
|
|
def make_network(
|
|
in_channels: int = 3,
|
|
dims: list[int] = [72, 72, 72],
|
|
embed_dims: list[int] = [192, 192, 192],
|
|
n_classes: int = 10,
|
|
heads: int = 12,
|
|
example_image_shape: list[int] = [1, 3, 28, 28],
|
|
dtype: torch.dtype = torch.float32,
|
|
device: torch.device | None = None,
|
|
iterations: int = 20,
|
|
) -> torch.nn.Sequential:
|
|
|
|
assert device is not None
|
|
|
|
network = torch.nn.Sequential()
|
|
|
|
example_image: torch.Tensor = torch.zeros(
|
|
example_image_shape, dtype=dtype, device=device
|
|
)
|
|
|
|
network.add_module(
|
|
"Encode Conv2d",
|
|
torch.nn.Conv2d(
|
|
in_channels,
|
|
dims[0],
|
|
kernel_size=4,
|
|
stride=4,
|
|
padding=0,
|
|
dtype=dtype,
|
|
device=device,
|
|
),
|
|
)
|
|
example_image = network[-1](example_image)
|
|
|
|
network.add_module(
|
|
"Encode Offset",
|
|
PositionalEncoding(
|
|
[example_image.shape[-3], example_image.shape[-2], example_image.shape[-1]]
|
|
).to(device=device),
|
|
)
|
|
example_image = network[-1](example_image)
|
|
|
|
network.add_module(
|
|
"Encode Layer Norm [Pre-Permute]",
|
|
Functional2Layer(func=torch.permute, dims=(0, 2, 3, 1)),
|
|
)
|
|
example_image = network[-1](example_image)
|
|
|
|
network.add_module(
|
|
"Encode Layer Norm",
|
|
torch.nn.LayerNorm(
|
|
normalized_shape=example_image.shape[-1],
|
|
eps=1e-06,
|
|
bias=True,
|
|
dtype=dtype,
|
|
device=device,
|
|
),
|
|
)
|
|
example_image = network[-1](example_image)
|
|
|
|
network.add_module(
|
|
"Encode Layer Norm [Post-Permute]",
|
|
Functional2Layer(func=torch.permute, dims=(0, 3, 1, 2)),
|
|
)
|
|
example_image = network[-1](example_image)
|
|
|
|
for i in range(len(dims)):
|
|
example_image = add_block(
|
|
network=network,
|
|
embed_dim=embed_dims[i],
|
|
num_heads=heads,
|
|
mlp_ratio=2,
|
|
block_id=i,
|
|
example_image=example_image,
|
|
dtype=dtype,
|
|
device=device,
|
|
iterations=iterations,
|
|
)
|
|
|
|
network.add_module(
|
|
"Spatial Mean Layer", Functional2Layer(func=torch.mean, dim=(-1, -2))
|
|
)
|
|
example_image = network[-1](example_image)
|
|
|
|
network.add_module(
|
|
"Final Linear Layer",
|
|
torch.nn.Linear(example_image.shape[-1], n_classes, dtype=dtype, device=device),
|
|
)
|
|
example_image = network[-1](example_image)
|
|
|
|
network.add_module("Final Softmax Layer", torch.nn.Softmax(dim=-1))
|
|
example_image = network[-1](example_image)
|
|
|
|
assert example_image.ndim == 2
|
|
assert example_image.shape[0] == example_image_shape[0]
|
|
assert example_image.shape[1] == n_classes
|
|
|
|
return network
|
|
|
|
|
|
if __name__ == "__main__":
|
|
network = make_network(device=torch.device("cuda:0"))
|
|
print(network)
|
|
|
|
number_of_parameter: int = 0
|
|
for name, param in network.named_parameters():
|
|
print(f"Parameter name: {name}, Shape: {param.shape}")
|
|
number_of_parameter += param.numel()
|
|
|
|
print("Number of total parameters:", number_of_parameter)
|