nnmf_24b/make_network.py

368 lines
10 KiB
Python
Raw Permalink Normal View History

2024-07-26 12:55:02 +02:00
import torch
from PositionalEncoding import PositionalEncoding
from SequentialSplit import SequentialSplit
from NNMF2dGrouped import NNMF2dGrouped
from Functional2Layer import Functional2Layer
def add_block(
network: torch.nn.Sequential,
embed_dim: int,
num_heads: int,
dtype: torch.dtype,
device: torch.device,
example_image: torch.Tensor,
mlp_ratio: int = 4,
block_id: int = 0,
iterations: int = 20,
padding: int = 1,
kernel_size: tuple[int, int] = (3, 3),
) -> torch.Tensor | None:
# ###########
# Attention #
# ###########
example_image_a: torch.Tensor = example_image.clone()
example_image_b: torch.Tensor = example_image.clone()
attention_a_sequential = torch.nn.Sequential()
attention_a_sequential.add_module(
"Attention Layer Norm 1 [Pre-Permute]",
Functional2Layer(func=torch.permute, dims=(0, 2, 3, 1)),
)
example_image_a = attention_a_sequential[-1](example_image_a)
attention_a_sequential.add_module(
"Attention Layer Norm 1",
torch.nn.LayerNorm(
normalized_shape=example_image_a.shape[-1],
eps=1e-06,
bias=True,
dtype=dtype,
device=device,
),
)
example_image_a = attention_a_sequential[-1](example_image_a)
attention_a_sequential.add_module(
"Attention Layer Norm 1 [Post-Permute]",
Functional2Layer(func=torch.permute, dims=(0, 3, 1, 2)),
)
example_image_a = attention_a_sequential[-1](example_image_a)
attention_a_sequential.add_module(
"Attention Clamp Layer", Functional2Layer(func=torch.clamp, min=1e-6)
)
example_image_a = attention_a_sequential[-1](example_image_a)
backup_image_dim = example_image_a.shape[1]
attention_a_sequential.add_module(
"Attention Zero Padding Layer", torch.nn.ZeroPad2d(padding=padding)
)
example_image_a = attention_a_sequential[-1](example_image_a)
# I need the output size
mock_output_shape = (
torch.nn.functional.conv2d(
torch.zeros(
1,
1,
example_image_a.shape[2],
example_image_a.shape[3],
),
torch.zeros((1, 1, kernel_size[0], kernel_size[1])),
stride=1,
padding=0,
dilation=1,
)
.squeeze(0)
.squeeze(0)
).shape
attention_a_sequential.add_module(
"Attention Windowing [Part 1]",
torch.nn.Unfold(
kernel_size=(kernel_size[-2], kernel_size[-1]),
dilation=1,
padding=0,
stride=1,
),
)
example_image_a = attention_a_sequential[-1](example_image_a)
attention_a_sequential.add_module(
"Attention Windowing [Part 2]",
torch.nn.Fold(
output_size=mock_output_shape,
kernel_size=(1, 1),
dilation=1,
padding=0,
stride=1,
),
)
example_image_a = attention_a_sequential[-1](example_image_a)
attention_a_sequential.add_module("Attention NNMFConv2d", torch.nn.ReLU())
example_image_a = attention_a_sequential[-1](example_image_a)
attention_a_sequential.add_module(
"Attention NNMFConv2d",
NNMF2dGrouped(
in_channels=example_image_a.shape[1],
out_channels=embed_dim,
groups=num_heads,
device=device,
dtype=dtype,
iterations=iterations,
),
)
example_image_a = attention_a_sequential[-1](example_image_a)
attention_a_sequential.add_module(
"Attention Layer Norm 2 [Pre-Permute]",
Functional2Layer(func=torch.permute, dims=(0, 2, 3, 1)),
)
example_image_a = attention_a_sequential[-1](example_image_a)
attention_a_sequential.add_module(
"Attention Layer Norm 2",
torch.nn.LayerNorm(
normalized_shape=example_image_a.shape[-1],
eps=1e-06,
bias=True,
dtype=dtype,
device=device,
),
)
example_image_a = attention_a_sequential[-1](example_image_a)
attention_a_sequential.add_module(
"Attention Layer Norm 2 [Post-Permute]",
Functional2Layer(func=torch.permute, dims=(0, 3, 1, 2)),
)
example_image_a = attention_a_sequential[-1](example_image_a)
attention_a_sequential.add_module(
"Attention Conv2d Layer ",
torch.nn.Conv2d(
in_channels=example_image_a.shape[1],
out_channels=backup_image_dim,
kernel_size=1,
dtype=dtype,
device=device,
),
)
example_image_a = attention_a_sequential[-1](example_image_a)
attention_b_sequential = torch.nn.Sequential()
attention_b_sequential.add_module(
"Attention Identity for the skip", torch.nn.Identity()
)
example_image_b = attention_b_sequential[-1](example_image_b)
assert example_image_b.shape == example_image_a.shape
network.add_module(
f"Block Number {block_id} [Attention]",
SequentialSplit(
torch.nn.Sequential(
attention_a_sequential,
attention_b_sequential,
),
combine="SUM",
),
)
example_image = network[-1](example_image)
# ######
# MLP #
# #####
example_image_a = example_image.clone()
example_image_b = example_image.clone()
mlp_a_sequential = torch.nn.Sequential()
mlp_a_sequential.add_module(
"MLP [Pre-Permute]", Functional2Layer(func=torch.permute, dims=(0, 2, 3, 1))
)
example_image_a = mlp_a_sequential[-1](example_image_a)
mlp_a_sequential.add_module(
"MLP Layer Norm",
torch.nn.LayerNorm(
normalized_shape=example_image_a.shape[-1],
eps=1e-06,
bias=True,
dtype=dtype,
device=device,
),
)
example_image_a = mlp_a_sequential[-1](example_image_a)
mlp_a_sequential.add_module(
"MLP Linear Layer A",
torch.nn.Linear(
example_image_a.shape[-1],
int(example_image_a.shape[-1] * mlp_ratio),
dtype=dtype,
device=device,
),
)
example_image_a = mlp_a_sequential[-1](example_image_a)
mlp_a_sequential.add_module("MLP GELU", torch.nn.GELU())
example_image_a = mlp_a_sequential[-1](example_image_a)
mlp_a_sequential.add_module(
"MLP Linear Layer B",
torch.nn.Linear(
example_image_a.shape[-1],
int(example_image_a.shape[-1] // mlp_ratio),
dtype=dtype,
device=device,
),
)
example_image_a = mlp_a_sequential[-1](example_image_a)
mlp_a_sequential.add_module(
"MLP [Post-Permute]", Functional2Layer(func=torch.permute, dims=(0, 3, 1, 2))
)
example_image_a = mlp_a_sequential[-1](example_image_a)
mlp_b_sequential = torch.nn.Sequential()
mlp_b_sequential.add_module("MLP Identity for the skip", torch.nn.Identity())
example_image_b = attention_b_sequential[-1](example_image_b)
assert example_image_b.shape == example_image_a.shape
network.add_module(
f"Block Number {block_id} [MLP]",
SequentialSplit(
torch.nn.Sequential(
mlp_a_sequential,
mlp_b_sequential,
),
combine="SUM",
),
)
example_image = network[-1](example_image)
return example_image
def make_network(
in_channels: int = 3,
dims: list[int] = [72, 72, 72],
embed_dims: list[int] = [192, 192, 192],
n_classes: int = 10,
heads: int = 12,
example_image_shape: list[int] = [1, 3, 28, 28],
dtype: torch.dtype = torch.float32,
device: torch.device | None = None,
iterations: int = 20,
) -> torch.nn.Sequential:
assert device is not None
network = torch.nn.Sequential()
example_image: torch.Tensor = torch.zeros(
example_image_shape, dtype=dtype, device=device
)
network.add_module(
"Encode Conv2d",
torch.nn.Conv2d(
in_channels,
dims[0],
kernel_size=4,
stride=4,
padding=0,
dtype=dtype,
device=device,
),
)
example_image = network[-1](example_image)
network.add_module(
"Encode Offset",
PositionalEncoding(
[example_image.shape[-3], example_image.shape[-2], example_image.shape[-1]]
).to(device=device),
)
example_image = network[-1](example_image)
network.add_module(
"Encode Layer Norm [Pre-Permute]",
Functional2Layer(func=torch.permute, dims=(0, 2, 3, 1)),
)
example_image = network[-1](example_image)
network.add_module(
"Encode Layer Norm",
torch.nn.LayerNorm(
normalized_shape=example_image.shape[-1],
eps=1e-06,
bias=True,
dtype=dtype,
device=device,
),
)
example_image = network[-1](example_image)
network.add_module(
"Encode Layer Norm [Post-Permute]",
Functional2Layer(func=torch.permute, dims=(0, 3, 1, 2)),
)
example_image = network[-1](example_image)
for i in range(len(dims)):
example_image = add_block(
network=network,
embed_dim=embed_dims[i],
num_heads=heads,
mlp_ratio=2,
block_id=i,
example_image=example_image,
dtype=dtype,
device=device,
iterations=iterations,
)
network.add_module(
"Spatial Mean Layer", Functional2Layer(func=torch.mean, dim=(-1, -2))
)
example_image = network[-1](example_image)
network.add_module(
"Final Linear Layer",
torch.nn.Linear(example_image.shape[-1], n_classes, dtype=dtype, device=device),
)
example_image = network[-1](example_image)
network.add_module("Final Softmax Layer", torch.nn.Softmax(dim=-1))
example_image = network[-1](example_image)
assert example_image.ndim == 2
assert example_image.shape[0] == example_image_shape[0]
assert example_image.shape[1] == n_classes
return network
if __name__ == "__main__":
network = make_network(device=torch.device("cuda:0"))
print(network)
number_of_parameter: int = 0
for name, param in network.named_parameters():
print(f"Parameter name: {name}, Shape: {param.shape}")
number_of_parameter += param.numel()
print("Number of total parameters:", number_of_parameter)