diff --git a/DATA_CIFAR10/PyTorch_Non_Spike_Network/Dataset.py b/DATA_CIFAR10/PyTorch_Non_Spike_Network/Dataset.py new file mode 100644 index 0000000..11f9854 --- /dev/null +++ b/DATA_CIFAR10/PyTorch_Non_Spike_Network/Dataset.py @@ -0,0 +1,422 @@ +# MIT License +# Copyright 2022 University of Bremen +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR +# THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# +# David Rotermund ( davrot@uni-bremen.de ) +# +# +# Release history: +# ================ +# 1.0.0 -- 01.05.2022: first release +# +# + +from abc import ABC, abstractmethod +import torch +import numpy as np +import torchvision as tv # type: ignore +from Parameter import Config + + +class DatasetMaster(torch.utils.data.Dataset, ABC): + + path_label: str + label_storage: np.ndarray + pattern_storage: np.ndarray + number_of_pattern: int + mean: list[float] + + # Initialize + def __init__( + self, + train: bool = False, + path_pattern: str = "./", + path_label: str = "./", + ) -> None: + super().__init__() + + if train is True: + self.label_storage = np.load(path_label + "/TrainLabelStorage.npy") + else: + self.label_storage = np.load(path_label + "/TestLabelStorage.npy") + + if train is True: + self.pattern_storage = np.load(path_pattern + "/TrainPatternStorage.npy") + else: + self.pattern_storage = np.load(path_pattern + "/TestPatternStorage.npy") + + self.number_of_pattern = self.label_storage.shape[0] + + self.mean = [] + + def __len__(self) -> int: + return self.number_of_pattern + + # Get one pattern at position index + @abstractmethod + def __getitem__(self, index: int) -> tuple[torch.Tensor, int]: + pass + + @abstractmethod + def pattern_filter_test(self, pattern: torch.Tensor, cfg: Config) -> torch.Tensor: + pass + + @abstractmethod + def pattern_filter_train(self, pattern: torch.Tensor, cfg: Config) -> torch.Tensor: + pass + + +class DatasetMNIST(DatasetMaster): + """Contstructor""" + + # Initialize + def __init__( + self, + train: bool = False, + path_pattern: str = "./", + path_label: str = "./", + ) -> None: + super().__init__(train, path_pattern, path_label) + + self.pattern_storage = np.ascontiguousarray( + self.pattern_storage[:, np.newaxis, :, :].astype(dtype=np.float32) + ) + + self.pattern_storage /= np.max(self.pattern_storage) + + mean = self.pattern_storage.mean(3).mean(2).mean(0) + self.mean = [*mean] + + def __getitem__(self, index: int) -> tuple[torch.Tensor, int]: + + image = self.pattern_storage[index, 0:1, :, :] + target = int(self.label_storage[index]) + return torch.tensor(image), target + + def pattern_filter_test(self, pattern: torch.Tensor, cfg: Config) -> torch.Tensor: + """0. The test image comes in + 1. is center cropped + 2. on/off filteres + 3. returned. + + This is a 1 channel version (e.g. one gray channel). + """ + + assert len(cfg.image_statistics.mean) == 1 + assert len(cfg.image_statistics.the_size) == 2 + assert cfg.image_statistics.the_size[0] > 0 + assert cfg.image_statistics.the_size[1] > 0 + + # Transformation chain + my_transforms: torch.nn.Sequential = torch.nn.Sequential( + tv.transforms.CenterCrop(size=cfg.image_statistics.the_size), + ) + scripted_transforms = torch.jit.script(my_transforms) + + # Preprocess the input data + pattern = scripted_transforms(pattern) + + # => On/Off + my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0]) + gray: torch.Tensor = my_on_off_filter( + pattern[:, 0:1, :, :], + ) + + return gray + + def pattern_filter_train(self, pattern: torch.Tensor, cfg: Config) -> torch.Tensor: + """0. The training image comes in + 1. is cropped from a random position + 2. on/off filteres + 3. returned. + + This is a 1 channel version (e.g. one gray channel). + """ + + assert len(cfg.image_statistics.mean) == 1 + assert len(cfg.image_statistics.the_size) == 2 + assert cfg.image_statistics.the_size[0] > 0 + assert cfg.image_statistics.the_size[1] > 0 + + # Transformation chain + my_transforms: torch.nn.Sequential = torch.nn.Sequential( + tv.transforms.RandomCrop(size=cfg.image_statistics.the_size), + ) + scripted_transforms = torch.jit.script(my_transforms) + + # Preprocess the input data + pattern = scripted_transforms(pattern) + + # => On/Off + my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0]) + gray: torch.Tensor = my_on_off_filter( + pattern[:, 0:1, :, :], + ) + + return gray + + +class DatasetFashionMNIST(DatasetMaster): + """Contstructor""" + + # Initialize + def __init__( + self, + train: bool = False, + path_pattern: str = "./", + path_label: str = "./", + ) -> None: + super().__init__(train, path_pattern, path_label) + + self.pattern_storage = np.ascontiguousarray( + self.pattern_storage[:, np.newaxis, :, :].astype(dtype=np.float32) + ) + + self.pattern_storage /= np.max(self.pattern_storage) + + mean = self.pattern_storage.mean(3).mean(2).mean(0) + self.mean = [*mean] + + def __getitem__(self, index: int) -> tuple[torch.Tensor, int]: + + image = self.pattern_storage[index, 0:1, :, :] + target = int(self.label_storage[index]) + return torch.tensor(image), target + + def pattern_filter_test(self, pattern: torch.Tensor, cfg: Config) -> torch.Tensor: + """0. The test image comes in + 1. is center cropped + 2. on/off filteres + 3. returned. + + This is a 1 channel version (e.g. one gray channel). + """ + + assert len(cfg.image_statistics.mean) == 1 + assert len(cfg.image_statistics.the_size) == 2 + assert cfg.image_statistics.the_size[0] > 0 + assert cfg.image_statistics.the_size[1] > 0 + + # Transformation chain + my_transforms: torch.nn.Sequential = torch.nn.Sequential( + tv.transforms.CenterCrop(size=cfg.image_statistics.the_size), + ) + scripted_transforms = torch.jit.script(my_transforms) + + # Preprocess the input data + pattern = scripted_transforms(pattern) + + # => On/Off + my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0]) + gray: torch.Tensor = my_on_off_filter( + pattern[:, 0:1, :, :], + ) + + return gray + + def pattern_filter_train(self, pattern: torch.Tensor, cfg: Config) -> torch.Tensor: + """0. The training image comes in + 1. is cropped from a random position + 2. on/off filteres + 3. returned. + + This is a 1 channel version (e.g. one gray channel). + """ + + assert len(cfg.image_statistics.mean) == 1 + assert len(cfg.image_statistics.the_size) == 2 + assert cfg.image_statistics.the_size[0] > 0 + assert cfg.image_statistics.the_size[1] > 0 + + # Transformation chain + my_transforms: torch.nn.Sequential = torch.nn.Sequential( + tv.transforms.RandomCrop(size=cfg.image_statistics.the_size), + tv.transforms.RandomHorizontalFlip(p=cfg.augmentation.flip_p), + tv.transforms.ColorJitter( + brightness=cfg.augmentation.jitter_brightness, + contrast=cfg.augmentation.jitter_contrast, + saturation=cfg.augmentation.jitter_saturation, + hue=cfg.augmentation.jitter_hue, + ), + ) + scripted_transforms = torch.jit.script(my_transforms) + + # Preprocess the input data + pattern = scripted_transforms(pattern) + + # => On/Off + my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0]) + gray: torch.Tensor = my_on_off_filter( + pattern[:, 0:1, :, :], + ) + + return gray + + +class DatasetCIFAR(DatasetMaster): + """Contstructor""" + + # Initialize + def __init__( + self, + train: bool = False, + path_pattern: str = "./", + path_label: str = "./", + ) -> None: + super().__init__(train, path_pattern, path_label) + + self.pattern_storage = np.ascontiguousarray( + np.moveaxis(self.pattern_storage.astype(dtype=np.float32), 3, 1) + ) + self.pattern_storage /= np.max(self.pattern_storage) + + mean = self.pattern_storage.mean(3).mean(2).mean(0) + self.mean = [*mean] + + def __getitem__(self, index: int) -> tuple[torch.Tensor, int]: + + image = self.pattern_storage[index, :, :, :] + target = int(self.label_storage[index]) + return torch.tensor(image), target + + def pattern_filter_test(self, pattern: torch.Tensor, cfg: Config) -> torch.Tensor: + """0. The test image comes in + 1. is center cropped + 2. on/off filteres + 3. returned. + + This is a 3 channel version (e.g. r,g,b channels). + """ + + assert len(cfg.image_statistics.mean) == 3 + assert len(cfg.image_statistics.the_size) == 2 + assert cfg.image_statistics.the_size[0] > 0 + assert cfg.image_statistics.the_size[1] > 0 + + # Transformation chain + my_transforms: torch.nn.Sequential = torch.nn.Sequential( + tv.transforms.CenterCrop(size=cfg.image_statistics.the_size), + ) + scripted_transforms = torch.jit.script(my_transforms) + + # Preprocess the input data + pattern = scripted_transforms(pattern) + + # => On/Off + + my_on_off_filter_r: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0]) + my_on_off_filter_g: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[1]) + my_on_off_filter_b: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[2]) + r: torch.Tensor = my_on_off_filter_r( + pattern[:, 0:1, :, :], + ) + g: torch.Tensor = my_on_off_filter_g( + pattern[:, 1:2, :, :], + ) + b: torch.Tensor = my_on_off_filter_b( + pattern[:, 2:3, :, :], + ) + + new_tensor: torch.Tensor = torch.cat((r, g, b), dim=1) + return new_tensor + + def pattern_filter_train(self, pattern: torch.Tensor, cfg: Config) -> torch.Tensor: + """0. The training image comes in + 1. is cropped from a random position + 2. is randomly horizontally flipped + 3. is randomly color jitteres + 4. on/off filteres + 5. returned. + + This is a 3 channel version (e.g. r,g,b channels). + """ + assert len(cfg.image_statistics.mean) == 3 + assert len(cfg.image_statistics.the_size) == 2 + assert cfg.image_statistics.the_size[0] > 0 + assert cfg.image_statistics.the_size[1] > 0 + + # Transformation chain + my_transforms: torch.nn.Sequential = torch.nn.Sequential( + tv.transforms.RandomCrop(size=cfg.image_statistics.the_size), + tv.transforms.RandomHorizontalFlip(p=cfg.augmentation.flip_p), + tv.transforms.ColorJitter( + brightness=cfg.augmentation.jitter_brightness, + contrast=cfg.augmentation.jitter_contrast, + saturation=cfg.augmentation.jitter_saturation, + hue=cfg.augmentation.jitter_hue, + ), + ) + scripted_transforms = torch.jit.script(my_transforms) + + # Preprocess the input data + pattern = scripted_transforms(pattern) + + # => On/Off + my_on_off_filter_r: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0]) + my_on_off_filter_g: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[1]) + my_on_off_filter_b: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[2]) + r: torch.Tensor = my_on_off_filter_r( + pattern[:, 0:1, :, :], + ) + g: torch.Tensor = my_on_off_filter_g( + pattern[:, 1:2, :, :], + ) + b: torch.Tensor = my_on_off_filter_b( + pattern[:, 2:3, :, :], + ) + + new_tensor: torch.Tensor = torch.cat((r, g, b), dim=1) + return new_tensor + + +class OnOffFilter(torch.nn.Module): + def __init__(self, p: float = 0.5) -> None: + super(OnOffFilter, self).__init__() + self.p: float = p + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + + assert tensor.shape[1] == 1 + + tensor_clone = 2.0 * (tensor - self.p) + + temp_0: torch.Tensor = torch.where( + tensor_clone < 0.0, + -tensor_clone, + tensor_clone.new_zeros(tensor_clone.shape, dtype=tensor_clone.dtype), + ) + + temp_1: torch.Tensor = torch.where( + tensor_clone >= 0.0, + tensor_clone, + tensor_clone.new_zeros(tensor_clone.shape, dtype=tensor_clone.dtype), + ) + + new_tensor: torch.Tensor = torch.cat((temp_0, temp_1), dim=1) + + return new_tensor + + def __repr__(self) -> str: + return self.__class__.__name__ + "(p={0})".format(self.p) + + +if __name__ == "__main__": + pass diff --git a/DATA_CIFAR10/PyTorch_Non_Spike_Network/Error.png b/DATA_CIFAR10/PyTorch_Non_Spike_Network/Error.png new file mode 100644 index 0000000..ec03281 Binary files /dev/null and b/DATA_CIFAR10/PyTorch_Non_Spike_Network/Error.png differ diff --git a/DATA_CIFAR10/PyTorch_Non_Spike_Network/Model_MNIST_A_499.pt.gz b/DATA_CIFAR10/PyTorch_Non_Spike_Network/Model_MNIST_A_499.pt.gz new file mode 100644 index 0000000..80bd539 Binary files /dev/null and b/DATA_CIFAR10/PyTorch_Non_Spike_Network/Model_MNIST_A_499.pt.gz differ diff --git a/DATA_CIFAR10/PyTorch_Non_Spike_Network/Parameter.py b/DATA_CIFAR10/PyTorch_Non_Spike_Network/Parameter.py new file mode 100644 index 0000000..92fe247 --- /dev/null +++ b/DATA_CIFAR10/PyTorch_Non_Spike_Network/Parameter.py @@ -0,0 +1,164 @@ +# MIT License +# Copyright 2022 University of Bremen +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR +# THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# +# David Rotermund ( davrot@uni-bremen.de ) +# +# +# Release history: +# ================ +# 1.0.0 -- 01.05.2022: first release +# +# + +# %% +from dataclasses import dataclass, field +import numpy as np +import torch +import os + + +@dataclass +class Network: + """Parameters of the network. The details about + its layers and the number of output neurons.""" + + number_of_output_neurons: int = field(default=0) + forward_kernel_size: list[list[int]] = field(default_factory=list) + forward_neuron_numbers: list[list[int]] = field(default_factory=list) + strides: list[list[int]] = field(default_factory=list) + dilation: list[list[int]] = field(default_factory=list) + padding: list[list[int]] = field(default_factory=list) + is_pooling_layer: list[bool] = field(default_factory=list) + w_trainable: list[bool] = field(default_factory=list) + eps_xy_trainable: list[bool] = field(default_factory=list) + eps_xy_mean: list[bool] = field(default_factory=list) + + +@dataclass +class LearningParameters: + """Parameter required for training""" + + loss_coeffs_mse: float = field(default=0.5) + loss_coeffs_kldiv: float = field(default=1.0) + learning_rate_gamma_w: float = field(default=-1.0) + learning_rate_gamma_eps_xy: float = field(default=-1.0) + learning_rate_threshold_w: float = field(default=0.00001) + learning_rate_threshold_eps_xy: float = field(default=0.00001) + learning_active: bool = field(default=True) + weight_noise_amplitude: float = field(default=0.01) + eps_xy_intitial: float = field(default=0.1) + test_every_x_learning_steps: int = field(default=50) + test_during_learning: bool = field(default=True) + lr_scheduler_factor: float = field(default=0.75) + lr_scheduler_patience: int = field(default=10) + optimizer_name: str = field(default="Adam") + lr_schedule_name: str = field(default="ReduceLROnPlateau") + number_of_batches_for_one_update: int = field(default=1) + alpha_number_of_iterations: int = field(default=0) + overload_path: str = field(default="./Previous") + + +@dataclass +class Augmentation: + """Parameters used for data augmentation.""" + + crop_width_in_pixel: int = field(default=2) + flip_p: float = field(default=0.5) + jitter_brightness: float = field(default=0.5) + jitter_contrast: float = field(default=0.1) + jitter_saturation: float = field(default=0.1) + jitter_hue: float = field(default=0.15) + + +@dataclass +class ImageStatistics: + """(Statistical) information about the input. i.e. + mean values and the x and y size of the input""" + + mean: list[float] = field(default_factory=list) + the_size: list[int] = field(default_factory=list) + + +@dataclass +class Config: + """Master config class.""" + + # Sub classes + network_structure: Network = field(default_factory=Network) + learning_parameters: LearningParameters = field(default_factory=LearningParameters) + augmentation: Augmentation = field(default_factory=Augmentation) + image_statistics: ImageStatistics = field(default_factory=ImageStatistics) + + batch_size: int = field(default=500) + data_mode: str = field(default="") + + learning_step: int = field(default=0) + learning_step_max: int = field(default=10000) + + number_of_cpu_processes: int = field(default=-1) + + number_of_spikes: int = field(default=0) + cooldown_after_number_of_spikes: int = field(default=0) + + weight_path: str = field(default="./Weights/") + eps_xy_path: str = field(default="./EpsXY/") + data_path: str = field(default="./") + + reduction_cooldown: float = field(default=25.0) + epsilon_0: float = field(default=1.0) + + update_after_x_batch: float = field(default=1.0) + + def __post_init__(self) -> None: + """Post init determines the number of cores. + Creates the required directory and gives us an optimized + (for the amount of cores) batch size.""" + number_of_cpu_processes_temp = os.cpu_count() + + if self.number_of_cpu_processes < 1: + if number_of_cpu_processes_temp is None: + self.number_of_cpu_processes = 1 + else: + self.number_of_cpu_processes = number_of_cpu_processes_temp + + os.makedirs(self.weight_path, exist_ok=True) + os.makedirs(self.eps_xy_path, exist_ok=True) + os.makedirs(self.data_path, exist_ok=True) + + self.batch_size = ( + self.batch_size // self.number_of_cpu_processes + ) * self.number_of_cpu_processes + + self.batch_size = np.max((self.batch_size, self.number_of_cpu_processes)) + self.batch_size = int(self.batch_size) + + def get_epsilon_t(self): + """Generates the time series of the basic epsilon.""" + np_epsilon_t: np.ndarray = np.ones((self.number_of_spikes), dtype=np.float32) + np_epsilon_t[ + self.cooldown_after_number_of_spikes : self.number_of_spikes + ] /= self.reduction_cooldown + return torch.tensor(np_epsilon_t) + + def get_update_after_x_pattern(self): + """Tells us after how many pattern we need to update the weights.""" + return self.batch_size * self.update_after_x_batch diff --git a/DATA_CIFAR10/PyTorch_Non_Spike_Network/events.out.tfevents.1651334099.fedora.121264.0 b/DATA_CIFAR10/PyTorch_Non_Spike_Network/events.out.tfevents.1651334099.fedora.121264.0 new file mode 100644 index 0000000..3e5aae1 Binary files /dev/null and b/DATA_CIFAR10/PyTorch_Non_Spike_Network/events.out.tfevents.1651334099.fedora.121264.0 differ diff --git a/DATA_CIFAR10/PyTorch_Non_Spike_Network/info.md b/DATA_CIFAR10/PyTorch_Non_Spike_Network/info.md new file mode 100644 index 0000000..1ebeb3f --- /dev/null +++ b/DATA_CIFAR10/PyTorch_Non_Spike_Network/info.md @@ -0,0 +1 @@ +Performance reached (test data correct classifications): 76.60% \ No newline at end of file diff --git a/DATA_CIFAR10/PyTorch_Non_Spike_Network/plot.py b/DATA_CIFAR10/PyTorch_Non_Spike_Network/plot.py new file mode 100644 index 0000000..82af53b --- /dev/null +++ b/DATA_CIFAR10/PyTorch_Non_Spike_Network/plot.py @@ -0,0 +1,31 @@ +import os + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + +import numpy as np +import matplotlib.pyplot as plt +from tensorboard.backend.event_processing import event_accumulator + +filename: str = "events.out.tfevents.1651334099.fedora.121264.0" + +acc = event_accumulator.EventAccumulator(filename) +acc.Reload() + +# What is available? +# available_scalar = acc.Tags()["scalars"] +# print("Available Scalars") +# print(available_scalar) + +which_scalar: str = "Test Number Correct" +te = acc.Scalars(which_scalar) + +temp: list = [] +for te_item in te: + temp.append((te_item[1], te_item[2])) +temp_np = np.array(temp) + +plt.semilogy(temp_np[:, 0], (1.0 - (temp_np[:, 1] / 10000)) * 100) +plt.xlabel("Epochs") +plt.ylabel("Error [%]") +plt.savefig("Error.png") +plt.show() diff --git a/DATA_CIFAR10/PyTorch_Non_Spike_Network/run.py b/DATA_CIFAR10/PyTorch_Non_Spike_Network/run.py new file mode 100644 index 0000000..a7f59ab --- /dev/null +++ b/DATA_CIFAR10/PyTorch_Non_Spike_Network/run.py @@ -0,0 +1,203 @@ +# %% +import torch +from Dataset import DatasetCIFAR +from Parameter import Config +import torchvision as tv # type: ignore + +# Some parameters + +cfg = Config() + +input_number_of_channel: int = 3 +input_dim_x: int = 28 +input_dim_y: int = 28 + +number_of_output_channels_conv1: int = 96 +number_of_output_channels_conv2: int = 192 +number_of_output_channels_flatten1: int = 3072 +number_of_output_channels_full1: int = 10 + +kernel_size_conv1: tuple[int, int] = (5, 5) +kernel_size_pool1: tuple[int, int] = (2, 2) +kernel_size_conv2: tuple[int, int] = (5, 5) +kernel_size_pool2: tuple[int, int] = (2, 2) + +stride_conv1: tuple[int, int] = (1, 1) +stride_pool1: tuple[int, int] = (2, 2) +stride_conv2: tuple[int, int] = (1, 1) +stride_pool2: tuple[int, int] = (2, 2) + +padding_conv1: int = 0 +padding_pool1: int = 0 +padding_conv2: int = 0 +padding_pool2: int = 0 + +network = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=input_number_of_channel, + out_channels=number_of_output_channels_conv1, + kernel_size=kernel_size_conv1, + stride=stride_conv1, + padding=padding_conv1, + ), + torch.nn.ReLU(), + torch.nn.MaxPool2d( + kernel_size=kernel_size_pool1, stride=stride_pool1, padding=padding_pool1 + ), + torch.nn.Conv2d( + in_channels=number_of_output_channels_conv1, + out_channels=number_of_output_channels_conv2, + kernel_size=kernel_size_conv2, + stride=stride_conv2, + padding=padding_conv2, + ), + torch.nn.ReLU(), + torch.nn.MaxPool2d( + kernel_size=kernel_size_pool2, stride=stride_pool2, padding=padding_pool2 + ), + torch.nn.Flatten( + start_dim=1, + ), + torch.nn.Linear( + in_features=number_of_output_channels_flatten1, + out_features=number_of_output_channels_full1, + bias=True, + ), + torch.nn.Softmax(dim=1), +) +# %% +path_pattern: str = "./DATA_CIFAR10/" +path_label: str = "./DATA_CIFAR10/" + +dataset_train = DatasetCIFAR( + train=True, path_pattern=path_pattern, path_label=path_label +) +dataset_test = DatasetCIFAR( + train=False, path_pattern=path_pattern, path_label=path_label +) +cfg.image_statistics.mean = dataset_train.mean +# The basic size +cfg.image_statistics.the_size = [ + dataset_train.pattern_storage.shape[2], + dataset_train.pattern_storage.shape[3], +] +# Minus the stuff we cut away in the pattern filter +cfg.image_statistics.the_size[0] -= 2 * cfg.augmentation.crop_width_in_pixel +cfg.image_statistics.the_size[1] -= 2 * cfg.augmentation.crop_width_in_pixel + + +batch_size_train: int = 100 +batch_size_test: int = 100 + + +train_data_load = torch.utils.data.DataLoader( + dataset_train, batch_size=batch_size_train, shuffle=True +) + +test_data_load = torch.utils.data.DataLoader( + dataset_test, batch_size=batch_size_test, shuffle=False +) + +transforms_test: torch.nn.Sequential = torch.nn.Sequential( + tv.transforms.CenterCrop(size=cfg.image_statistics.the_size), +) +scripted_transforms_test = torch.jit.script(transforms_test) + +transforms_train: torch.nn.Sequential = torch.nn.Sequential( + tv.transforms.RandomCrop(size=cfg.image_statistics.the_size), + tv.transforms.RandomHorizontalFlip(p=cfg.augmentation.flip_p), + tv.transforms.ColorJitter( + brightness=cfg.augmentation.jitter_brightness, + contrast=cfg.augmentation.jitter_contrast, + saturation=cfg.augmentation.jitter_saturation, + hue=cfg.augmentation.jitter_hue, + ), +) +scripted_transforms_train = torch.jit.script(transforms_train) +# %% +# The optimizer +optimizer = torch.optim.Adam(network.parameters(), lr=0.001) +# The LR Scheduler +lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.75) + +# %% +number_of_test_pattern: int = dataset_test.__len__() +number_of_train_pattern: int = dataset_train.__len__() + +number_of_epoch: int = 500 + +# %% +import time +from torch.utils.tensorboard import SummaryWriter + +tb = SummaryWriter() + +# %% +loss_function = torch.nn.CrossEntropyLoss() + +for epoch_id in range(0, number_of_epoch): + print(f"Epoch: {epoch_id}") + t_start: float = time.perf_counter() + + train_loss: float = 0.0 + train_correct: int = 0 + train_number: int = 0 + test_correct: int = 0 + test_number: int = 0 + + # Switch the network into training mode + network.train() + + # This runs in total for one epoch split up into mini-batches + for image, target in train_data_load: + + # Clean the gradient + optimizer.zero_grad() + + output = network(scripted_transforms_train(image)) + + loss = loss_function(output, target) + + train_loss += loss.item() + train_correct += (output.argmax(dim=1) == target).sum().numpy() + train_number += target.shape[0] + # Calculate backprop + loss.backward() + + # Update the parameter + optimizer.step() + + # Update the learning rate + lr_scheduler.step(train_loss) + + t_training: float = time.perf_counter() + + # Switch the network into evalution mode + network.eval() + with torch.no_grad(): + for image, target in test_data_load: + + output = network(scripted_transforms_test(image)) + + test_correct += (output.argmax(dim=1) == target).sum().numpy() + test_number += target.shape[0] + + t_testing = time.perf_counter() + + perfomance_test_correct: float = 100.0 * test_correct / test_number + perfomance_train_correct: float = 100.0 * train_correct / train_number + + tb.add_scalar("Train Loss", train_loss, epoch_id) + tb.add_scalar("Train Number Correct", train_correct, epoch_id) + tb.add_scalar("Test Number Correct", test_correct, epoch_id) + + print(f"Training: Loss={train_loss:.5f} Correct={perfomance_train_correct:.2f}%") + print(f"Testing: Correct={perfomance_test_correct:.2f}%") + print( + f"Time: Training={(t_training-t_start):.1f}sec, Testing={(t_testing-t_training):.1f}sec" + ) + torch.save(network, "Model_MNIST_A_" + str(epoch_id) + ".pt") + print() + +# %% +tb.close()