diff --git a/MLP_equivalent/L1NormLayer.py b/MLP_equivalent/L1NormLayer.py new file mode 100644 index 0000000..6816b3a --- /dev/null +++ b/MLP_equivalent/L1NormLayer.py @@ -0,0 +1,13 @@ +import torch + + +class L1NormLayer(torch.nn.Module): + + epsilon: float + + def __init__(self, epsilon: float = 10e-20) -> None: + super().__init__() + self.epsilon = epsilon + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return input / (input.sum(dim=1, keepdim=True) + self.epsilon) diff --git a/MLP_equivalent/NNMF2d.py b/MLP_equivalent/NNMF2d.py new file mode 100644 index 0000000..b84d083 --- /dev/null +++ b/MLP_equivalent/NNMF2d.py @@ -0,0 +1,252 @@ +import torch +from non_linear_weigth_function import non_linear_weigth_function + + +class NNMF2d(torch.nn.Module): + + in_channels: int + out_channels: int + weight: torch.Tensor + iterations: int + epsilon: float | None + init_min: float + init_max: float + beta: torch.Tensor | None + positive_function_type: int + local_learning: bool + local_learning_kl: bool + + def __init__( + self, + in_channels: int, + out_channels: int, + device=None, + dtype=None, + iterations: int = 20, + epsilon: float | None = None, + init_min: float = 0.0, + init_max: float = 1.0, + beta: float | None = None, + positive_function_type: int = 0, + local_learning: bool = False, + local_learning_kl: bool = False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + + super().__init__() + + self.positive_function_type = positive_function_type + self.init_min = init_min + self.init_max = init_max + + self.in_channels = in_channels + self.out_channels = out_channels + + self.iterations = iterations + self.local_learning = local_learning + self.local_learning_kl = local_learning_kl + + self.weight = torch.nn.parameter.Parameter( + torch.empty((out_channels, in_channels), **factory_kwargs) + ) + + if beta is not None: + self.beta = torch.nn.parameter.Parameter(torch.empty((1), **factory_kwargs)) + self.beta.data[0] = beta + else: + self.beta = None + + self.reset_parameters() + self.functional_nnmf2d = FunctionalNNMF2d.apply + + self.epsilon = epsilon + + def extra_repr(self) -> str: + s: str = f"{self.in_channels}, {self.out_channels}" + + if self.epsilon is not None: + s += f", epsilon={self.epsilon}" + s += f", pfunctype={self.positive_function_type}" + s += f", local_learning={self.local_learning}" + + if self.local_learning: + s += f", local_learning_kl={self.local_learning_kl}" + + return s + + def reset_parameters(self) -> None: + torch.nn.init.uniform_(self.weight, a=self.init_min, b=self.init_max) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + + positive_weights = non_linear_weigth_function( + self.weight, self.beta, self.positive_function_type + ) + positive_weights = positive_weights / ( + positive_weights.sum(dim=1, keepdim=True) + 10e-20 + ) + + h_dyn = self.functional_nnmf2d( + input, + positive_weights, + self.out_channels, + self.iterations, + self.epsilon, + self.local_learning, + self.local_learning_kl, + ) + + return h_dyn + + +class FunctionalNNMF2d(torch.autograd.Function): + @staticmethod + def forward( # type: ignore + ctx, + input: torch.Tensor, + weight: torch.Tensor, + out_channels: int, + iterations: int, + epsilon: float | None, + local_learning: bool, + local_learning_kl: bool, + ) -> torch.Tensor: + + # Prepare h + h = torch.full( + (input.shape[0], out_channels, input.shape[-2], input.shape[-1]), + 1.0 / float(out_channels), + device=input.device, + dtype=input.dtype, + ) + + h = h.movedim(1, -1) + input = input.movedim(1, -1) + for _ in range(0, iterations): + reconstruction = torch.nn.functional.linear(h, weight.T) + reconstruction += 1e-20 + if epsilon is None: + h *= torch.nn.functional.linear((input / reconstruction), weight) + else: + h *= 1 + epsilon * torch.nn.functional.linear( + (input / reconstruction), weight + ) + h /= h.sum(-1, keepdim=True) + 10e-20 + h = h.movedim(-1, 1) + input = input.movedim(-1, 1) + + # ########################################################### + # Save the necessary data for the backward pass + # ########################################################### + ctx.save_for_backward(input, weight, h) + ctx.local_learning = local_learning + ctx.local_learning_kl = local_learning_kl + + assert torch.isfinite(h).all() + return h + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, grad_output: torch.Tensor) -> tuple[ # type: ignore + torch.Tensor, + torch.Tensor | None, + None, + None, + None, + None, + None, + ]: + + # ############################################## + # Default values + # ############################################## + grad_weight: torch.Tensor | None = None + + # ############################################## + # Get the variables back + # ############################################## + (input, weight, h) = ctx.saved_tensors + + # The back prop gradient + h = h.movedim(1, -1) + grad_output = grad_output.movedim(1, -1) + input = input.movedim(1, -1) + big_r = torch.nn.functional.linear(h, weight.T) + big_r_div = 1.0 / (big_r + 1e-20) + + factor_x_div_r = input * big_r_div + + grad_input: torch.Tensor = ( + torch.nn.functional.linear(h * grad_output, weight.T) * big_r_div + ) + + del big_r_div + + # The weight gradient + if ctx.local_learning is False: + del big_r + + grad_weight = -torch.nn.functional.linear( + h.reshape( + grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2], + h.shape[3], + ).T, + (factor_x_div_r * grad_input) + .reshape( + grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2], + grad_input.shape[3], + ) + .T, + ) + + grad_weight += torch.nn.functional.linear( + (h * grad_output) + .reshape( + grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2], + h.shape[3], + ) + .T, + factor_x_div_r.reshape( + grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2], + grad_input.shape[3], + ).T, + ) + + else: + if ctx.local_learning_kl: + grad_weight = -torch.nn.functional.linear( + h.reshape( + grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2], + h.shape[3], + ).T, + factor_x_div_r.reshape( + grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2], + grad_input.shape[3], + ).T, + ) + else: + grad_weight = -torch.nn.functional.linear( + h.reshape( + grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2], + h.shape[3], + ).T, + (2 * (input - big_r)) + .reshape( + grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2], + grad_input.shape[3], + ) + .T, + ) + grad_input = grad_input.movedim(-1, 1) + assert torch.isfinite(grad_input).all() + assert torch.isfinite(grad_weight).all() + + return ( + grad_input, + grad_weight, + None, + None, + None, + None, + None, + ) diff --git a/MLP_equivalent/append_block.py b/MLP_equivalent/append_block.py new file mode 100644 index 0000000..b6796c4 --- /dev/null +++ b/MLP_equivalent/append_block.py @@ -0,0 +1,151 @@ +import torch +from L1NormLayer import L1NormLayer +from append_parameter import append_parameter + + +def append_block( + network: torch.nn.Sequential, + out_channels: int, + test_image: torch.Tensor, + parameter_cnn_top: list[torch.nn.parameter.Parameter], + parameter_nnmf: list[torch.nn.parameter.Parameter], + parameter_norm: list[torch.nn.parameter.Parameter], + torch_device: torch.device, + dilation: tuple[int, int] | int = 1, + padding: tuple[int, int] | int = 0, + stride: tuple[int, int] | int = 1, + kernel_size: tuple[int, int] = (5, 5), + epsilon: float | None = None, + positive_function_type: int = 0, + beta: float | None = None, + iterations: int = 20, + local_learning: bool = False, + local_learning_kl: bool = False, + momentum: float = 0.1, + track_running_stats: bool = False, + last_layer: bool= False, +) -> torch.Tensor: + + kernel_size_internal: list[int] = [kernel_size[-2], kernel_size[-1]] + + if kernel_size[0] < 1: + kernel_size_internal[0] = test_image.shape[-2] + + if kernel_size[1] < 1: + kernel_size_internal[1] = test_image.shape[-1] + + # Main + network.append(torch.nn.ReLU()) + test_image = network[-1](test_image) + + # I need the output size + mock_output = ( + torch.nn.functional.conv2d( + torch.zeros( + 1, + 1, + test_image.shape[2], + test_image.shape[3], + ), + torch.zeros((1, 1, kernel_size_internal[0], kernel_size_internal[1])), + stride=stride, + padding=padding, + dilation=dilation, + ) + .squeeze(0) + .squeeze(0) + ) + network.append( + torch.nn.Unfold( + kernel_size=(kernel_size_internal[-2], kernel_size_internal[-1]), + dilation=dilation, + padding=padding, + stride=stride, + ) + ) + test_image = network[-1](test_image) + + network.append( + torch.nn.Fold( + output_size=mock_output.shape, + kernel_size=(1, 1), + dilation=1, + padding=0, + stride=1, + ) + ) + test_image = network[-1](test_image) + + network.append(L1NormLayer()) + test_image = network[-1](test_image) + + network.append( + torch.nn.Conv2d( + in_channels=test_image.shape[1], + out_channels=out_channels, + kernel_size=(1, 1), + bias=False, + ).to(torch_device) + ) + test_image = network[-1](test_image) + append_parameter(module=network[-1], parameter_list=parameter_nnmf) + + if (test_image.shape[-1] > 1) or (test_image.shape[-2] > 1): + network.append( + torch.nn.BatchNorm2d( + num_features=test_image.shape[1], + momentum=momentum, + track_running_stats=track_running_stats, + device=torch_device, + ) + ) + test_image = network[-1](test_image) + append_parameter(module=network[-1], parameter_list=parameter_norm) + + if last_layer is False: + + network.append(torch.nn.ReLU()) + test_image = network[-1](test_image) + + + network.append( + torch.nn.Conv2d( + in_channels=test_image.shape[1], + out_channels=out_channels, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + bias=True, + device=torch_device, + ) + ) + # Init the cnn top layers 1x1 conv2d layers + for name, param in network[-1].named_parameters(): + with torch.no_grad(): + if name == "bias": + param.data *= 0 + if name == "weight": + assert param.shape[-2] == 1 + assert param.shape[-1] == 1 + param[: param.shape[0], : param.shape[0], 0, 0] = torch.eye( + param.shape[0], dtype=param.dtype, device=param.device + ) + param[param.shape[0] :, :, 0, 0] = 0 + param[:, param.shape[0] :, 0, 0] = 0 + + test_image = network[-1](test_image) + append_parameter(module=network[-1], parameter_list=parameter_cnn_top) + + if (test_image.shape[-1] > 1) or (test_image.shape[-2] > 1): + network.append( + torch.nn.BatchNorm2d( + num_features=test_image.shape[1], + device=torch_device, + momentum=momentum, + track_running_stats=track_running_stats, + ) + ) + test_image = network[-1](test_image) + append_parameter(module=network[-1], parameter_list=parameter_norm) + + return test_image diff --git a/MLP_equivalent/append_parameter.py b/MLP_equivalent/append_parameter.py new file mode 100644 index 0000000..b972e39 --- /dev/null +++ b/MLP_equivalent/append_parameter.py @@ -0,0 +1,8 @@ +import torch + + +def append_parameter( + module: torch.nn.Module, parameter_list: list[torch.nn.parameter.Parameter] +): + for netp in module.parameters(): + parameter_list.append(netp) diff --git a/MLP_equivalent/convert_log_to_numpy.py b/MLP_equivalent/convert_log_to_numpy.py new file mode 100644 index 0000000..6a1343a --- /dev/null +++ b/MLP_equivalent/convert_log_to_numpy.py @@ -0,0 +1,30 @@ +import os +import glob + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + +from tensorboard.backend.event_processing import event_accumulator # type: ignore +import numpy as np + + +def get_data(path: str = "log_cnn"): + acc = event_accumulator.EventAccumulator(path) + acc.Reload() + + which_scalar = "Test Number Correct" + te = acc.Scalars(which_scalar) + + np_temp = np.zeros((len(te), 2)) + + for id in range(0, len(te)): + np_temp[id, 0] = te[id].step + np_temp[id, 1] = te[id].value + + print(np_temp[:, 1] / 100) + return np_temp + + +for path in glob.glob("log_*"): + print(path) + data = get_data(path) + np.save("data_" + path + ".npy", data) diff --git a/MLP_equivalent/data_loader.py b/MLP_equivalent/data_loader.py new file mode 100644 index 0000000..0a0d430 --- /dev/null +++ b/MLP_equivalent/data_loader.py @@ -0,0 +1,31 @@ +import torch + + +def data_loader( + pattern: torch.Tensor, + labels: torch.Tensor, + worker_init_fn, + generator, + batch_size: int = 128, + shuffle: bool = True, + torch_device: torch.device = torch.device("cpu"), +) -> torch.utils.data.dataloader.DataLoader: + + assert pattern.ndim >= 3 + + pattern_storage: torch.Tensor = pattern.to(torch_device).type(torch.float32) + if pattern_storage.ndim == 3: + pattern_storage = pattern_storage.unsqueeze(1) + pattern_storage /= pattern_storage.max() + + label_storage: torch.Tensor = labels.to(torch_device).type(torch.int64) + + dataloader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset(pattern_storage, label_storage), + batch_size=batch_size, + shuffle=shuffle, + worker_init_fn=worker_init_fn, + generator=generator, + ) + + return dataloader diff --git a/MLP_equivalent/get_the_data.py b/MLP_equivalent/get_the_data.py new file mode 100644 index 0000000..fc61064 --- /dev/null +++ b/MLP_equivalent/get_the_data.py @@ -0,0 +1,147 @@ +import torch +import torchvision # type: ignore +from data_loader import data_loader + +from torchvision.transforms import v2 # type: ignore +import numpy as np + + +def get_the_data( + dataset: str, + batch_size_train: int, + batch_size_test: int, + torch_device: torch.device, + input_dim_x: int, + input_dim_y: int, + flip_p: float = 0.5, + jitter_brightness: float = 0.5, + jitter_contrast: float = 0.1, + jitter_saturation: float = 0.1, + jitter_hue: float = 0.15, + da_auto_mode: bool = False, +) -> tuple[ + torch.utils.data.dataloader.DataLoader, + torch.utils.data.dataloader.DataLoader, + torchvision.transforms.Compose, + torchvision.transforms.Compose, +]: + if dataset == "MNIST": + tv_dataset_train = torchvision.datasets.MNIST( + root="data", train=True, download=True + ) + tv_dataset_test = torchvision.datasets.MNIST( + root="data", train=False, download=True + ) + elif dataset == "FashionMNIST": + tv_dataset_train = torchvision.datasets.FashionMNIST( + root="data", train=True, download=True + ) + tv_dataset_test = torchvision.datasets.FashionMNIST( + root="data", train=False, download=True + ) + elif dataset == "CIFAR10": + tv_dataset_train = torchvision.datasets.CIFAR10( + root="data", train=True, download=True + ) + tv_dataset_test = torchvision.datasets.CIFAR10( + root="data", train=False, download=True + ) + else: + raise NotImplementedError("This dataset is not implemented.") + + def seed_worker(worker_id): + worker_seed = torch.initial_seed() % 2**32 + np.random.seed(worker_seed) + torch.random.seed(worker_seed) + + g = torch.Generator() + g.manual_seed(0) + + if dataset == "MNIST" or dataset == "FashionMNIST": + + train_dataloader = data_loader( + torch_device=torch_device, + batch_size=batch_size_train, + pattern=tv_dataset_train.data, + labels=tv_dataset_train.targets, + shuffle=True, + worker_init_fn=seed_worker, + generator=g, + ) + + test_dataloader = data_loader( + torch_device=torch_device, + batch_size=batch_size_test, + pattern=tv_dataset_test.data, + labels=tv_dataset_test.targets, + shuffle=False, + worker_init_fn=seed_worker, + generator=g, + ) + + # Data augmentation filter + test_processing_chain = torchvision.transforms.Compose( + transforms=[torchvision.transforms.CenterCrop((input_dim_x, input_dim_y))], + ) + + train_processing_chain = torchvision.transforms.Compose( + transforms=[torchvision.transforms.RandomCrop((input_dim_x, input_dim_y))], + ) + else: + + train_dataloader = data_loader( + torch_device=torch_device, + batch_size=batch_size_train, + pattern=torch.tensor(tv_dataset_train.data).movedim(-1, 1), + labels=torch.tensor(tv_dataset_train.targets), + shuffle=True, + worker_init_fn=seed_worker, + generator=g, + ) + + test_dataloader = data_loader( + torch_device=torch_device, + batch_size=batch_size_test, + pattern=torch.tensor(tv_dataset_test.data).movedim(-1, 1), + labels=torch.tensor(tv_dataset_test.targets), + shuffle=False, + worker_init_fn=seed_worker, + generator=g, + ) + + # Data augmentation filter + test_processing_chain = torchvision.transforms.Compose( + transforms=[torchvision.transforms.CenterCrop((input_dim_x, input_dim_y))], + ) + + if da_auto_mode: + train_processing_chain = torchvision.transforms.Compose( + transforms=[ + v2.AutoAugment( + policy=torchvision.transforms.AutoAugmentPolicy( + v2.AutoAugmentPolicy.CIFAR10 + ) + ), + torchvision.transforms.CenterCrop((input_dim_x, input_dim_y)), + ], + ) + else: + train_processing_chain = torchvision.transforms.Compose( + transforms=[ + torchvision.transforms.RandomCrop((input_dim_x, input_dim_y)), + torchvision.transforms.RandomHorizontalFlip(p=flip_p), + torchvision.transforms.ColorJitter( + brightness=jitter_brightness, + contrast=jitter_contrast, + saturation=jitter_saturation, + hue=jitter_hue, + ), + ], + ) + + return ( + train_dataloader, + test_dataloader, + train_processing_chain, + test_processing_chain, + ) diff --git a/MLP_equivalent/loss_function.py b/MLP_equivalent/loss_function.py new file mode 100644 index 0000000..e256840 --- /dev/null +++ b/MLP_equivalent/loss_function.py @@ -0,0 +1,64 @@ +import torch + + +# loss_mode == 0: "normal" SbS loss function mixture +# loss_mode == 1: cross_entropy +def loss_function( + h: torch.Tensor, + labels: torch.Tensor, + loss_mode: int = 0, + number_of_output_neurons: int = 10, + loss_coeffs_mse: float = 0.0, + loss_coeffs_kldiv: float = 0.0, +) -> torch.Tensor | None: + + assert loss_mode >= 0 + assert loss_mode <= 1 + + assert h.ndim == 2 + + if loss_mode == 0: + + # Convert label into one hot + target_one_hot: torch.Tensor = torch.zeros( + ( + labels.shape[0], + number_of_output_neurons, + ), + device=h.device, + dtype=h.dtype, + ) + + target_one_hot.scatter_( + 1, + labels.to(h.device).unsqueeze(1), + torch.ones( + (labels.shape[0], 1), + device=h.device, + dtype=h.dtype, + ), + ) + + my_loss: torch.Tensor = ((h - target_one_hot) ** 2).sum(dim=0).mean( + dim=0 + ) * loss_coeffs_mse + + my_loss = ( + my_loss + + ( + (target_one_hot * torch.log((target_one_hot + 1e-20) / (h + 1e-20))) + .sum(dim=0) + .mean(dim=0) + ) + * loss_coeffs_kldiv + ) + + my_loss = my_loss / (abs(loss_coeffs_kldiv) + abs(loss_coeffs_mse)) + + return my_loss + + elif loss_mode == 1: + my_loss = torch.nn.functional.cross_entropy(h, labels.to(h.device)) + return my_loss + else: + return None diff --git a/MLP_equivalent/make_network.py b/MLP_equivalent/make_network.py new file mode 100644 index 0000000..335bc3c --- /dev/null +++ b/MLP_equivalent/make_network.py @@ -0,0 +1,208 @@ +import torch +from append_block import append_block +from L1NormLayer import L1NormLayer +from append_parameter import append_parameter + + +def make_network( + input_dim_x: int, + input_dim_y: int, + input_number_of_channel: int, + iterations: int, + torch_device: torch.device, + epsilon: bool | None = None, + positive_function_type: int = 0, + beta: float | None = None, + # Conv: + number_of_output_channels: list[int] = [32 * 1, 64 * 1, 96 * 1, 10], + kernel_size_conv: list[tuple[int, int]] = [ + (5, 5), + (5, 5), + (-1, -1), # Take the whole input image x and y size + (1, 1), + ], + stride_conv: list[tuple[int, int]] = [ + (1, 1), + (1, 1), + (1, 1), + (1, 1), + ], + padding_conv: list[tuple[int, int]] = [ + (0, 0), + (0, 0), + (0, 0), + (0, 0), + ], + dilation_conv: list[tuple[int, int]] = [ + (1, 1), + (1, 1), + (1, 1), + (1, 1), + ], + # Pool: + kernel_size_pool: list[tuple[int, int]] = [ + (2, 2), + (2, 2), + (-1, -1), # No pooling layer + (-1, -1), # No pooling layer + ], + stride_pool: list[tuple[int, int]] = [ + (2, 2), + (2, 2), + (-1, -1), + (-1, -1), + ], + padding_pool: list[tuple[int, int]] = [ + (0, 0), + (0, 0), + (0, 0), + (0, 0), + ], + dilation_pool: list[tuple[int, int]] = [ + (1, 1), + (1, 1), + (1, 1), + (1, 1), + ], + enable_onoff: bool = False, +) -> tuple[ + torch.nn.Sequential, + list[list[torch.nn.parameter.Parameter]], + list[str], +]: + + assert len(number_of_output_channels) == len(kernel_size_conv) + assert len(number_of_output_channels) == len(stride_conv) + assert len(number_of_output_channels) == len(padding_conv) + assert len(number_of_output_channels) == len(dilation_conv) + assert len(number_of_output_channels) == len(kernel_size_pool) + assert len(number_of_output_channels) == len(stride_pool) + assert len(number_of_output_channels) == len(padding_pool) + assert len(number_of_output_channels) == len(dilation_pool) + + if enable_onoff: + input_number_of_channel *= 2 + + parameter_cnn_top: list[torch.nn.parameter.Parameter] = [] + parameter_nnmf: list[torch.nn.parameter.Parameter] = [] + parameter_norm: list[torch.nn.parameter.Parameter] = [] + + test_image = torch.ones( + (1, input_number_of_channel, input_dim_x, input_dim_y), device=torch_device + ) + + network = torch.nn.Sequential() + network = network.to(torch_device) + + for block_id in range(0, len(number_of_output_channels)): + + test_image = append_block( + network=network, + out_channels=number_of_output_channels[block_id], + test_image=test_image, + dilation=dilation_conv[block_id], + padding=padding_conv[block_id], + stride=stride_conv[block_id], + kernel_size=kernel_size_conv[block_id], + epsilon=epsilon, + positive_function_type=positive_function_type, + beta=beta, + iterations=iterations, + torch_device=torch_device, + parameter_cnn_top=parameter_cnn_top, + parameter_nnmf=parameter_nnmf, + parameter_norm=parameter_norm, + last_layer = block_id == len(number_of_output_channels)-1, + ) + + if (kernel_size_pool[block_id][0] > 0) and (kernel_size_pool[block_id][1] > 0): + network.append(torch.nn.ReLU()) + test_image = network[-1](test_image) + + mock_output = ( + torch.nn.functional.conv2d( + torch.zeros( + 1, + 1, + test_image.shape[2], + test_image.shape[3], + ), + torch.zeros((1, 1, 2, 2)), + stride=(2, 2), + padding=(0, 0), + dilation=(1, 1), + ) + .squeeze(0) + .squeeze(0) + ) + + network.append( + torch.nn.Unfold( + kernel_size=(2, 2), + stride=(2, 2), + padding=(0, 0), + dilation=(1, 1), + ) + ) + test_image = network[-1](test_image) + + network.append( + torch.nn.Fold( + output_size=mock_output.shape, + kernel_size=(1, 1), + dilation=1, + padding=0, + stride=1, + ) + ) + test_image = network[-1](test_image) + + network.append(L1NormLayer()) + test_image = network[-1](test_image) + + network.append( + torch.nn.Conv2d( + in_channels=test_image.shape[1], + out_channels=test_image.shape[1] // 4, + kernel_size=(1, 1), + bias=False, + ).to(torch_device) + ) + + test_image = network[-1](test_image) + append_parameter(module=network[-1], parameter_list=parameter_nnmf) + + network.append( + torch.nn.BatchNorm2d( + num_features=test_image.shape[1], + device=torch_device, + momentum=0.1, + track_running_stats=False, + ) + ) + test_image = network[-1](test_image) + append_parameter(module=network[-1], parameter_list=parameter_norm) + + network.append(torch.nn.Softmax(dim=1)) + test_image = network[-1](test_image) + + network.append(torch.nn.Flatten()) + test_image = network[-1](test_image) + + parameters: list[list[torch.nn.parameter.Parameter]] = [ + parameter_cnn_top, + parameter_nnmf, + parameter_norm, + ] + + name_list: list[str] = [ + "cnn_top", + "nnmf", + "batchnorm2d", + ] + + return ( + network, + parameters, + name_list, + ) diff --git a/MLP_equivalent/make_optimize.py b/MLP_equivalent/make_optimize.py new file mode 100644 index 0000000..ab1a4e0 --- /dev/null +++ b/MLP_equivalent/make_optimize.py @@ -0,0 +1,32 @@ +import torch + + +def make_optimize( + parameters: list[list[torch.nn.parameter.Parameter]], + lr_initial: list[float], + eps=1e-10, +) -> tuple[ + list[torch.optim.Adam | None], + list[torch.optim.lr_scheduler.ReduceLROnPlateau | None], +]: + list_optimizer: list[torch.optim.Adam | None] = [] + list_lr_scheduler: list[torch.optim.lr_scheduler.ReduceLROnPlateau | None] = [] + + assert len(parameters) == len(lr_initial) + + for i in range(0, len(parameters)): + if len(parameters[i]) > 0: + list_optimizer.append(torch.optim.Adam(parameters[i], lr=lr_initial[i])) + else: + list_optimizer.append(None) + + for i in range(0, len(list_optimizer)): + if list_optimizer[i] is not None: + pass + list_lr_scheduler.append( + torch.optim.lr_scheduler.ReduceLROnPlateau(list_optimizer[i], eps=eps) # type: ignore + ) + else: + list_lr_scheduler.append(None) + + return (list_optimizer, list_lr_scheduler) diff --git a/MLP_equivalent/non_linear_weigth_function.py b/MLP_equivalent/non_linear_weigth_function.py new file mode 100644 index 0000000..053a9b6 --- /dev/null +++ b/MLP_equivalent/non_linear_weigth_function.py @@ -0,0 +1,26 @@ +import torch + + +def non_linear_weigth_function( + weight: torch.Tensor, beta: torch.Tensor | None, positive_function_type: int +) -> torch.Tensor: + + if positive_function_type == 0: + positive_weights = torch.abs(weight) + + elif positive_function_type == 1: + assert beta is not None + positive_weights = weight + max_value = torch.abs(positive_weights).max() + if max_value > 80: + positive_weights = 80.0 * positive_weights / max_value + positive_weights = torch.exp((torch.tanh(beta) + 1.0) * 0.5 * positive_weights) + + elif positive_function_type == 2: + assert beta is not None + positive_weights = (torch.tanh(beta * weight) + 1.0) * 0.5 + + else: + positive_weights = weight + + return positive_weights diff --git a/MLP_equivalent/plot.py b/MLP_equivalent/plot.py new file mode 100644 index 0000000..ad22d33 --- /dev/null +++ b/MLP_equivalent/plot.py @@ -0,0 +1,15 @@ +import numpy as np +import matplotlib.pyplot as plt + +data = np.load("data_log.npy") +plt.loglog( + data[:, 0], + 100.0 * (1.0 - data[:, 1] / 10000.0), + "k", +) + +plt.legend() +plt.xlabel("Epoch") +plt.ylabel("Error [%]") +plt.title("CIFAR10") +plt.show() diff --git a/MLP_equivalent/run_network.py b/MLP_equivalent/run_network.py new file mode 100644 index 0000000..e361a32 --- /dev/null +++ b/MLP_equivalent/run_network.py @@ -0,0 +1,251 @@ +import os + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + +import argh + +import time +import numpy as np +import torch + +rand_seed: int = 21 +torch.manual_seed(rand_seed) +torch.cuda.manual_seed(rand_seed) +np.random.seed(rand_seed) + +from torch.utils.tensorboard import SummaryWriter + +from make_network import make_network +from get_the_data import get_the_data +from loss_function import loss_function +from make_optimize import make_optimize + + +def main( + lr_initial_nnmf: float = 0.01, + lr_initial_cnn_top: float = 0.001, + lr_initial_norm: float = 0.001, + iterations: int = 20, + dataset: str = "CIFAR10", # "CIFAR10", "FashionMNIST", "MNIST" + only_print_network: bool = False, +) -> None: + + da_auto_mode: bool = False # Automatic Data Augmentation from TorchVision + lr_limit: float = 1e-9 + + torch_device: torch.device = ( + torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + ) + torch.set_default_dtype(torch.float32) + + # Some parameters + batch_size_train: int = 50 # 0 + batch_size_test: int = 50 # 0 + number_of_epoch: int = 5000 + + loss_mode: int = 0 + loss_coeffs_mse: float = 0.5 + loss_coeffs_kldiv: float = 1.0 + print( + "loss_mode: ", + loss_mode, + "loss_coeffs_mse: ", + loss_coeffs_mse, + "loss_coeffs_kldiv: ", + loss_coeffs_kldiv, + ) + + if dataset == "MNIST" or dataset == "FashionMNIST": + input_number_of_channel: int = 1 + input_dim_x: int = 24 + input_dim_y: int = 24 + else: + input_number_of_channel = 3 + input_dim_x = 28 + input_dim_y = 28 + + train_dataloader, test_dataloader, train_processing_chain, test_processing_chain = ( + get_the_data( + dataset, + batch_size_train, + batch_size_test, + torch_device, + input_dim_x, + input_dim_y, + flip_p=0.5, + jitter_brightness=0.5, + jitter_contrast=0.1, + jitter_saturation=0.1, + jitter_hue=0.15, + da_auto_mode=da_auto_mode, + ) + ) + + ( + network, + parameters, + name_list, + ) = make_network( + input_dim_x=input_dim_x, + input_dim_y=input_dim_y, + input_number_of_channel=input_number_of_channel, + iterations=iterations, + torch_device=torch_device, + ) + + print(network) + + print() + print("Information about used parameters:") + number_of_parameter: int = 0 + for i, parameter_list in enumerate(parameters): + count_parameter: int = 0 + for parameter_element in parameter_list: + count_parameter += parameter_element.numel() + print(f"{name_list[i]}: {count_parameter}") + number_of_parameter += count_parameter + print(f"total number of parameter: {number_of_parameter}") + + if only_print_network: + exit() + + ( + optimizers, + lr_schedulers, + ) = make_optimize( + parameters=parameters, + lr_initial=[ + lr_initial_cnn_top, + lr_initial_nnmf, + lr_initial_norm, + ], + ) + + my_string: str = "_lr_" + for i in range(0, len(lr_schedulers)): + if lr_schedulers[i] is not None: + my_string += f"{lr_schedulers[i].get_last_lr()[0]:.4e}_" # type: ignore + else: + my_string += "-_" + + default_path: str = f"iter{iterations}{my_string}" + log_dir: str = f"log_{default_path}" + + tb = SummaryWriter(log_dir=log_dir) + + for epoch_id in range(0, number_of_epoch): + print() + print(f"Epoch: {epoch_id}") + t_start: float = time.perf_counter() + + train_loss: float = 0.0 + train_correct: int = 0 + train_number: int = 0 + test_correct: int = 0 + test_number: int = 0 + + # Switch the network into training mode + network.train() + + # This runs in total for one epoch split up into mini-batches + for image, target in train_dataloader: + + # Clean the gradient + for i in range(0, len(optimizers)): + if optimizers[i] is not None: + optimizers[i].zero_grad() # type: ignore + + output = network(train_processing_chain(image)) + + loss = loss_function( + h=output, + labels=target, + number_of_output_neurons=output.shape[1], + loss_mode=loss_mode, + loss_coeffs_mse=loss_coeffs_mse, + loss_coeffs_kldiv=loss_coeffs_kldiv, + ) + + assert loss is not None + train_loss += loss.item() + train_correct += (output.argmax(dim=1) == target).sum().cpu().numpy() + train_number += target.shape[0] + + # Calculate backprop + loss.backward() + + # Update the parameter + # Clean the gradient + for i in range(0, len(optimizers)): + if optimizers[i] is not None: + optimizers[i].step() # type: ignore + + perfomance_train_correct: float = 100.0 * train_correct / train_number + # Update the learning rate + for i in range(0, len(lr_schedulers)): + if lr_schedulers[i] is not None: + lr_schedulers[i].step(train_loss) # type: ignore + + my_string = "Actual lr: " + for i in range(0, len(lr_schedulers)): + if lr_schedulers[i] is not None: + my_string += f" {lr_schedulers[i].get_last_lr()[0]:.4e} " # type: ignore + else: + my_string += " --- " + + print(my_string) + t_training: float = time.perf_counter() + + # Switch the network into evalution mode + network.eval() + + with torch.no_grad(): + + for image, target in test_dataloader: + output = network(test_processing_chain(image)) + + test_correct += (output.argmax(dim=1) == target).sum().cpu().numpy() + test_number += target.shape[0] + + t_testing = time.perf_counter() + + perfomance_test_correct: float = 100.0 * test_correct / test_number + + tb.add_scalar("Train Loss", train_loss / float(train_number), epoch_id) + tb.add_scalar("Train Number Correct", train_correct, epoch_id) + tb.add_scalar("Test Number Correct", test_correct, epoch_id) + + print( + f"Training: Loss={train_loss / float(train_number):.5f} Correct={perfomance_train_correct:.2f}%" + ) + print(f"Testing: Correct={perfomance_test_correct:.2f}%") + print( + f"Time: Training={(t_training - t_start):.1f}sec, Testing={(t_testing - t_training):.1f}sec" + ) + + tb.flush() + + lr_check: list[float] = [] + for i in range(0, len(lr_schedulers)): + if lr_schedulers[i] is not None: + lr_check.append(lr_schedulers[i].get_last_lr()[0]) # type: ignore + + lr_check_max = float(torch.tensor(lr_check).max()) + + if lr_check_max < lr_limit: + torch.save(network, f"Model_{default_path}.pt") + tb.close() + print("Done (lr_limit)") + return + + torch.save(network, f"Model_{default_path}.pt") + print() + + tb.close() + print("Done (loop end)") + + return + + +if __name__ == "__main__": + argh.dispatch_command(main)