diff --git a/Functional2Layer.py b/Functional2Layer.py new file mode 100644 index 0000000..4ca3d30 --- /dev/null +++ b/Functional2Layer.py @@ -0,0 +1,41 @@ +import torch +from typing import Callable, Any + + +class Functional2Layer(torch.nn.Module): + def __init__( + self, func: Callable[..., torch.Tensor], *args: Any, **kwargs: Any + ) -> None: + super().__init__() + self.func = func + self.args = args + self.kwargs = kwargs + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return self.func(input, *self.args, **self.kwargs) + + def extra_repr(self) -> str: + func_name = ( + self.func.__name__ if hasattr(self.func, "__name__") else str(self.func) + ) + args_repr = ", ".join(map(repr, self.args)) + kwargs_repr = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items()) + return f"func={func_name}, args=({args_repr}), kwargs={{{kwargs_repr}}}" + + +if __name__ == "__main__": + + print("Permute Example") + test_layer_permute = Functional2Layer(func=torch.permute, dims=(0, 2, 3, 1)) + input = torch.zeros((10, 11, 12, 13)) + output = test_layer_permute(input) + print(input.shape) + print(output.shape) + print(test_layer_permute) + + print() + print("Clamp Example") + test_layer_clamp = Functional2Layer(func=torch.clamp, min=5, max=100) + output = test_layer_permute(input) + print(output[0, 0, 0, 0]) + print(test_layer_clamp) diff --git a/NNMF2dGrouped.py b/NNMF2dGrouped.py new file mode 100644 index 0000000..35fe0a6 --- /dev/null +++ b/NNMF2dGrouped.py @@ -0,0 +1,277 @@ +import torch +from non_linear_weigth_function import non_linear_weigth_function + + +class NNMF2dGrouped(torch.nn.Module): + + in_channels: int + out_channels: int + weight: torch.Tensor + iterations: int + epsilon: float | None + init_min: float + init_max: float + beta: torch.Tensor | None + positive_function_type: int + local_learning: bool + local_learning_kl: bool + groups: int + + def __init__( + self, + in_channels: int, + out_channels: int, + groups: int = 1, + device=None, + dtype=None, + iterations: int = 20, + epsilon: float | None = None, + init_min: float = 0.0, + init_max: float = 1.0, + beta: float | None = None, + positive_function_type: int = 0, + local_learning: bool = False, + local_learning_kl: bool = False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + + super().__init__() + + self.positive_function_type = positive_function_type + self.init_min = init_min + self.init_max = init_max + + self.groups = groups + assert ( + in_channels % self.groups == 0 + ), f"Can't divide without rest {in_channels} / {self.groups}" + self.in_channels = in_channels // self.groups + assert ( + out_channels % self.groups == 0 + ), f"Can't divide without rest {out_channels} / {self.groups}" + self.out_channels = out_channels // self.groups + + self.iterations = iterations + self.local_learning = local_learning + self.local_learning_kl = local_learning_kl + + self.weight = torch.nn.parameter.Parameter( + torch.empty( + (self.groups, self.out_channels, self.in_channels), **factory_kwargs + ) + ) + + if beta is not None: + self.beta = torch.nn.parameter.Parameter(torch.empty((1), **factory_kwargs)) + self.beta.data[0] = beta + else: + self.beta = None + + self.reset_parameters() + self.functional_nnmf2d_grouped = FunctionalNNMF2dGrouped.apply + + self.epsilon = epsilon + + def extra_repr(self) -> str: + s: str = f"{self.in_channels}, {self.out_channels}" + + if self.epsilon is not None: + s += f", epsilon={self.epsilon}" + s += f", pfunctype={self.positive_function_type}" + s += f", local_learning={self.local_learning}" + s += f", groups={self.groups}" + + if self.local_learning: + s += f", local_learning_kl={self.local_learning_kl}" + + return s + + def reset_parameters(self) -> None: + torch.nn.init.uniform_(self.weight, a=self.init_min, b=self.init_max) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + + positive_weights = non_linear_weigth_function( + self.weight, self.beta, self.positive_function_type + ) + positive_weights = positive_weights / ( + positive_weights.sum(dim=-1, keepdim=True) + 10e-20 + ) + assert self.groups * self.in_channels == input.shape[1] + + input = input.reshape( + ( + input.shape[0], + self.groups, + self.in_channels, + input.shape[-2], + input.shape[-1], + ) + ) + input = input / (input.sum(dim=2, keepdim=True) + 10e-20) + + h_dyn = self.functional_nnmf2d_grouped( + input, + positive_weights, + self.out_channels, + self.iterations, + self.epsilon, + self.local_learning, + self.local_learning_kl, + ) + + h_dyn = h_dyn.reshape( + ( + h_dyn.shape[0], + h_dyn.shape[1] * h_dyn.shape[2], + h_dyn.shape[3], + h_dyn.shape[4], + ) + ) + h_dyn = h_dyn / (h_dyn.sum(dim=1, keepdim=True) + 10e-20) + + return h_dyn + + +@torch.jit.script +def grouped_linear_einsum_h_weights(h, weights): + return torch.einsum("bgoxy,goi->bgixy", h, weights) + + +@torch.jit.script +def grouped_linear_einsum_reconstruction_weights(reconstruction, weights): + return torch.einsum("bgixy,goi->bgoxy", reconstruction, weights) + + +@torch.jit.script +def grouped_linear_einsum_h_input(h, reconstruction): + return torch.einsum("bgoxy,bgixy->goi", h, reconstruction) + + +class FunctionalNNMF2dGrouped(torch.autograd.Function): + + @staticmethod + def forward( # type: ignore + ctx, + input: torch.Tensor, + weight: torch.Tensor, + out_channels: int, + iterations: int, + epsilon: float | None, + local_learning: bool, + local_learning_kl: bool, + ) -> torch.Tensor: + + # Prepare h + h = torch.full( + ( + input.shape[0], + input.shape[1], + out_channels, + input.shape[-2], + input.shape[-1], + ), + 1.0 / float(out_channels), + device=input.device, + dtype=input.dtype, + ) + + for _ in range(0, iterations): + + reconstruction = grouped_linear_einsum_h_weights(h, weight) + reconstruction += 1e-20 + + if epsilon is None: + h *= grouped_linear_einsum_reconstruction_weights( + (input / reconstruction), weight + ) + else: + h *= 1 + epsilon * grouped_linear_einsum_reconstruction_weights( + (input / reconstruction), weight + ) + h /= h.sum(2, keepdim=True) + 10e-20 + + # ########################################################### + # Save the necessary data for the backward pass + # ########################################################### + ctx.save_for_backward(input, weight, h) + ctx.local_learning = local_learning + ctx.local_learning_kl = local_learning_kl + + assert torch.isfinite(h).all() + return h + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, grad_output: torch.Tensor) -> tuple[ # type: ignore + torch.Tensor, + torch.Tensor | None, + None, + None, + None, + None, + None, + ]: + + # ############################################## + # Default values + # ############################################## + grad_weight: torch.Tensor | None = None + + # ############################################## + # Get the variables back + # ############################################## + (input, weight, h) = ctx.saved_tensors + + # The back prop gradient + big_r = grouped_linear_einsum_h_weights(h, weight) + + big_r_div = 1.0 / (big_r + 1e-20) + + factor_x_div_r = input * big_r_div + + grad_input: torch.Tensor = ( + grouped_linear_einsum_h_weights(h * grad_output, weight) * big_r_div + ) + + del big_r_div + + # The weight gradient + if ctx.local_learning is False: + del big_r + + grad_weight = -grouped_linear_einsum_h_input( + h, (factor_x_div_r * grad_input) + ) + + grad_weight += grouped_linear_einsum_h_input( + (h * grad_output), + factor_x_div_r, + ) + + else: + if ctx.local_learning_kl: + + grad_weight = -grouped_linear_einsum_h_input( + h, + factor_x_div_r, + ) + + else: + grad_weight = -grouped_linear_einsum_h_input( + h, + (2 * (input - big_r)), + ) + + assert torch.isfinite(grad_input).all() + assert torch.isfinite(grad_weight).all() + + return ( + grad_input, + grad_weight, + None, + None, + None, + None, + None, + ) diff --git a/PositionalEncoding.py b/PositionalEncoding.py new file mode 100644 index 0000000..392e66d --- /dev/null +++ b/PositionalEncoding.py @@ -0,0 +1,22 @@ +import torch + + +class PositionalEncoding(torch.nn.Module): + + init_std: float + pos_embedding: torch.nn.Parameter + + def __init__(self, dim: list[int], init_std: float = 0.2): + super().__init__() + self.init_std = init_std + assert len(dim) == 3 + self.pos_embedding: torch.nn.Parameter = torch.nn.Parameter( + torch.randn(1, *dim) + ) + self.init_parameters() + + def init_parameters(self): + torch.nn.init.trunc_normal_(self.pos_embedding, std=self.init_std) + + def forward(self, input: torch.Tensor): + return input + self.pos_embedding diff --git a/SequentialSplit.py b/SequentialSplit.py new file mode 100644 index 0000000..c20dcbe --- /dev/null +++ b/SequentialSplit.py @@ -0,0 +1,169 @@ +import torch +from typing import Callable + + +class SequentialSplit(torch.nn.Module): + """ + A PyTorch module that splits the processing path of a input tensor + and processes it through multiple torch.nn.Sequential segments, + then combines the outputs using a specified methods. + + This module allows for creating split paths within a `torch.nn.Sequential` + model, making it possible to implement architectures with skip connections + or parallel paths without abandoning the sequential model structure. + + Attributes: + segments (torch.nn.Sequential[torch.nn.Sequential]): A list of sequential modules to + process the input tensor. + combine_func (Callable | None): A function to combine the outputs + from the segments. + dim (int | None): The dimension along which to concatenate + the outputs if `combine_func` is `torch.cat`. + + Args: + segments (torch.nn.Sequential[torch.nn.Sequential]): A torch.nn.Sequential + with a list of sequential modules to process the input tensor. + combine (str, optional): The method to combine the outputs. + "cat" for concatenation (default), "sum" for a summation, + or "func" to use a custom combine function. + dim (int | None, optional): The dimension along which to + concatenate the outputs if `combine` is "cat". + Defaults to 1. + combine_func (Callable | None, optional): A custom function + to combine the outputs if `combine` is "func". + Defaults to None. + + Example: + A simple example for the `SequentialSplit` module with two sub-torch.nn.Sequential: + + ----- segment_a ----- + main_Sequential ----| |---- main_Sequential + ----- segment_b ----- + + segments = [segment_a, segment_b] + y_split = SequentialSplit(segments) + result = y_split(input_tensor) + + Methods: + forward(input: torch.Tensor) -> torch.Tensor: + Processes the input tensor through the segments and + combines the results. + """ + + segments: torch.nn.Sequential + combine_func: Callable + dim: int | None + + def __init__( + self, + segments: torch.nn.Sequential, + combine: str = "cat", # "cat", "sum", "func", + dim: int | None = 1, + combine_func: Callable | None = None, + ): + super().__init__() + self.segments = segments + self.dim = dim + + self.combine = combine + + if combine.upper() == "CAT": + self.combine_func = torch.cat + elif combine.upper() == "SUM": + self.combine_func = self.sum + self.dim = None + else: + assert combine_func is not None + self.combine_func = combine_func + + def sum(self, input: list[torch.Tensor]) -> torch.Tensor | None: + + if len(input) == 0: + return None + + if len(input) == 1: + return input[0] + + output: torch.Tensor = input[0] + + for i in range(1, len(input)): + output = output + input[i] + + return output + + def forward(self, input: torch.Tensor) -> torch.Tensor: + results: list[torch.Tensor] = [] + for segment in self.segments: + results.append(segment(input)) + + if self.dim is None: + return self.combine_func(results) + else: + return self.combine_func(results, dim=self.dim) + + def extra_repr(self) -> str: + return self.combine + + +if __name__ == "__main__": + + print("Example CAT") + strain_a = torch.nn.Sequential(torch.nn.Identity()) + strain_b = torch.nn.Sequential(torch.nn.Identity()) + strain_c = torch.nn.Sequential(torch.nn.Identity()) + test_cat = SequentialSplit( + torch.nn.Sequential(strain_a, strain_b, strain_c), combine="cat", dim=2 + ) + print(test_cat) + input = torch.ones((10, 11, 12, 13)) + output = test_cat(input) + print(input.shape) + print(output.shape) + print(input[0, 0, 0, 0]) + print(output[0, 0, 0, 0]) + print() + + print("Example SUM") + strain_a = torch.nn.Sequential(torch.nn.Identity()) + strain_b = torch.nn.Sequential(torch.nn.Identity()) + strain_c = torch.nn.Sequential(torch.nn.Identity()) + test_sum = SequentialSplit( + torch.nn.Sequential(strain_a, strain_b, strain_c), combine="sum", dim=2 + ) + print(test_sum) + input = torch.ones((10, 11, 12, 13)) + output = test_sum(input) + print(input.shape) + print(output.shape) + print(input[0, 0, 0, 0]) + print(output[0, 0, 0, 0]) + print() + + print("Example Labeling") + strain_a = torch.nn.Sequential() + strain_a.add_module("Label for first strain", torch.nn.Identity()) + strain_b = torch.nn.Sequential() + strain_b.add_module("Label for second strain", torch.nn.Identity()) + strain_c = torch.nn.Sequential() + strain_c.add_module("Label for third strain", torch.nn.Identity()) + test_label = SequentialSplit(torch.nn.Sequential(strain_a, strain_b, strain_c)) + print(test_label) + print() + + print("Example Get Parameter") + input = torch.ones((10, 11, 12, 13)) + strain_a = torch.nn.Sequential() + strain_a.add_module("Identity", torch.nn.Identity()) + strain_b = torch.nn.Sequential() + strain_b.add_module( + "Conv2d", + torch.nn.Conv2d( + in_channels=input.shape[1], + out_channels=input.shape[1], + kernel_size=(1, 1), + ), + ) + test_parameter = SequentialSplit(torch.nn.Sequential(strain_a, strain_b)) + print(test_parameter) + for name, param in test_parameter.named_parameters(): + print(f"Parameter name: {name}, Shape: {param.shape}") diff --git a/convert_log_to_numpy.py b/convert_log_to_numpy.py new file mode 100644 index 0000000..b96c5d1 --- /dev/null +++ b/convert_log_to_numpy.py @@ -0,0 +1,29 @@ +import os +import glob + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + +from tensorboard.backend.event_processing import event_accumulator # type: ignore +import numpy as np + + +def get_data(path: str = "log_cnn"): + acc = event_accumulator.EventAccumulator(path) + acc.Reload() + + which_scalar = "Test Number Correct" + te = acc.Scalars(which_scalar) + + np_temp = np.zeros((len(te), 2)) + + for id in range(0, len(te)): + np_temp[id, 0] = te[id].step + np_temp[id, 1] = te[id].value + print(np_temp[:, 1]/100) + return np_temp + + +for path in glob.glob("log_*"): + print(path) + data = get_data(path) + np.save("data_" + path + ".npy", data) diff --git a/data_loader.py b/data_loader.py new file mode 100644 index 0000000..0a0d430 --- /dev/null +++ b/data_loader.py @@ -0,0 +1,31 @@ +import torch + + +def data_loader( + pattern: torch.Tensor, + labels: torch.Tensor, + worker_init_fn, + generator, + batch_size: int = 128, + shuffle: bool = True, + torch_device: torch.device = torch.device("cpu"), +) -> torch.utils.data.dataloader.DataLoader: + + assert pattern.ndim >= 3 + + pattern_storage: torch.Tensor = pattern.to(torch_device).type(torch.float32) + if pattern_storage.ndim == 3: + pattern_storage = pattern_storage.unsqueeze(1) + pattern_storage /= pattern_storage.max() + + label_storage: torch.Tensor = labels.to(torch_device).type(torch.int64) + + dataloader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset(pattern_storage, label_storage), + batch_size=batch_size, + shuffle=shuffle, + worker_init_fn=worker_init_fn, + generator=generator, + ) + + return dataloader diff --git a/get_the_data.py b/get_the_data.py new file mode 100644 index 0000000..fc61064 --- /dev/null +++ b/get_the_data.py @@ -0,0 +1,147 @@ +import torch +import torchvision # type: ignore +from data_loader import data_loader + +from torchvision.transforms import v2 # type: ignore +import numpy as np + + +def get_the_data( + dataset: str, + batch_size_train: int, + batch_size_test: int, + torch_device: torch.device, + input_dim_x: int, + input_dim_y: int, + flip_p: float = 0.5, + jitter_brightness: float = 0.5, + jitter_contrast: float = 0.1, + jitter_saturation: float = 0.1, + jitter_hue: float = 0.15, + da_auto_mode: bool = False, +) -> tuple[ + torch.utils.data.dataloader.DataLoader, + torch.utils.data.dataloader.DataLoader, + torchvision.transforms.Compose, + torchvision.transforms.Compose, +]: + if dataset == "MNIST": + tv_dataset_train = torchvision.datasets.MNIST( + root="data", train=True, download=True + ) + tv_dataset_test = torchvision.datasets.MNIST( + root="data", train=False, download=True + ) + elif dataset == "FashionMNIST": + tv_dataset_train = torchvision.datasets.FashionMNIST( + root="data", train=True, download=True + ) + tv_dataset_test = torchvision.datasets.FashionMNIST( + root="data", train=False, download=True + ) + elif dataset == "CIFAR10": + tv_dataset_train = torchvision.datasets.CIFAR10( + root="data", train=True, download=True + ) + tv_dataset_test = torchvision.datasets.CIFAR10( + root="data", train=False, download=True + ) + else: + raise NotImplementedError("This dataset is not implemented.") + + def seed_worker(worker_id): + worker_seed = torch.initial_seed() % 2**32 + np.random.seed(worker_seed) + torch.random.seed(worker_seed) + + g = torch.Generator() + g.manual_seed(0) + + if dataset == "MNIST" or dataset == "FashionMNIST": + + train_dataloader = data_loader( + torch_device=torch_device, + batch_size=batch_size_train, + pattern=tv_dataset_train.data, + labels=tv_dataset_train.targets, + shuffle=True, + worker_init_fn=seed_worker, + generator=g, + ) + + test_dataloader = data_loader( + torch_device=torch_device, + batch_size=batch_size_test, + pattern=tv_dataset_test.data, + labels=tv_dataset_test.targets, + shuffle=False, + worker_init_fn=seed_worker, + generator=g, + ) + + # Data augmentation filter + test_processing_chain = torchvision.transforms.Compose( + transforms=[torchvision.transforms.CenterCrop((input_dim_x, input_dim_y))], + ) + + train_processing_chain = torchvision.transforms.Compose( + transforms=[torchvision.transforms.RandomCrop((input_dim_x, input_dim_y))], + ) + else: + + train_dataloader = data_loader( + torch_device=torch_device, + batch_size=batch_size_train, + pattern=torch.tensor(tv_dataset_train.data).movedim(-1, 1), + labels=torch.tensor(tv_dataset_train.targets), + shuffle=True, + worker_init_fn=seed_worker, + generator=g, + ) + + test_dataloader = data_loader( + torch_device=torch_device, + batch_size=batch_size_test, + pattern=torch.tensor(tv_dataset_test.data).movedim(-1, 1), + labels=torch.tensor(tv_dataset_test.targets), + shuffle=False, + worker_init_fn=seed_worker, + generator=g, + ) + + # Data augmentation filter + test_processing_chain = torchvision.transforms.Compose( + transforms=[torchvision.transforms.CenterCrop((input_dim_x, input_dim_y))], + ) + + if da_auto_mode: + train_processing_chain = torchvision.transforms.Compose( + transforms=[ + v2.AutoAugment( + policy=torchvision.transforms.AutoAugmentPolicy( + v2.AutoAugmentPolicy.CIFAR10 + ) + ), + torchvision.transforms.CenterCrop((input_dim_x, input_dim_y)), + ], + ) + else: + train_processing_chain = torchvision.transforms.Compose( + transforms=[ + torchvision.transforms.RandomCrop((input_dim_x, input_dim_y)), + torchvision.transforms.RandomHorizontalFlip(p=flip_p), + torchvision.transforms.ColorJitter( + brightness=jitter_brightness, + contrast=jitter_contrast, + saturation=jitter_saturation, + hue=jitter_hue, + ), + ], + ) + + return ( + train_dataloader, + test_dataloader, + train_processing_chain, + test_processing_chain, + ) diff --git a/loss_function.py b/loss_function.py new file mode 100644 index 0000000..e256840 --- /dev/null +++ b/loss_function.py @@ -0,0 +1,64 @@ +import torch + + +# loss_mode == 0: "normal" SbS loss function mixture +# loss_mode == 1: cross_entropy +def loss_function( + h: torch.Tensor, + labels: torch.Tensor, + loss_mode: int = 0, + number_of_output_neurons: int = 10, + loss_coeffs_mse: float = 0.0, + loss_coeffs_kldiv: float = 0.0, +) -> torch.Tensor | None: + + assert loss_mode >= 0 + assert loss_mode <= 1 + + assert h.ndim == 2 + + if loss_mode == 0: + + # Convert label into one hot + target_one_hot: torch.Tensor = torch.zeros( + ( + labels.shape[0], + number_of_output_neurons, + ), + device=h.device, + dtype=h.dtype, + ) + + target_one_hot.scatter_( + 1, + labels.to(h.device).unsqueeze(1), + torch.ones( + (labels.shape[0], 1), + device=h.device, + dtype=h.dtype, + ), + ) + + my_loss: torch.Tensor = ((h - target_one_hot) ** 2).sum(dim=0).mean( + dim=0 + ) * loss_coeffs_mse + + my_loss = ( + my_loss + + ( + (target_one_hot * torch.log((target_one_hot + 1e-20) / (h + 1e-20))) + .sum(dim=0) + .mean(dim=0) + ) + * loss_coeffs_kldiv + ) + + my_loss = my_loss / (abs(loss_coeffs_kldiv) + abs(loss_coeffs_mse)) + + return my_loss + + elif loss_mode == 1: + my_loss = torch.nn.functional.cross_entropy(h, labels.to(h.device)) + return my_loss + else: + return None diff --git a/make_network.py b/make_network.py new file mode 100644 index 0000000..9390638 --- /dev/null +++ b/make_network.py @@ -0,0 +1,367 @@ +import torch +from PositionalEncoding import PositionalEncoding +from SequentialSplit import SequentialSplit +from NNMF2dGrouped import NNMF2dGrouped +from Functional2Layer import Functional2Layer + + +def add_block( + network: torch.nn.Sequential, + embed_dim: int, + num_heads: int, + dtype: torch.dtype, + device: torch.device, + example_image: torch.Tensor, + mlp_ratio: int = 4, + block_id: int = 0, + iterations: int = 20, + padding: int = 1, + kernel_size: tuple[int, int] = (3, 3), +) -> torch.Tensor | None: + + # ########### + # Attention # + # ########### + + example_image_a: torch.Tensor = example_image.clone() + example_image_b: torch.Tensor = example_image.clone() + + attention_a_sequential = torch.nn.Sequential() + + attention_a_sequential.add_module( + "Attention Layer Norm 1 [Pre-Permute]", + Functional2Layer(func=torch.permute, dims=(0, 2, 3, 1)), + ) + example_image_a = attention_a_sequential[-1](example_image_a) + + attention_a_sequential.add_module( + "Attention Layer Norm 1", + torch.nn.LayerNorm( + normalized_shape=example_image_a.shape[-1], + eps=1e-06, + bias=True, + dtype=dtype, + device=device, + ), + ) + example_image_a = attention_a_sequential[-1](example_image_a) + + attention_a_sequential.add_module( + "Attention Layer Norm 1 [Post-Permute]", + Functional2Layer(func=torch.permute, dims=(0, 3, 1, 2)), + ) + example_image_a = attention_a_sequential[-1](example_image_a) + + attention_a_sequential.add_module( + "Attention Clamp Layer", Functional2Layer(func=torch.clamp, min=1e-6) + ) + example_image_a = attention_a_sequential[-1](example_image_a) + + backup_image_dim = example_image_a.shape[1] + + attention_a_sequential.add_module( + "Attention Zero Padding Layer", torch.nn.ZeroPad2d(padding=padding) + ) + example_image_a = attention_a_sequential[-1](example_image_a) + + # I need the output size + mock_output_shape = ( + torch.nn.functional.conv2d( + torch.zeros( + 1, + 1, + example_image_a.shape[2], + example_image_a.shape[3], + ), + torch.zeros((1, 1, kernel_size[0], kernel_size[1])), + stride=1, + padding=0, + dilation=1, + ) + .squeeze(0) + .squeeze(0) + ).shape + + attention_a_sequential.add_module( + "Attention Windowing [Part 1]", + torch.nn.Unfold( + kernel_size=(kernel_size[-2], kernel_size[-1]), + dilation=1, + padding=0, + stride=1, + ), + ) + example_image_a = attention_a_sequential[-1](example_image_a) + + attention_a_sequential.add_module( + "Attention Windowing [Part 2]", + torch.nn.Fold( + output_size=mock_output_shape, + kernel_size=(1, 1), + dilation=1, + padding=0, + stride=1, + ), + ) + example_image_a = attention_a_sequential[-1](example_image_a) + + attention_a_sequential.add_module("Attention NNMFConv2d", torch.nn.ReLU()) + example_image_a = attention_a_sequential[-1](example_image_a) + + attention_a_sequential.add_module( + "Attention NNMFConv2d", + NNMF2dGrouped( + in_channels=example_image_a.shape[1], + out_channels=embed_dim, + groups=num_heads, + device=device, + dtype=dtype, + iterations=iterations, + ), + ) + example_image_a = attention_a_sequential[-1](example_image_a) + + attention_a_sequential.add_module( + "Attention Layer Norm 2 [Pre-Permute]", + Functional2Layer(func=torch.permute, dims=(0, 2, 3, 1)), + ) + example_image_a = attention_a_sequential[-1](example_image_a) + + attention_a_sequential.add_module( + "Attention Layer Norm 2", + torch.nn.LayerNorm( + normalized_shape=example_image_a.shape[-1], + eps=1e-06, + bias=True, + dtype=dtype, + device=device, + ), + ) + example_image_a = attention_a_sequential[-1](example_image_a) + + attention_a_sequential.add_module( + "Attention Layer Norm 2 [Post-Permute]", + Functional2Layer(func=torch.permute, dims=(0, 3, 1, 2)), + ) + example_image_a = attention_a_sequential[-1](example_image_a) + + attention_a_sequential.add_module( + "Attention Conv2d Layer ", + torch.nn.Conv2d( + in_channels=example_image_a.shape[1], + out_channels=backup_image_dim, + kernel_size=1, + dtype=dtype, + device=device, + ), + ) + example_image_a = attention_a_sequential[-1](example_image_a) + + attention_b_sequential = torch.nn.Sequential() + attention_b_sequential.add_module( + "Attention Identity for the skip", torch.nn.Identity() + ) + example_image_b = attention_b_sequential[-1](example_image_b) + + assert example_image_b.shape == example_image_a.shape + + network.add_module( + f"Block Number {block_id} [Attention]", + SequentialSplit( + torch.nn.Sequential( + attention_a_sequential, + attention_b_sequential, + ), + combine="SUM", + ), + ) + example_image = network[-1](example_image) + + # ###### + # MLP # + # ##### + + example_image_a = example_image.clone() + example_image_b = example_image.clone() + + mlp_a_sequential = torch.nn.Sequential() + + mlp_a_sequential.add_module( + "MLP [Pre-Permute]", Functional2Layer(func=torch.permute, dims=(0, 2, 3, 1)) + ) + example_image_a = mlp_a_sequential[-1](example_image_a) + + mlp_a_sequential.add_module( + "MLP Layer Norm", + torch.nn.LayerNorm( + normalized_shape=example_image_a.shape[-1], + eps=1e-06, + bias=True, + dtype=dtype, + device=device, + ), + ) + example_image_a = mlp_a_sequential[-1](example_image_a) + + mlp_a_sequential.add_module( + "MLP Linear Layer A", + torch.nn.Linear( + example_image_a.shape[-1], + int(example_image_a.shape[-1] * mlp_ratio), + dtype=dtype, + device=device, + ), + ) + example_image_a = mlp_a_sequential[-1](example_image_a) + + mlp_a_sequential.add_module("MLP GELU", torch.nn.GELU()) + example_image_a = mlp_a_sequential[-1](example_image_a) + + mlp_a_sequential.add_module( + "MLP Linear Layer B", + torch.nn.Linear( + example_image_a.shape[-1], + int(example_image_a.shape[-1] // mlp_ratio), + dtype=dtype, + device=device, + ), + ) + example_image_a = mlp_a_sequential[-1](example_image_a) + + mlp_a_sequential.add_module( + "MLP [Post-Permute]", Functional2Layer(func=torch.permute, dims=(0, 3, 1, 2)) + ) + example_image_a = mlp_a_sequential[-1](example_image_a) + + mlp_b_sequential = torch.nn.Sequential() + mlp_b_sequential.add_module("MLP Identity for the skip", torch.nn.Identity()) + + example_image_b = attention_b_sequential[-1](example_image_b) + + assert example_image_b.shape == example_image_a.shape + + network.add_module( + f"Block Number {block_id} [MLP]", + SequentialSplit( + torch.nn.Sequential( + mlp_a_sequential, + mlp_b_sequential, + ), + combine="SUM", + ), + ) + example_image = network[-1](example_image) + + return example_image + + +def make_network( + in_channels: int = 3, + dims: list[int] = [72, 72, 72], + embed_dims: list[int] = [192, 192, 192], + n_classes: int = 10, + heads: int = 12, + example_image_shape: list[int] = [1, 3, 28, 28], + dtype: torch.dtype = torch.float32, + device: torch.device | None = None, + iterations: int = 20, +) -> torch.nn.Sequential: + + assert device is not None + + network = torch.nn.Sequential() + + example_image: torch.Tensor = torch.zeros( + example_image_shape, dtype=dtype, device=device + ) + + network.add_module( + "Encode Conv2d", + torch.nn.Conv2d( + in_channels, + dims[0], + kernel_size=4, + stride=4, + padding=0, + dtype=dtype, + device=device, + ), + ) + example_image = network[-1](example_image) + + network.add_module( + "Encode Offset", + PositionalEncoding( + [example_image.shape[-3], example_image.shape[-2], example_image.shape[-1]] + ).to(device=device), + ) + example_image = network[-1](example_image) + + network.add_module( + "Encode Layer Norm [Pre-Permute]", + Functional2Layer(func=torch.permute, dims=(0, 2, 3, 1)), + ) + example_image = network[-1](example_image) + + network.add_module( + "Encode Layer Norm", + torch.nn.LayerNorm( + normalized_shape=example_image.shape[-1], + eps=1e-06, + bias=True, + dtype=dtype, + device=device, + ), + ) + example_image = network[-1](example_image) + + network.add_module( + "Encode Layer Norm [Post-Permute]", + Functional2Layer(func=torch.permute, dims=(0, 3, 1, 2)), + ) + example_image = network[-1](example_image) + + for i in range(len(dims)): + example_image = add_block( + network=network, + embed_dim=embed_dims[i], + num_heads=heads, + mlp_ratio=2, + block_id=i, + example_image=example_image, + dtype=dtype, + device=device, + iterations=iterations, + ) + + network.add_module( + "Spatial Mean Layer", Functional2Layer(func=torch.mean, dim=(-1, -2)) + ) + example_image = network[-1](example_image) + + network.add_module( + "Final Linear Layer", + torch.nn.Linear(example_image.shape[-1], n_classes, dtype=dtype, device=device), + ) + example_image = network[-1](example_image) + + network.add_module("Final Softmax Layer", torch.nn.Softmax(dim=-1)) + example_image = network[-1](example_image) + + assert example_image.ndim == 2 + assert example_image.shape[0] == example_image_shape[0] + assert example_image.shape[1] == n_classes + + return network + + +if __name__ == "__main__": + network = make_network(device=torch.device("cuda:0")) + print(network) + + number_of_parameter: int = 0 + for name, param in network.named_parameters(): + print(f"Parameter name: {name}, Shape: {param.shape}") + number_of_parameter += param.numel() + + print("Number of total parameters:", number_of_parameter) diff --git a/make_optimize.py b/make_optimize.py new file mode 100644 index 0000000..ab1a4e0 --- /dev/null +++ b/make_optimize.py @@ -0,0 +1,32 @@ +import torch + + +def make_optimize( + parameters: list[list[torch.nn.parameter.Parameter]], + lr_initial: list[float], + eps=1e-10, +) -> tuple[ + list[torch.optim.Adam | None], + list[torch.optim.lr_scheduler.ReduceLROnPlateau | None], +]: + list_optimizer: list[torch.optim.Adam | None] = [] + list_lr_scheduler: list[torch.optim.lr_scheduler.ReduceLROnPlateau | None] = [] + + assert len(parameters) == len(lr_initial) + + for i in range(0, len(parameters)): + if len(parameters[i]) > 0: + list_optimizer.append(torch.optim.Adam(parameters[i], lr=lr_initial[i])) + else: + list_optimizer.append(None) + + for i in range(0, len(list_optimizer)): + if list_optimizer[i] is not None: + pass + list_lr_scheduler.append( + torch.optim.lr_scheduler.ReduceLROnPlateau(list_optimizer[i], eps=eps) # type: ignore + ) + else: + list_lr_scheduler.append(None) + + return (list_optimizer, list_lr_scheduler) diff --git a/non_linear_weigth_function.py b/non_linear_weigth_function.py new file mode 100644 index 0000000..053a9b6 --- /dev/null +++ b/non_linear_weigth_function.py @@ -0,0 +1,26 @@ +import torch + + +def non_linear_weigth_function( + weight: torch.Tensor, beta: torch.Tensor | None, positive_function_type: int +) -> torch.Tensor: + + if positive_function_type == 0: + positive_weights = torch.abs(weight) + + elif positive_function_type == 1: + assert beta is not None + positive_weights = weight + max_value = torch.abs(positive_weights).max() + if max_value > 80: + positive_weights = 80.0 * positive_weights / max_value + positive_weights = torch.exp((torch.tanh(beta) + 1.0) * 0.5 * positive_weights) + + elif positive_function_type == 2: + assert beta is not None + positive_weights = (torch.tanh(beta * weight) + 1.0) * 0.5 + + else: + positive_weights = weight + + return positive_weights diff --git a/run_network.py b/run_network.py new file mode 100644 index 0000000..f10a5bf --- /dev/null +++ b/run_network.py @@ -0,0 +1,263 @@ +import os + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + +import argh + +import time +import numpy as np +import torch + +rand_seed: int = 21 +torch.manual_seed(rand_seed) +torch.cuda.manual_seed(rand_seed) +np.random.seed(rand_seed) + +from torch.utils.tensorboard import SummaryWriter + +from make_network import make_network +from get_the_data import get_the_data +from loss_function import loss_function +from make_optimize import make_optimize + + +def main( + lr_initial_nnmf: float = 0.01, + lr_initial_cnn: float = 0.01, + iterations: int = 25, + heads: int = 12, + dataset: str = "CIFAR10", # "CIFAR10", "FashionMNIST", "MNIST" + only_print_network: bool = False, + da_auto_mode: bool = False, +) -> None: + + lr_limit: float = 1e-9 + + torch_device: torch.device = ( + torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + ) + torch.set_default_dtype(torch.float32) + + # Some parameters + batch_size_train: int = 500 + batch_size_test: int = 500 + number_of_epoch: int = 5000 + + prefix = "" + + loss_mode: int = 0 + loss_coeffs_mse: float = 0.5 + loss_coeffs_kldiv: float = 1.0 + print( + "loss_mode: ", + loss_mode, + "loss_coeffs_mse: ", + loss_coeffs_mse, + "loss_coeffs_kldiv: ", + loss_coeffs_kldiv, + ) + + if dataset == "MNIST" or dataset == "FashionMNIST": + input_number_of_channel: int = 1 + input_dim_x: int = 24 + input_dim_y: int = 24 + else: + input_number_of_channel = 3 + input_dim_x = 28 + input_dim_y = 28 + + train_dataloader, test_dataloader, train_processing_chain, test_processing_chain = ( + get_the_data( + dataset, + batch_size_train, + batch_size_test, + torch_device, + input_dim_x, + input_dim_y, + flip_p=0.5, + jitter_brightness=0.5, + jitter_contrast=0.1, + jitter_saturation=0.1, + jitter_hue=0.15, + da_auto_mode=da_auto_mode, + ) + ) + + network = make_network( + in_channels=input_number_of_channel, + dims=[72, 72, 72], + embed_dims=[192, 192, 192], + n_classes=10, + heads=heads, + example_image_shape=[1, input_number_of_channel, input_dim_x, input_dim_y], + dtype=torch.float32, + device=torch_device, + iterations=iterations, + ) + print(network) + + print() + print("Information about used parameters:") + + parameter_list: list[list] = [] + parameter_list.append([]) + parameter_list.append([]) + + number_of_parameter: int = 0 + for name, param in network.named_parameters(): + + if name.find("NNMF") == -1: + parameter_list[0].append(param) + else: + parameter_list[1].append(param) + print("!!! NNMF !!! ", end=" ") + + print(f"Parameter name: {name}, Shape: {param.shape}") + number_of_parameter += param.numel() + print() + print("Number of total parameters:", number_of_parameter) + print("Number of parameter sets in CNN:", len(parameter_list[0])) + print("Number of parameter sets in NNMF:", len(parameter_list[1])) + + if only_print_network: + exit() + + ( + optimizers, + lr_schedulers, + ) = make_optimize( + parameters=parameter_list, + lr_initial=[ + lr_initial_cnn, + lr_initial_nnmf, + ], + ) + + my_string: str = "_lr_" + for i in range(0, len(lr_schedulers)): + if lr_schedulers[i] is not None: + my_string += f"{lr_schedulers[i].get_last_lr()[0]:.4e}_" # type: ignore + else: + my_string += "-_" + + default_path: str = f"{prefix}_iter{iterations}{my_string}" + log_dir: str = f"log_{default_path}" + + tb = SummaryWriter(log_dir=log_dir) + + for epoch_id in range(0, number_of_epoch): + print() + print(f"Epoch: {epoch_id}") + t_start: float = time.perf_counter() + + train_loss: float = 0.0 + train_correct: int = 0 + train_number: int = 0 + test_correct: int = 0 + test_number: int = 0 + + # Switch the network into training mode + network.train() + + # This runs in total for one epoch split up into mini-batches + for image, target in train_dataloader: + + # Clean the gradient + for i in range(0, len(optimizers)): + if optimizers[i] is not None: + optimizers[i].zero_grad() # type: ignore + + output = network(train_processing_chain(image)) + + loss = loss_function( + h=output, + labels=target, + number_of_output_neurons=output.shape[1], + loss_mode=loss_mode, + loss_coeffs_mse=loss_coeffs_mse, + loss_coeffs_kldiv=loss_coeffs_kldiv, + ) + + assert loss is not None + train_loss += loss.item() + train_correct += (output.argmax(dim=1) == target).sum().cpu().numpy() + train_number += target.shape[0] + + # Calculate backprop + loss.backward() + + # Update the parameter + # Clean the gradient + for i in range(0, len(optimizers)): + if optimizers[i] is not None: + optimizers[i].step() # type: ignore + + perfomance_train_correct: float = 100.0 * train_correct / train_number + # Update the learning rate + for i in range(0, len(lr_schedulers)): + if lr_schedulers[i] is not None: + lr_schedulers[i].step(train_loss) # type: ignore + + my_string = "Actual lr: " + for i in range(0, len(lr_schedulers)): + if lr_schedulers[i] is not None: + my_string += f" {lr_schedulers[i].get_last_lr()[0]:.4e} " # type: ignore + else: + my_string += " --- " + + print(my_string) + t_training: float = time.perf_counter() + + # Switch the network into evalution mode + network.eval() + + with torch.no_grad(): + + for image, target in test_dataloader: + output = network(test_processing_chain(image)) + + test_correct += (output.argmax(dim=1) == target).sum().cpu().numpy() + test_number += target.shape[0] + + t_testing = time.perf_counter() + + perfomance_test_correct: float = 100.0 * test_correct / test_number + + tb.add_scalar("Train Loss", train_loss / float(train_number), epoch_id) + tb.add_scalar("Train Number Correct", train_correct, epoch_id) + tb.add_scalar("Test Number Correct", test_correct, epoch_id) + + print( + f"Training: Loss={train_loss / float(train_number):.5f} Correct={perfomance_train_correct:.2f}%" + ) + print(f"Testing: Correct={perfomance_test_correct:.2f}%") + print( + f"Time: Training={(t_training - t_start):.1f}sec, Testing={(t_testing - t_training):.1f}sec" + ) + + tb.flush() + + lr_check: list[float] = [] + for i in range(0, len(lr_schedulers)): + if lr_schedulers[i] is not None: + lr_check.append(lr_schedulers[i].get_last_lr()[0]) # type: ignore + + lr_check_max = float(torch.tensor(lr_check).max()) + + if lr_check_max < lr_limit: + torch.save(network, f"Model_{default_path}.pt") + tb.close() + print("Done (lr_limit)") + return + + torch.save(network, f"Model_{default_path}.pt") + print() + + tb.close() + print("Done (loop end)") + + return + + +if __name__ == "__main__": + argh.dispatch_command(main)