From 2928901a38816b72f2caec59d633b44e762cd93b Mon Sep 17 00:00:00 2001 From: davrot <54365609+davrot@users.noreply.github.com> Date: Sat, 30 Apr 2022 02:07:09 +0200 Subject: [PATCH] First version --- Dataset.py | 422 ++++++++++++ Parameter.py | 164 +++++ PyHDynamicCNNManyIP.pyi | 18 + PySpikeGeneration2DManyIP.pyi | 18 + SbS.py | 1220 +++++++++++++++++++++++++++++++++ learn_it.py | 590 ++++++++++++++++ 6 files changed, 2432 insertions(+) create mode 100644 Dataset.py create mode 100644 Parameter.py create mode 100644 PyHDynamicCNNManyIP.pyi create mode 100644 PySpikeGeneration2DManyIP.pyi create mode 100644 SbS.py create mode 100644 learn_it.py diff --git a/Dataset.py b/Dataset.py new file mode 100644 index 0000000..11f9854 --- /dev/null +++ b/Dataset.py @@ -0,0 +1,422 @@ +# MIT License +# Copyright 2022 University of Bremen +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR +# THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# +# David Rotermund ( davrot@uni-bremen.de ) +# +# +# Release history: +# ================ +# 1.0.0 -- 01.05.2022: first release +# +# + +from abc import ABC, abstractmethod +import torch +import numpy as np +import torchvision as tv # type: ignore +from Parameter import Config + + +class DatasetMaster(torch.utils.data.Dataset, ABC): + + path_label: str + label_storage: np.ndarray + pattern_storage: np.ndarray + number_of_pattern: int + mean: list[float] + + # Initialize + def __init__( + self, + train: bool = False, + path_pattern: str = "./", + path_label: str = "./", + ) -> None: + super().__init__() + + if train is True: + self.label_storage = np.load(path_label + "/TrainLabelStorage.npy") + else: + self.label_storage = np.load(path_label + "/TestLabelStorage.npy") + + if train is True: + self.pattern_storage = np.load(path_pattern + "/TrainPatternStorage.npy") + else: + self.pattern_storage = np.load(path_pattern + "/TestPatternStorage.npy") + + self.number_of_pattern = self.label_storage.shape[0] + + self.mean = [] + + def __len__(self) -> int: + return self.number_of_pattern + + # Get one pattern at position index + @abstractmethod + def __getitem__(self, index: int) -> tuple[torch.Tensor, int]: + pass + + @abstractmethod + def pattern_filter_test(self, pattern: torch.Tensor, cfg: Config) -> torch.Tensor: + pass + + @abstractmethod + def pattern_filter_train(self, pattern: torch.Tensor, cfg: Config) -> torch.Tensor: + pass + + +class DatasetMNIST(DatasetMaster): + """Contstructor""" + + # Initialize + def __init__( + self, + train: bool = False, + path_pattern: str = "./", + path_label: str = "./", + ) -> None: + super().__init__(train, path_pattern, path_label) + + self.pattern_storage = np.ascontiguousarray( + self.pattern_storage[:, np.newaxis, :, :].astype(dtype=np.float32) + ) + + self.pattern_storage /= np.max(self.pattern_storage) + + mean = self.pattern_storage.mean(3).mean(2).mean(0) + self.mean = [*mean] + + def __getitem__(self, index: int) -> tuple[torch.Tensor, int]: + + image = self.pattern_storage[index, 0:1, :, :] + target = int(self.label_storage[index]) + return torch.tensor(image), target + + def pattern_filter_test(self, pattern: torch.Tensor, cfg: Config) -> torch.Tensor: + """0. The test image comes in + 1. is center cropped + 2. on/off filteres + 3. returned. + + This is a 1 channel version (e.g. one gray channel). + """ + + assert len(cfg.image_statistics.mean) == 1 + assert len(cfg.image_statistics.the_size) == 2 + assert cfg.image_statistics.the_size[0] > 0 + assert cfg.image_statistics.the_size[1] > 0 + + # Transformation chain + my_transforms: torch.nn.Sequential = torch.nn.Sequential( + tv.transforms.CenterCrop(size=cfg.image_statistics.the_size), + ) + scripted_transforms = torch.jit.script(my_transforms) + + # Preprocess the input data + pattern = scripted_transforms(pattern) + + # => On/Off + my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0]) + gray: torch.Tensor = my_on_off_filter( + pattern[:, 0:1, :, :], + ) + + return gray + + def pattern_filter_train(self, pattern: torch.Tensor, cfg: Config) -> torch.Tensor: + """0. The training image comes in + 1. is cropped from a random position + 2. on/off filteres + 3. returned. + + This is a 1 channel version (e.g. one gray channel). + """ + + assert len(cfg.image_statistics.mean) == 1 + assert len(cfg.image_statistics.the_size) == 2 + assert cfg.image_statistics.the_size[0] > 0 + assert cfg.image_statistics.the_size[1] > 0 + + # Transformation chain + my_transforms: torch.nn.Sequential = torch.nn.Sequential( + tv.transforms.RandomCrop(size=cfg.image_statistics.the_size), + ) + scripted_transforms = torch.jit.script(my_transforms) + + # Preprocess the input data + pattern = scripted_transforms(pattern) + + # => On/Off + my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0]) + gray: torch.Tensor = my_on_off_filter( + pattern[:, 0:1, :, :], + ) + + return gray + + +class DatasetFashionMNIST(DatasetMaster): + """Contstructor""" + + # Initialize + def __init__( + self, + train: bool = False, + path_pattern: str = "./", + path_label: str = "./", + ) -> None: + super().__init__(train, path_pattern, path_label) + + self.pattern_storage = np.ascontiguousarray( + self.pattern_storage[:, np.newaxis, :, :].astype(dtype=np.float32) + ) + + self.pattern_storage /= np.max(self.pattern_storage) + + mean = self.pattern_storage.mean(3).mean(2).mean(0) + self.mean = [*mean] + + def __getitem__(self, index: int) -> tuple[torch.Tensor, int]: + + image = self.pattern_storage[index, 0:1, :, :] + target = int(self.label_storage[index]) + return torch.tensor(image), target + + def pattern_filter_test(self, pattern: torch.Tensor, cfg: Config) -> torch.Tensor: + """0. The test image comes in + 1. is center cropped + 2. on/off filteres + 3. returned. + + This is a 1 channel version (e.g. one gray channel). + """ + + assert len(cfg.image_statistics.mean) == 1 + assert len(cfg.image_statistics.the_size) == 2 + assert cfg.image_statistics.the_size[0] > 0 + assert cfg.image_statistics.the_size[1] > 0 + + # Transformation chain + my_transforms: torch.nn.Sequential = torch.nn.Sequential( + tv.transforms.CenterCrop(size=cfg.image_statistics.the_size), + ) + scripted_transforms = torch.jit.script(my_transforms) + + # Preprocess the input data + pattern = scripted_transforms(pattern) + + # => On/Off + my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0]) + gray: torch.Tensor = my_on_off_filter( + pattern[:, 0:1, :, :], + ) + + return gray + + def pattern_filter_train(self, pattern: torch.Tensor, cfg: Config) -> torch.Tensor: + """0. The training image comes in + 1. is cropped from a random position + 2. on/off filteres + 3. returned. + + This is a 1 channel version (e.g. one gray channel). + """ + + assert len(cfg.image_statistics.mean) == 1 + assert len(cfg.image_statistics.the_size) == 2 + assert cfg.image_statistics.the_size[0] > 0 + assert cfg.image_statistics.the_size[1] > 0 + + # Transformation chain + my_transforms: torch.nn.Sequential = torch.nn.Sequential( + tv.transforms.RandomCrop(size=cfg.image_statistics.the_size), + tv.transforms.RandomHorizontalFlip(p=cfg.augmentation.flip_p), + tv.transforms.ColorJitter( + brightness=cfg.augmentation.jitter_brightness, + contrast=cfg.augmentation.jitter_contrast, + saturation=cfg.augmentation.jitter_saturation, + hue=cfg.augmentation.jitter_hue, + ), + ) + scripted_transforms = torch.jit.script(my_transforms) + + # Preprocess the input data + pattern = scripted_transforms(pattern) + + # => On/Off + my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0]) + gray: torch.Tensor = my_on_off_filter( + pattern[:, 0:1, :, :], + ) + + return gray + + +class DatasetCIFAR(DatasetMaster): + """Contstructor""" + + # Initialize + def __init__( + self, + train: bool = False, + path_pattern: str = "./", + path_label: str = "./", + ) -> None: + super().__init__(train, path_pattern, path_label) + + self.pattern_storage = np.ascontiguousarray( + np.moveaxis(self.pattern_storage.astype(dtype=np.float32), 3, 1) + ) + self.pattern_storage /= np.max(self.pattern_storage) + + mean = self.pattern_storage.mean(3).mean(2).mean(0) + self.mean = [*mean] + + def __getitem__(self, index: int) -> tuple[torch.Tensor, int]: + + image = self.pattern_storage[index, :, :, :] + target = int(self.label_storage[index]) + return torch.tensor(image), target + + def pattern_filter_test(self, pattern: torch.Tensor, cfg: Config) -> torch.Tensor: + """0. The test image comes in + 1. is center cropped + 2. on/off filteres + 3. returned. + + This is a 3 channel version (e.g. r,g,b channels). + """ + + assert len(cfg.image_statistics.mean) == 3 + assert len(cfg.image_statistics.the_size) == 2 + assert cfg.image_statistics.the_size[0] > 0 + assert cfg.image_statistics.the_size[1] > 0 + + # Transformation chain + my_transforms: torch.nn.Sequential = torch.nn.Sequential( + tv.transforms.CenterCrop(size=cfg.image_statistics.the_size), + ) + scripted_transforms = torch.jit.script(my_transforms) + + # Preprocess the input data + pattern = scripted_transforms(pattern) + + # => On/Off + + my_on_off_filter_r: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0]) + my_on_off_filter_g: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[1]) + my_on_off_filter_b: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[2]) + r: torch.Tensor = my_on_off_filter_r( + pattern[:, 0:1, :, :], + ) + g: torch.Tensor = my_on_off_filter_g( + pattern[:, 1:2, :, :], + ) + b: torch.Tensor = my_on_off_filter_b( + pattern[:, 2:3, :, :], + ) + + new_tensor: torch.Tensor = torch.cat((r, g, b), dim=1) + return new_tensor + + def pattern_filter_train(self, pattern: torch.Tensor, cfg: Config) -> torch.Tensor: + """0. The training image comes in + 1. is cropped from a random position + 2. is randomly horizontally flipped + 3. is randomly color jitteres + 4. on/off filteres + 5. returned. + + This is a 3 channel version (e.g. r,g,b channels). + """ + assert len(cfg.image_statistics.mean) == 3 + assert len(cfg.image_statistics.the_size) == 2 + assert cfg.image_statistics.the_size[0] > 0 + assert cfg.image_statistics.the_size[1] > 0 + + # Transformation chain + my_transforms: torch.nn.Sequential = torch.nn.Sequential( + tv.transforms.RandomCrop(size=cfg.image_statistics.the_size), + tv.transforms.RandomHorizontalFlip(p=cfg.augmentation.flip_p), + tv.transforms.ColorJitter( + brightness=cfg.augmentation.jitter_brightness, + contrast=cfg.augmentation.jitter_contrast, + saturation=cfg.augmentation.jitter_saturation, + hue=cfg.augmentation.jitter_hue, + ), + ) + scripted_transforms = torch.jit.script(my_transforms) + + # Preprocess the input data + pattern = scripted_transforms(pattern) + + # => On/Off + my_on_off_filter_r: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0]) + my_on_off_filter_g: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[1]) + my_on_off_filter_b: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[2]) + r: torch.Tensor = my_on_off_filter_r( + pattern[:, 0:1, :, :], + ) + g: torch.Tensor = my_on_off_filter_g( + pattern[:, 1:2, :, :], + ) + b: torch.Tensor = my_on_off_filter_b( + pattern[:, 2:3, :, :], + ) + + new_tensor: torch.Tensor = torch.cat((r, g, b), dim=1) + return new_tensor + + +class OnOffFilter(torch.nn.Module): + def __init__(self, p: float = 0.5) -> None: + super(OnOffFilter, self).__init__() + self.p: float = p + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + + assert tensor.shape[1] == 1 + + tensor_clone = 2.0 * (tensor - self.p) + + temp_0: torch.Tensor = torch.where( + tensor_clone < 0.0, + -tensor_clone, + tensor_clone.new_zeros(tensor_clone.shape, dtype=tensor_clone.dtype), + ) + + temp_1: torch.Tensor = torch.where( + tensor_clone >= 0.0, + tensor_clone, + tensor_clone.new_zeros(tensor_clone.shape, dtype=tensor_clone.dtype), + ) + + new_tensor: torch.Tensor = torch.cat((temp_0, temp_1), dim=1) + + return new_tensor + + def __repr__(self) -> str: + return self.__class__.__name__ + "(p={0})".format(self.p) + + +if __name__ == "__main__": + pass diff --git a/Parameter.py b/Parameter.py new file mode 100644 index 0000000..92fe247 --- /dev/null +++ b/Parameter.py @@ -0,0 +1,164 @@ +# MIT License +# Copyright 2022 University of Bremen +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR +# THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# +# David Rotermund ( davrot@uni-bremen.de ) +# +# +# Release history: +# ================ +# 1.0.0 -- 01.05.2022: first release +# +# + +# %% +from dataclasses import dataclass, field +import numpy as np +import torch +import os + + +@dataclass +class Network: + """Parameters of the network. The details about + its layers and the number of output neurons.""" + + number_of_output_neurons: int = field(default=0) + forward_kernel_size: list[list[int]] = field(default_factory=list) + forward_neuron_numbers: list[list[int]] = field(default_factory=list) + strides: list[list[int]] = field(default_factory=list) + dilation: list[list[int]] = field(default_factory=list) + padding: list[list[int]] = field(default_factory=list) + is_pooling_layer: list[bool] = field(default_factory=list) + w_trainable: list[bool] = field(default_factory=list) + eps_xy_trainable: list[bool] = field(default_factory=list) + eps_xy_mean: list[bool] = field(default_factory=list) + + +@dataclass +class LearningParameters: + """Parameter required for training""" + + loss_coeffs_mse: float = field(default=0.5) + loss_coeffs_kldiv: float = field(default=1.0) + learning_rate_gamma_w: float = field(default=-1.0) + learning_rate_gamma_eps_xy: float = field(default=-1.0) + learning_rate_threshold_w: float = field(default=0.00001) + learning_rate_threshold_eps_xy: float = field(default=0.00001) + learning_active: bool = field(default=True) + weight_noise_amplitude: float = field(default=0.01) + eps_xy_intitial: float = field(default=0.1) + test_every_x_learning_steps: int = field(default=50) + test_during_learning: bool = field(default=True) + lr_scheduler_factor: float = field(default=0.75) + lr_scheduler_patience: int = field(default=10) + optimizer_name: str = field(default="Adam") + lr_schedule_name: str = field(default="ReduceLROnPlateau") + number_of_batches_for_one_update: int = field(default=1) + alpha_number_of_iterations: int = field(default=0) + overload_path: str = field(default="./Previous") + + +@dataclass +class Augmentation: + """Parameters used for data augmentation.""" + + crop_width_in_pixel: int = field(default=2) + flip_p: float = field(default=0.5) + jitter_brightness: float = field(default=0.5) + jitter_contrast: float = field(default=0.1) + jitter_saturation: float = field(default=0.1) + jitter_hue: float = field(default=0.15) + + +@dataclass +class ImageStatistics: + """(Statistical) information about the input. i.e. + mean values and the x and y size of the input""" + + mean: list[float] = field(default_factory=list) + the_size: list[int] = field(default_factory=list) + + +@dataclass +class Config: + """Master config class.""" + + # Sub classes + network_structure: Network = field(default_factory=Network) + learning_parameters: LearningParameters = field(default_factory=LearningParameters) + augmentation: Augmentation = field(default_factory=Augmentation) + image_statistics: ImageStatistics = field(default_factory=ImageStatistics) + + batch_size: int = field(default=500) + data_mode: str = field(default="") + + learning_step: int = field(default=0) + learning_step_max: int = field(default=10000) + + number_of_cpu_processes: int = field(default=-1) + + number_of_spikes: int = field(default=0) + cooldown_after_number_of_spikes: int = field(default=0) + + weight_path: str = field(default="./Weights/") + eps_xy_path: str = field(default="./EpsXY/") + data_path: str = field(default="./") + + reduction_cooldown: float = field(default=25.0) + epsilon_0: float = field(default=1.0) + + update_after_x_batch: float = field(default=1.0) + + def __post_init__(self) -> None: + """Post init determines the number of cores. + Creates the required directory and gives us an optimized + (for the amount of cores) batch size.""" + number_of_cpu_processes_temp = os.cpu_count() + + if self.number_of_cpu_processes < 1: + if number_of_cpu_processes_temp is None: + self.number_of_cpu_processes = 1 + else: + self.number_of_cpu_processes = number_of_cpu_processes_temp + + os.makedirs(self.weight_path, exist_ok=True) + os.makedirs(self.eps_xy_path, exist_ok=True) + os.makedirs(self.data_path, exist_ok=True) + + self.batch_size = ( + self.batch_size // self.number_of_cpu_processes + ) * self.number_of_cpu_processes + + self.batch_size = np.max((self.batch_size, self.number_of_cpu_processes)) + self.batch_size = int(self.batch_size) + + def get_epsilon_t(self): + """Generates the time series of the basic epsilon.""" + np_epsilon_t: np.ndarray = np.ones((self.number_of_spikes), dtype=np.float32) + np_epsilon_t[ + self.cooldown_after_number_of_spikes : self.number_of_spikes + ] /= self.reduction_cooldown + return torch.tensor(np_epsilon_t) + + def get_update_after_x_pattern(self): + """Tells us after how many pattern we need to update the weights.""" + return self.batch_size * self.update_after_x_batch diff --git a/PyHDynamicCNNManyIP.pyi b/PyHDynamicCNNManyIP.pyi new file mode 100644 index 0000000..d86ff54 --- /dev/null +++ b/PyHDynamicCNNManyIP.pyi @@ -0,0 +1,18 @@ +# +# AUTOMATICALLY GENERATED FILE, DO NOT EDIT! +# + +"""HDynamicCNNManyIP Module""" +from __future__ import annotations +import PyHDynamicCNNManyIP +import typing + +__all__ = [ + "HDynamicCNNManyIP" +] + + +class HDynamicCNNManyIP(): + def __init__(self) -> None: ... + def update_with_init_vector_multi_pattern(self, arg0: int, arg1: int, arg2: int, arg3: int, arg4: int, arg5: int, arg6: int, arg7: int, arg8: int, arg9: int, arg10: int, arg11: int, arg12: int, arg13: int, arg14: int, arg15: int, arg16: int, arg17: int, arg18: int, arg19: int, arg20: int) -> bool: ... + pass diff --git a/PySpikeGeneration2DManyIP.pyi b/PySpikeGeneration2DManyIP.pyi new file mode 100644 index 0000000..57e593a --- /dev/null +++ b/PySpikeGeneration2DManyIP.pyi @@ -0,0 +1,18 @@ +# +# AUTOMATICALLY GENERATED FILE, DO NOT EDIT! +# + +"""SpikeGeneration2DManyIP Module""" +from __future__ import annotations +import PySpikeGeneration2DManyIP +import typing + +__all__ = [ + "SpikeGeneration2DManyIP" +] + + +class SpikeGeneration2DManyIP(): + def __init__(self) -> None: ... + def spike_generation_multi_pattern(self, arg0: int, arg1: int, arg2: int, arg3: int, arg4: int, arg5: int, arg6: int, arg7: int, arg8: int, arg9: int, arg10: int, arg11: int, arg12: int, arg13: int, arg14: int, arg15: int) -> bool: ... + pass diff --git a/SbS.py b/SbS.py new file mode 100644 index 0000000..7b56a08 --- /dev/null +++ b/SbS.py @@ -0,0 +1,1220 @@ +# MIT License +# Copyright 2022 University of Bremen +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR +# THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# +# David Rotermund ( davrot@uni-bremen.de ) +# +# +# Release history: +# ================ +# 1.0.0 -- 01.05.2022: first release +# +# + +# %% +import torch +import numpy as np + +try: + import PySpikeGeneration2DManyIP + + cpp_spike: bool = True +except Exception: + cpp_spike = False + +try: + import PyHDynamicCNNManyIP + + cpp_sbs: bool = True +except Exception: + cpp_sbs = False + + +class SbS(torch.nn.Module): + + _epsilon_xy: torch.nn.parameter.Parameter + _epsilon_xy_exists: bool = False + _epsilon_0: torch.Tensor | None = None + _epsilon_t: torch.Tensor | None = None + _weights: torch.nn.parameter.Parameter + _weights_exists: bool = False + _kernel_size: torch.Tensor | None = None + _stride: torch.Tensor | None = None + _dilation: torch.Tensor | None = None + _padding: torch.Tensor | None = None + _output_size: torch.Tensor | None = None + _number_of_spikes: torch.Tensor | None = None + _number_of_cpu_processes: torch.Tensor | None = None + _number_of_neurons: torch.Tensor | None = None + _number_of_input_neurons: torch.Tensor | None = None + _h_initial: torch.Tensor | None = None + _epsilon_xy_backup: torch.Tensor | None = None + _weights_backup: torch.Tensor | None = None + _alpha_number_of_iterations: torch.Tensor | None = None + + def __init__( + self, + number_of_input_neurons: int, + number_of_neurons: int, + input_size: list[int], + forward_kernel_size: list[int], + number_of_spikes: int, + epsilon_t: torch.Tensor, + epsilon_xy_intitial: float = 0.1, + epsilon_0: float = 1.0, + weight_noise_amplitude: float = 0.01, + is_pooling_layer: bool = False, + strides: list[int] = [1, 1], + dilation: list[int] = [0, 0], + padding: list[int] = [0, 0], + alpha_number_of_iterations: int = 0, + number_of_cpu_processes: int = 1, + ) -> None: + """Constructor""" + super().__init__() + + self.stride = torch.tensor(strides, dtype=torch.int64) + + self.dilation = torch.tensor(dilation, dtype=torch.int64) + + self.padding = torch.tensor(padding, dtype=torch.int64) + + self.kernel_size = torch.tensor( + forward_kernel_size, + dtype=torch.int64, + ) + + self.number_of_input_neurons = torch.tensor( + number_of_input_neurons, + dtype=torch.int64, + ) + + self.number_of_neurons = torch.tensor( + number_of_neurons, + dtype=torch.int64, + ) + + self.alpha_number_of_iterations = torch.tensor( + alpha_number_of_iterations, dtype=torch.int64 + ) + + self.calculate_output_size(torch.tensor(input_size, dtype=torch.int64)) + + self.set_h_init_to_uniform() + + self.initialize_epsilon_xy(epsilon_xy_intitial) + + self.epsilon_0 = torch.tensor(epsilon_0, dtype=torch.float64) + + self.number_of_cpu_processes = torch.tensor( + number_of_cpu_processes, dtype=torch.int64 + ) + + self.number_of_spikes = torch.tensor(number_of_spikes, dtype=torch.int64) + + self.epsilon_t = epsilon_t.type(dtype=torch.float64) + + self.initialize_weights( + is_pooling_layer=is_pooling_layer, + noise_amplitude=weight_noise_amplitude, + ) + + self.functional_sbs = FunctionalSbS.apply + + #################################################################### + # Variables in and out # + #################################################################### + + @property + def epsilon_xy(self) -> torch.Tensor | None: + if self._epsilon_xy_exists is False: + return None + else: + return self._epsilon_xy + + @epsilon_xy.setter + def epsilon_xy(self, value: torch.Tensor): + assert value is not None + assert torch.is_tensor(value) is True + assert value.dim() == 2 + assert value.dtype == torch.float64 + if self._epsilon_xy_exists is False: + self._epsilon_xy = torch.nn.parameter.Parameter( + value.detach().clone(memory_format=torch.contiguous_format), + requires_grad=True, + ) + self._epsilon_xy_exists = True + else: + self._epsilon_xy.data = value.detach().clone( + memory_format=torch.contiguous_format + ) + + @property + def epsilon_0(self) -> torch.Tensor | None: + return self._epsilon_0 + + @epsilon_0.setter + def epsilon_0(self, value: torch.Tensor): + assert value is not None + assert torch.is_tensor(value) is True + assert torch.numel(value) == 1 + assert value.dtype == torch.float64 + assert value.item() > 0 + self._epsilon_0 = value.detach().clone(memory_format=torch.contiguous_format) + self._epsilon_0.requires_grad_(False) + + @property + def epsilon_t(self) -> torch.Tensor | None: + return self._epsilon_t + + @epsilon_t.setter + def epsilon_t(self, value: torch.Tensor): + assert value is not None + assert torch.is_tensor(value) is True + assert value.dim() == 1 + assert value.dtype == torch.float64 + self._epsilon_t = value.detach().clone(memory_format=torch.contiguous_format) + self._epsilon_t.requires_grad_(False) + + @property + def weights(self) -> torch.Tensor | None: + if self._weights_exists is False: + return None + else: + return self._weights + + @weights.setter + def weights(self, value: torch.Tensor): + assert value is not None + assert torch.is_tensor(value) is True + assert value.dim() == 2 + assert value.dtype == torch.float64 + temp: torch.Tensor = value.detach().clone(memory_format=torch.contiguous_format) + temp /= temp.sum(dim=0, keepdim=True, dtype=torch.float64) + if self._weights_exists is False: + self._weights = torch.nn.parameter.Parameter( + temp, + requires_grad=True, + ) + self._weights_exists = True + else: + self._weights.data = temp + + @property + def kernel_size(self) -> torch.Tensor | None: + return self._kernel_size + + @kernel_size.setter + def kernel_size(self, value: torch.Tensor): + assert value is not None + assert torch.is_tensor(value) is True + assert value.dim() == 1 + assert torch.numel(value) == 2 + assert value.dtype == torch.int64 + assert value[0] > 0 + assert value[1] > 0 + self._kernel_size = value.detach().clone(memory_format=torch.contiguous_format) + self._kernel_size.requires_grad_(False) + + @property + def stride(self) -> torch.Tensor | None: + return self._stride + + @stride.setter + def stride(self, value: torch.Tensor): + assert value is not None + assert torch.is_tensor(value) is True + assert value.dim() == 1 + assert torch.numel(value) == 2 + assert value.dtype == torch.int64 + assert value[0] > 0 + assert value[1] > 0 + self._stride = value.detach().clone(memory_format=torch.contiguous_format) + self._stride.requires_grad_(False) + + @property + def dilation(self) -> torch.Tensor | None: + return self._dilation + + @dilation.setter + def dilation(self, value: torch.Tensor): + assert value is not None + assert torch.is_tensor(value) is True + assert value.dim() == 1 + assert torch.numel(value) == 2 + assert value.dtype == torch.int64 + assert value[0] > 0 + assert value[1] > 0 + self._dilation = value.detach().clone(memory_format=torch.contiguous_format) + self._dilation.requires_grad_(False) + + @property + def padding(self) -> torch.Tensor | None: + return self._padding + + @padding.setter + def padding(self, value: torch.Tensor): + assert value is not None + assert torch.is_tensor(value) is True + assert value.dim() == 1 + assert torch.numel(value) == 2 + assert value.dtype == torch.int64 + assert value[0] >= 0 + assert value[1] >= 0 + self._padding = value.detach().clone(memory_format=torch.contiguous_format) + self._padding.requires_grad_(False) + + @property + def output_size(self) -> torch.Tensor | None: + return self._output_size + + @output_size.setter + def output_size(self, value: torch.Tensor): + assert value is not None + assert torch.is_tensor(value) is True + assert value.dim() == 1 + assert torch.numel(value) == 2 + assert value.dtype == torch.int64 + assert value[0] > 0 + assert value[1] > 0 + self._output_size = value.detach().clone(memory_format=torch.contiguous_format) + self._output_size.requires_grad_(False) + + @property + def number_of_spikes(self) -> torch.Tensor | None: + return self._number_of_spikes + + @number_of_spikes.setter + def number_of_spikes(self, value: torch.Tensor): + assert value is not None + assert torch.is_tensor(value) is True + assert torch.numel(value) == 1 + assert value.dtype == torch.int64 + assert value.item() > 0 + self._number_of_spikes = value.detach().clone( + memory_format=torch.contiguous_format + ) + self._number_of_spikes.requires_grad_(False) + + @property + def number_of_cpu_processes(self) -> torch.Tensor | None: + return self._number_of_cpu_processes + + @number_of_cpu_processes.setter + def number_of_cpu_processes(self, value: torch.Tensor): + assert value is not None + assert torch.is_tensor(value) is True + assert torch.numel(value) == 1 + assert value.dtype == torch.int64 + assert value.item() > 0 + self._number_of_cpu_processes = value.detach().clone( + memory_format=torch.contiguous_format + ) + self._number_of_cpu_processes.requires_grad_(False) + + @property + def number_of_neurons(self) -> torch.Tensor | None: + return self._number_of_neurons + + @number_of_neurons.setter + def number_of_neurons(self, value: torch.Tensor): + assert value is not None + assert torch.is_tensor(value) is True + assert torch.numel(value) == 1 + assert value.dtype == torch.int64 + assert value.item() > 0 + self._number_of_neurons = value.detach().clone( + memory_format=torch.contiguous_format + ) + self._number_of_neurons.requires_grad_(False) + + @property + def number_of_input_neurons(self) -> torch.Tensor | None: + return self._number_of_input_neurons + + @number_of_input_neurons.setter + def number_of_input_neurons(self, value: torch.Tensor): + assert value is not None + assert torch.is_tensor(value) is True + assert torch.numel(value) == 1 + assert value.dtype == torch.int64 + assert value.item() > 0 + self._number_of_input_neurons = value.detach().clone( + memory_format=torch.contiguous_format + ) + self._number_of_input_neurons.requires_grad_(False) + + @property + def h_initial(self) -> torch.Tensor | None: + return self._h_initial + + @h_initial.setter + def h_initial(self, value: torch.Tensor): + assert value is not None + assert torch.is_tensor(value) is True + assert value.dim() == 1 + assert value.dtype == torch.float32 + self._h_initial = value.detach().clone(memory_format=torch.contiguous_format) + self._h_initial.requires_grad_(False) + + @property + def alpha_number_of_iterations(self) -> torch.Tensor | None: + return self._alpha_number_of_iterations + + @alpha_number_of_iterations.setter + def alpha_number_of_iterations(self, value: torch.Tensor): + assert value is not None + assert torch.is_tensor(value) is True + assert torch.numel(value) == 1 + assert value.dtype == torch.int64 + assert value.item() >= 0 + self._alpha_number_of_iterations = value.detach().clone( + memory_format=torch.contiguous_format + ) + self._alpha_number_of_iterations.requires_grad_(False) + + #################################################################### + # Forward # + #################################################################### + + def forward(self, input: torch.Tensor) -> torch.Tensor: + """PyTorch Forward method. Does the work.""" + + # Are we happy with the input? + assert input is not None + assert torch.is_tensor(input) is True + assert input.dim() == 4 + assert input.dtype == torch.float64 + + # Are we happy with the rest of the network? + assert self._epsilon_xy_exists is True + assert self._epsilon_xy is not None + assert self._epsilon_0 is not None + assert self._epsilon_t is not None + assert self._weights_exists is True + assert self._weights is not None + assert self._kernel_size is not None + assert self._stride is not None + assert self._dilation is not None + assert self._padding is not None + assert self._output_size is not None + assert self._number_of_spikes is not None + assert self._number_of_cpu_processes is not None + assert self._h_initial is not None + assert self._alpha_number_of_iterations is not None + + # SbS forward functional + return self.functional_sbs( + input, + self._epsilon_xy, + self._epsilon_0, + self._epsilon_t, + self._weights, + self._kernel_size, + self._stride, + self._dilation, + self._padding, + self._output_size, + self._number_of_spikes, + self._number_of_cpu_processes, + self._h_initial, + self._alpha_number_of_iterations, + ) + + #################################################################### + # Helper functions # + #################################################################### + + def calculate_output_size(self, value: torch.Tensor) -> None: + + coordinates_0, coordinates_1 = self._get_coordinates(value) + + self._output_size: torch.Tensor = torch.tensor( + [ + coordinates_0.shape[1], + coordinates_1.shape[1], + ], + dtype=torch.int64, + ) + self._output_size.requires_grad_(False) + + def _get_coordinates( + self, value: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """Function converts parameter in coordinates + for the convolution window""" + + assert value is not None + assert torch.is_tensor(value) is True + assert value.dim() == 1 + assert torch.numel(value) == 2 + assert value.dtype == torch.int64 + assert value[0] > 0 + assert value[1] > 0 + + assert self._kernel_size is not None + assert self._stride is not None + assert self._dilation is not None + assert self._padding is not None + + assert torch.numel(self._kernel_size) == 2 + assert torch.numel(self._stride) == 2 + assert torch.numel(self._dilation) == 2 + assert torch.numel(self._padding) == 2 + + unfold_0: torch.nn.Unfold = torch.nn.Unfold( + kernel_size=(int(self._kernel_size[0]), 1), + dilation=int(self._dilation[0]), + padding=int(self._padding[0]), + stride=int(self._stride[0]), + ) + + unfold_1: torch.nn.Unfold = torch.nn.Unfold( + kernel_size=(1, int(self._kernel_size[1])), + dilation=int(self._dilation[1]), + padding=int(self._padding[1]), + stride=int(self._stride[1]), + ) + + coordinates_0: torch.Tensor = ( + unfold_0( + torch.unsqueeze( + torch.unsqueeze( + torch.unsqueeze( + torch.arange(0, int(value[0]), dtype=torch.float64), + 1, + ), + 0, + ), + 0, + ) + ) + .squeeze(0) + .type(torch.int64) + ) + + coordinates_1: torch.Tensor = ( + unfold_1( + torch.unsqueeze( + torch.unsqueeze( + torch.unsqueeze( + torch.arange(0, int(value[1]), dtype=torch.float64), + 0, + ), + 0, + ), + 0, + ) + ) + .squeeze(0) + .type(torch.int64) + ) + + return coordinates_0, coordinates_1 + + def _initial_random_weights(self, noise_amplitude: torch.Tensor) -> torch.Tensor: + """Creates initial weights + Uniform plus random noise plus normalization + """ + + assert torch.numel(noise_amplitude) == 1 + assert noise_amplitude.item() >= 0 + assert noise_amplitude.dtype == torch.float64 + + assert self._number_of_neurons is not None + assert self._number_of_input_neurons is not None + assert self._kernel_size is not None + + weights = torch.empty( + ( + int(self._kernel_size[0]), + int(self._kernel_size[1]), + int(self._number_of_input_neurons), + int(self._number_of_neurons), + ), + dtype=torch.float64, + ) + torch.nn.init.uniform_(weights, a=1.0, b=(1.0 + noise_amplitude.item())) + + return weights + + def _make_pooling_weights(self) -> torch.Tensor: + """For generating the pooling weights.""" + + assert self._number_of_neurons is not None + assert self._kernel_size is not None + + norm: float = 1.0 / (self._kernel_size[0] * self._kernel_size[1]) + + weights: torch.Tensor = torch.zeros( + ( + int(self._kernel_size[0]), + int(self._kernel_size[1]), + int(self._number_of_neurons), + int(self._number_of_neurons), + ), + dtype=torch.float64, + ) + + for i in range(0, int(self._number_of_neurons)): + weights[:, :, i, i] = norm + + return weights + + def initialize_weights( + self, + is_pooling_layer: bool = False, + noise_amplitude: float = 0.01, + ) -> None: + """For the generation of the initital weights. + Switches between normal initial random weights and pooling weights.""" + + assert self._kernel_size is not None + + if is_pooling_layer is True: + weights = self._make_pooling_weights() + else: + weights = self._initial_random_weights( + torch.tensor(noise_amplitude, dtype=torch.float64) + ) + + weights = weights.moveaxis(-1, 0).moveaxis(-1, 1) + + weights_t = torch.nn.functional.unfold( + input=weights, + kernel_size=(int(self._kernel_size[0]), int(self._kernel_size[1])), + dilation=(1, 1), + padding=(0, 0), + stride=(1, 1), + ).squeeze() + + weights_t = torch.moveaxis(weights_t, 0, 1) + + self.weights = weights_t + + def initialize_epsilon_xy( + self, + eps_xy_intitial: float, + ) -> None: + """Creates initial epsilon xy matrices""" + + assert self._output_size is not None + assert eps_xy_intitial > 0 + + eps_xy_temp: torch.Tensor = torch.full( + (int(self._output_size[0]), int(self._output_size[1])), + eps_xy_intitial, + dtype=torch.float64, + ) + + self.epsilon_xy = eps_xy_temp + + def set_h_init_to_uniform(self) -> None: + + assert self._number_of_neurons is not None + + h_initial: torch.Tensor = torch.full( + (int(self._number_of_neurons.item()),), + (1.0 / float(self._number_of_neurons.item())), + dtype=torch.float32, + ) + + self.h_initial = h_initial + + # Epsilon XY + def backup_epsilon_xy(self) -> None: + assert self._epsilon_xy_exists is True + self._epsilon_xy_backup = self._epsilon_xy.data.clone() + + def restore_epsilon_xy(self) -> None: + assert self._epsilon_xy_backup is not None + assert self._epsilon_xy_exists is True + self._epsilon_xy.data = self._epsilon_xy_backup.clone() + + def mean_epsilon_xy(self) -> None: + assert self._epsilon_xy_exists is True + + fill_value: float = float(self._epsilon_xy.data.mean()) + self._epsilon_xy.data = torch.full_like( + self._epsilon_xy.data, fill_value, dtype=torch.float64 + ) + + def threshold_epsilon_xy(self, threshold: float) -> None: + assert self._epsilon_xy_exists is True + assert threshold >= 0 + torch.clamp( + self._epsilon_xy.data, + min=float(threshold), + max=None, + out=self._epsilon_xy.data, + ) + + # Weights + def backup_weights(self) -> None: + assert self._weights_exists is True + self._weights_backup = self._weights.data.clone() + + def restore_weights(self) -> None: + assert self._weights_backup is not None + assert self._weights_exists is True + self._weights.data = self._weights_backup.clone() + + def norm_weights(self) -> None: + assert self._weights_exists is True + temp: torch.Tensor = ( + self._weights.data.detach() + .clone(memory_format=torch.contiguous_format) + .type(dtype=torch.float64) + ) + temp /= temp.sum(dim=0, keepdim=True, dtype=torch.float64) + self._weights.data = temp + + def threshold_weights(self, threshold: float) -> None: + assert self._weights_exists is True + assert threshold >= 0 + torch.clamp( + self._weights.data, + min=float(threshold), + max=None, + out=self._weights.data, + ) + + +class FunctionalSbS(torch.autograd.Function): + @staticmethod + def forward( # type: ignore + ctx, + input_float64: torch.Tensor, + epsilon_xy_float64: torch.Tensor, + epsilon_0_float64: torch.Tensor, + epsilon_t_float64: torch.Tensor, + weights_float64: torch.Tensor, + kernel_size: torch.Tensor, + stride: torch.Tensor, + dilation: torch.Tensor, + padding: torch.Tensor, + output_size: torch.Tensor, + number_of_spikes: torch.Tensor, + number_of_cpu_processes: torch.Tensor, + h_initial: torch.Tensor, + alpha_number_of_iterations: torch.Tensor, + ) -> torch.Tensor: + + input = input_float64.type(dtype=torch.float32) + epsilon_xy = epsilon_xy_float64.type(dtype=torch.float32) + weights = weights_float64.type(dtype=torch.float32) + epsilon_0 = epsilon_0_float64.type(dtype=torch.float32) + epsilon_t = epsilon_t_float64.type(dtype=torch.float32) + + assert input.dim() == 4 + assert torch.numel(kernel_size) == 2 + assert torch.numel(dilation) == 2 + assert torch.numel(padding) == 2 + assert torch.numel(stride) == 2 + assert torch.numel(output_size) == 2 + + assert torch.numel(epsilon_0) == 1 + assert torch.numel(number_of_spikes) == 1 + assert torch.numel(number_of_cpu_processes) == 1 + assert torch.numel(alpha_number_of_iterations) == 1 + + input_size = torch.tensor([input.shape[2], input.shape[3]]) + + ############################################################ + # Pre convolving the input # + ############################################################ + + input_convolved_temp = torch.nn.functional.unfold( + input, + kernel_size=tuple(kernel_size.tolist()), + dilation=tuple(dilation.tolist()), + padding=tuple(padding.tolist()), + stride=tuple(stride.tolist()), + ) + + input_convolved = torch.nn.functional.fold( + input_convolved_temp, + output_size=tuple(output_size.tolist()), + kernel_size=(1, 1), + dilation=(1, 1), + padding=(0, 0), + stride=(1, 1), + ).requires_grad_(True) + + ############################################################ + # Spike generation # + ############################################################ + + if cpp_spike is False: + # Alternative to the C++ module but 5x slower: + spikes = ( + ( + input_convolved.movedim(source=(2, 3), destination=(0, 1)) + .reshape( + shape=( + input_convolved.shape[2] + * input_convolved.shape[3] + * input_convolved.shape[0], + input_convolved.shape[1], + ) + ) + .multinomial( + num_samples=int(number_of_spikes.item()), replacement=True + ) + ) + .reshape( + shape=( + input_convolved.shape[2], + input_convolved.shape[3], + input_convolved.shape[0], + int(number_of_spikes.item()), + ) + ) + .movedim(source=(0, 1), destination=(2, 3)) + ).contiguous(memory_format=torch.contiguous_format) + else: + # Normalized cumsum + input_cumsum: torch.Tensor = torch.cumsum( + input_convolved, dim=1, dtype=torch.float32 + ) + input_cumsum_last: torch.Tensor = input_cumsum[:, -1, :, :].unsqueeze(1) + input_cumsum /= input_cumsum_last + + random_values = torch.rand( + size=[ + input_cumsum.shape[0], + int(number_of_spikes.item()), + input_cumsum.shape[2], + input_cumsum.shape[3], + ], + dtype=torch.float32, + ) + + spikes = torch.empty_like(random_values, dtype=torch.int64) + + # Prepare for Export (Pointer and stuff)-> + np_input: np.ndarray = input_cumsum.detach().numpy() + assert input_cumsum.dtype == torch.float32 + assert np_input.flags["C_CONTIGUOUS"] is True + assert np_input.ndim == 4 + + np_random_values: np.ndarray = random_values.detach().numpy() + assert random_values.dtype == torch.float32 + assert np_random_values.flags["C_CONTIGUOUS"] is True + assert np_random_values.ndim == 4 + + np_spikes: np.ndarray = spikes.detach().numpy() + assert spikes.dtype == torch.int64 + assert np_spikes.flags["C_CONTIGUOUS"] is True + assert np_spikes.ndim == 4 + # <- Prepare for Export + + spike_generation: PySpikeGeneration2DManyIP.SpikeGeneration2DManyIP = ( + PySpikeGeneration2DManyIP.SpikeGeneration2DManyIP() + ) + + spike_generation.spike_generation_multi_pattern( + np_input.__array_interface__["data"][0], + int(np_input.shape[0]), + int(np_input.shape[1]), + int(np_input.shape[2]), + int(np_input.shape[3]), + np_random_values.__array_interface__["data"][0], + int(np_random_values.shape[0]), + int(np_random_values.shape[1]), + int(np_random_values.shape[2]), + int(np_random_values.shape[3]), + np_spikes.__array_interface__["data"][0], + int(np_spikes.shape[0]), + int(np_spikes.shape[1]), + int(np_spikes.shape[2]), + int(np_spikes.shape[3]), + int(number_of_cpu_processes.item()), + ) + + ############################################################ + # H dynamic # + ############################################################ + + assert epsilon_t.ndim == 1 + assert epsilon_t.shape[0] >= number_of_spikes + + if cpp_sbs is False: + h = torch.tile( + h_initial.unsqueeze(0).unsqueeze(0).unsqueeze(0), + dims=[int(input.shape[0]), int(output_size[0]), int(output_size[1]), 1], + ) + + epsilon_scale: torch.Tensor = torch.ones( + size=[1, int(epsilon_xy.shape[0]), int(epsilon_xy.shape[1]), 1], + dtype=torch.float32, + ) + + for t in range(0, spikes.shape[1]): + + if epsilon_scale.max() > 1e10: + h /= epsilon_scale + epsilon_scale = torch.ones_like(epsilon_scale) + + h_temp: torch.Tensor = weights[spikes[:, t, :, :], :] * h + epsilon_subsegment: torch.Tensor = ( + epsilon_xy.unsqueeze(0).unsqueeze(-1) * epsilon_t[t] * epsilon_0 + ) + h_temp_sum: torch.Tensor = ( + epsilon_scale * epsilon_subsegment / h_temp.sum(dim=3, keepdim=True) + ) + torch.nan_to_num( + h_temp_sum, out=h_temp_sum, nan=0.0, posinf=0.0, neginf=0.0 + ) + h_temp *= h_temp_sum + h += h_temp + + epsilon_scale *= 1.0 + epsilon_subsegment + + h /= epsilon_scale + output = h.movedim(3, 1) + else: + epsilon_t_0: torch.Tensor = epsilon_t * epsilon_0 + + h_shape: tuple[int, int, int, int] = ( + int(input.shape[0]), + int(weights.shape[1]), + int(output_size[0]), + int(output_size[1]), + ) + + output = torch.empty(h_shape, dtype=torch.float32) + + # Prepare the export to C++ -> + np_h: np.ndarray = output.detach().numpy() + assert output.dtype == torch.float32 + assert np_h.flags["C_CONTIGUOUS"] is True + assert np_h.ndim == 4 + + np_epsilon_xy: np.ndarray = epsilon_xy.detach().numpy() + assert epsilon_xy.dtype == torch.float32 + assert np_epsilon_xy.flags["C_CONTIGUOUS"] is True + assert np_epsilon_xy.ndim == 2 + + np_epsilon_t: np.ndarray = epsilon_t_0.detach().numpy() + assert epsilon_t_0.dtype == torch.float32 + assert np_epsilon_t.flags["C_CONTIGUOUS"] is True + assert np_epsilon_t.ndim == 1 + + np_weights: np.ndarray = weights.detach().numpy() + assert weights.dtype == torch.float32 + assert np_weights.flags["C_CONTIGUOUS"] is True + assert np_weights.ndim == 2 + + np_spikes = spikes.contiguous().detach().numpy() + assert spikes.dtype == torch.int64 + assert np_spikes.flags["C_CONTIGUOUS"] is True + assert np_spikes.ndim == 4 + + np_h_initial = h_initial.contiguous().detach().numpy() + assert h_initial.dtype == torch.float32 + assert np_h_initial.flags["C_CONTIGUOUS"] is True + assert np_h_initial.ndim == 1 + # <- Prepare the export to C++ + + h_dynamic: PyHDynamicCNNManyIP.HDynamicCNNManyIP = ( + PyHDynamicCNNManyIP.HDynamicCNNManyIP() + ) + + h_dynamic.update_with_init_vector_multi_pattern( + np_h.__array_interface__["data"][0], + int(np_h.shape[0]), + int(np_h.shape[1]), + int(np_h.shape[2]), + int(np_h.shape[3]), + np_epsilon_xy.__array_interface__["data"][0], + int(np_epsilon_xy.shape[0]), + int(np_epsilon_xy.shape[1]), + np_epsilon_t.__array_interface__["data"][0], + int(np_epsilon_t.shape[0]), + np_weights.__array_interface__["data"][0], + int(np_weights.shape[0]), + int(np_weights.shape[1]), + np_spikes.__array_interface__["data"][0], + int(np_spikes.shape[0]), + int(np_spikes.shape[1]), + int(np_spikes.shape[2]), + int(np_spikes.shape[3]), + np_h_initial.__array_interface__["data"][0], + int(np_h_initial.shape[0]), + int(number_of_cpu_processes.item()), + ) + + ############################################################ + # Alpha # + ############################################################ + alpha_number_of_iterations_int: int = int(alpha_number_of_iterations.item()) + + if alpha_number_of_iterations_int > 0: + # Initialization + virtual_reconstruction_weight: torch.Tensor = torch.einsum( + "bixy,ji->bjxy", output, weights + ) + alpha_fill_value: float = 1.0 / ( + virtual_reconstruction_weight.shape[2] + * virtual_reconstruction_weight.shape[3] + ) + alpha_dynamic: torch.Tensor = torch.full( + ( + int(virtual_reconstruction_weight.shape[0]), + 1, + int(virtual_reconstruction_weight.shape[2]), + int(virtual_reconstruction_weight.shape[3]), + ), + alpha_fill_value, + dtype=torch.float32, + ) + + # Iterations + for _ in range(0, alpha_number_of_iterations_int): + alpha_temp: torch.Tensor = alpha_dynamic * virtual_reconstruction_weight + alpha_temp /= alpha_temp.sum(dim=3, keepdim=True).sum( + dim=2, keepdim=True + ) + torch.nan_to_num( + alpha_temp, out=alpha_temp, nan=0.0, posinf=0.0, neginf=0.0 + ) + + alpha_temp = torch.nn.functional.unfold( + alpha_temp, + kernel_size=(1, 1), + dilation=1, + padding=0, + stride=1, + ) + + alpha_temp = torch.nn.functional.fold( + alpha_temp, + output_size=tuple(input_size.tolist()), + kernel_size=tuple(kernel_size.tolist()), + dilation=tuple(dilation.tolist()), + padding=tuple(padding.tolist()), + stride=tuple(stride.tolist()), + ) + + alpha_temp = (alpha_temp * input).sum(dim=1, keepdim=True) + + alpha_temp = torch.nn.functional.unfold( + alpha_temp, + kernel_size=tuple(kernel_size.tolist()), + dilation=tuple(dilation.tolist()), + padding=tuple(padding.tolist()), + stride=tuple(stride.tolist()), + ) + + alpha_temp = torch.nn.functional.fold( + alpha_temp, + output_size=tuple(output_size.tolist()), + kernel_size=(1, 1), + dilation=(1, 1), + padding=(0, 0), + stride=(1, 1), + ) + alpha_dynamic = alpha_temp.sum(dim=1, keepdim=True) + + alpha_dynamic += torch.finfo(torch.float32).eps * 1000 + + # Alpha normalization + alpha_dynamic /= alpha_dynamic.sum(dim=3, keepdim=True).sum( + dim=2, keepdim=True + ) + torch.nan_to_num( + alpha_dynamic, out=alpha_dynamic, nan=0.0, posinf=0.0, neginf=0.0 + ) + + # Applied to the output + output *= alpha_dynamic + + ############################################################ + # Save the necessary data for the backward pass # + ############################################################ + + output = output.type(dtype=torch.float64) + + ctx.save_for_backward( + input_convolved, + epsilon_xy_float64, + epsilon_0_float64, + weights_float64, + output, + kernel_size, + stride, + dilation, + padding, + input_size, + ) + + return output + + @staticmethod + def backward(ctx, grad_output): + + # Get the variables back + ( + input_float32, + epsilon_xy, + epsilon_0, + weights, + output, + kernel_size, + stride, + dilation, + padding, + input_size, + ) = ctx.saved_tensors + + input = input_float32.type(dtype=torch.float64) + input /= input.sum(dim=1, keepdim=True, dtype=torch.float64) + + # For debugging: + # print( + # f"S: O: {output.min().item():e} {output.max().item():e} I: {input.min().item():e} {input.max().item():e} G: {grad_output.min().item():e} {grad_output.max().item():e}" + # ) + + epsilon_0_float: float = epsilon_0.item() + + temp_e: torch.Tensor = 1.0 / ((epsilon_xy * epsilon_0_float) + 1.0) + + eps_a: torch.Tensor = temp_e.clone() + eps_a *= epsilon_xy * epsilon_0_float + + eps_b: torch.Tensor = temp_e**2 * epsilon_0_float + + backprop_r: torch.Tensor = weights.unsqueeze(0).unsqueeze(-1).unsqueeze( + -1 + ) * output.unsqueeze(1) + + backprop_bigr: torch.Tensor = backprop_r.sum(axis=2) + + temp: torch.Tensor = input / backprop_bigr**2 + + backprop_f: torch.Tensor = output.unsqueeze(1) * temp.unsqueeze(2) + torch.nan_to_num( + backprop_f, out=backprop_f, nan=1e300, posinf=1e300, neginf=-1e300 + ) + torch.clip(backprop_f, out=backprop_f, min=-1e300, max=1e300) + + tempz: torch.Tensor = 1.0 / backprop_bigr + + backprop_z: torch.Tensor = backprop_r * tempz.unsqueeze(2) + torch.nan_to_num( + backprop_z, out=backprop_z, nan=1e300, posinf=1e300, neginf=-1e300 + ) + torch.clip(backprop_z, out=backprop_z, min=-1e300, max=1e300) + + backprop_y: torch.Tensor = ( + torch.einsum("bijxy,bixy->bjxy", backprop_z, input) - output + ) + + result_omega: torch.Tensor = backprop_bigr.unsqueeze(2) * grad_output.unsqueeze( + 1 + ) + result_omega -= torch.einsum( + "bijxy,bjxy->bixy", backprop_r, grad_output + ).unsqueeze(2) + result_omega *= backprop_f + torch.nan_to_num( + result_omega, out=result_omega, nan=1e300, posinf=1e300, neginf=-1e300 + ) + torch.clip(result_omega, out=result_omega, min=-1e300, max=1e300) + + result_eps_xy: torch.Tensor = ( + torch.einsum("bixy,bixy->bxy", backprop_y, grad_output) * eps_b + ) + torch.nan_to_num( + result_eps_xy, out=result_eps_xy, nan=1e300, posinf=1e300, neginf=-1e300 + ) + torch.clip(result_eps_xy, out=result_eps_xy, min=-1e300, max=1e300) + + result_phi: torch.Tensor = torch.einsum( + "bijxy,bjxy->bixy", backprop_z, grad_output + ) * eps_a.unsqueeze(0).unsqueeze(0) + torch.nan_to_num( + result_phi, out=result_phi, nan=1e300, posinf=1e300, neginf=-1e300 + ) + torch.clip(result_phi, out=result_phi, min=-1e300, max=1e300) + + grad_weights = result_omega.sum(0).sum(-1).sum(-1) + grad_eps_xy = result_eps_xy.sum(0) + + grad_input = torch.nn.functional.fold( + torch.nn.functional.unfold( + result_phi, + kernel_size=(1, 1), + dilation=1, + padding=0, + stride=1, + ), + output_size=input_size, + kernel_size=kernel_size, + dilation=dilation, + padding=padding, + stride=stride, + ) + + torch.nan_to_num( + grad_input, out=grad_input, nan=1e300, posinf=1e300, neginf=-1e300 + ) + torch.clip(grad_input, out=grad_input, min=-1e300, max=1e300) + + torch.nan_to_num( + grad_eps_xy, out=grad_eps_xy, nan=1e300, posinf=1e300, neginf=-1e300 + ) + torch.clip(grad_eps_xy, out=grad_eps_xy, min=-1e300, max=1e300) + + torch.nan_to_num( + grad_weights, out=grad_weights, nan=1e300, posinf=1e300, neginf=-1e300 + ) + torch.clip(grad_weights, out=grad_weights, min=-1e300, max=1e300) + + grad_epsilon_0 = None + grad_epsilon_t = None + grad_kernel_size = None + grad_stride = None + grad_dilation = None + grad_padding = None + grad_output_size = None + grad_number_of_spikes = None + grad_number_of_cpu_processes = None + grad_h_initial = None + grad_alpha_number_of_iterations = None + + return ( + grad_input, + grad_eps_xy, + grad_epsilon_0, + grad_epsilon_t, + grad_weights, + grad_kernel_size, + grad_stride, + grad_dilation, + grad_padding, + grad_output_size, + grad_number_of_spikes, + grad_number_of_cpu_processes, + grad_h_initial, + grad_alpha_number_of_iterations, + ) diff --git a/learn_it.py b/learn_it.py new file mode 100644 index 0000000..5d3ff57 --- /dev/null +++ b/learn_it.py @@ -0,0 +1,590 @@ +# MIT License +# Copyright 2022 University of Bremen +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR +# THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# +# David Rotermund ( davrot@uni-bremen.de ) +# +# +# Release history: +# ================ +# 1.0.0 -- 01.05.2022: first release +# +# + +# %% +import os + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + +import numpy as np +import sys +import torch +import time +import dataconf +import logging +from datetime import datetime + +from Dataset import ( + DatasetMaster, + DatasetCIFAR, + DatasetMNIST, + DatasetFashionMNIST, +) +from Parameter import Config +from SbS import SbS + +from torch.utils.tensorboard import SummaryWriter + +tb = SummaryWriter() + +####################################################################### +# We want to log what is going on into a file and screen # +####################################################################### + +now = datetime.now() +dt_string_filename = now.strftime("%Y_%m_%d_%H_%M_%S") +logging.basicConfig( + filename="log_" + dt_string_filename + ".txt", + filemode="w", + level=logging.INFO, + format="%(asctime)s %(message)s", +) +logging.getLogger().addHandler(logging.StreamHandler()) + +####################################################################### +# Load the config data from the json file # +####################################################################### + +if len(sys.argv) < 2: + raise Exception("Argument: Config file name is missing") + +filename: str = sys.argv[1] + +if os.path.exists(filename) is False: + raise Exception(f"Config file not found! {filename}") + +cfg = dataconf.file(filename, Config) +logging.info(f"Using configuration file: {filename}") + + +####################################################################### +# Prepare the test and training data # +####################################################################### + +# Load the input data +the_dataset_train: DatasetMaster +the_dataset_test: DatasetMaster +if cfg.data_mode == "CIFAR10": + the_dataset_train = DatasetCIFAR( + train=True, path_pattern=cfg.data_path, path_label=cfg.data_path + ) + the_dataset_test = DatasetCIFAR( + train=False, path_pattern=cfg.data_path, path_label=cfg.data_path + ) +elif cfg.data_mode == "MNIST": + the_dataset_train = DatasetMNIST( + train=True, path_pattern=cfg.data_path, path_label=cfg.data_path + ) + the_dataset_test = DatasetMNIST( + train=False, path_pattern=cfg.data_path, path_label=cfg.data_path + ) +elif cfg.data_mode == "MNIST_FASHION": + the_dataset_train = DatasetFashionMNIST( + train=True, path_pattern=cfg.data_path, path_label=cfg.data_path + ) + the_dataset_test = DatasetFashionMNIST( + train=False, path_pattern=cfg.data_path, path_label=cfg.data_path + ) +else: + raise Exception("data_mode unknown") + +cfg.image_statistics.mean = the_dataset_train.mean + +# The basic size +cfg.image_statistics.the_size = [ + the_dataset_train.pattern_storage.shape[2], + the_dataset_train.pattern_storage.shape[3], +] + +# Minus the stuff we cut away in the pattern filter +cfg.image_statistics.the_size[0] -= 2 * cfg.augmentation.crop_width_in_pixel +cfg.image_statistics.the_size[1] -= 2 * cfg.augmentation.crop_width_in_pixel + +my_loader_test: torch.utils.data.DataLoader = torch.utils.data.DataLoader( + the_dataset_test, batch_size=cfg.batch_size, shuffle=False +) +my_loader_train: torch.utils.data.DataLoader = torch.utils.data.DataLoader( + the_dataset_train, batch_size=cfg.batch_size, shuffle=True +) + +logging.info("*** Data loaded.") + +####################################################################### +# Build the network # +####################################################################### + +wf: list[np.ndarray] = [] +eps_xy: list[np.ndarray] = [] +network = torch.nn.Sequential() +for id in range(0, len(cfg.network_structure.is_pooling_layer)): + if id == 0: + input_size: list[int] = cfg.image_statistics.the_size + else: + input_size = network[id - 1].output_size.tolist() + + network.append( + SbS( + number_of_input_neurons=cfg.network_structure.forward_neuron_numbers[id][0], + number_of_neurons=cfg.network_structure.forward_neuron_numbers[id][1], + input_size=input_size, + forward_kernel_size=cfg.network_structure.forward_kernel_size[id], + number_of_spikes=cfg.number_of_spikes, + epsilon_t=cfg.get_epsilon_t(), + epsilon_xy_intitial=cfg.learning_parameters.eps_xy_intitial, + epsilon_0=cfg.epsilon_0, + weight_noise_amplitude=cfg.learning_parameters.weight_noise_amplitude, + is_pooling_layer=cfg.network_structure.is_pooling_layer[id], + strides=cfg.network_structure.strides[id], + dilation=cfg.network_structure.dilation[id], + padding=cfg.network_structure.padding[id], + alpha_number_of_iterations=cfg.learning_parameters.alpha_number_of_iterations, + number_of_cpu_processes=cfg.number_of_cpu_processes, + ) + ) + + eps_xy.append(network[id].epsilon_xy.detach().clone().numpy()) + wf.append(network[id].weights.detach().clone().numpy()) + +logging.info("*** Network generated.") + +for id in range(0, len(network)): + # Load previous weights and epsilon xy + if cfg.learning_step > 0: + network[id].weights = torch.tensor( + np.load( + cfg.weight_path + + "/Weight_L" + + str(id) + + "_S" + + str(cfg.learning_step) + + ".npy" + ), + dtype=torch.float64, + ) + + wf[id] = np.load( + cfg.weight_path + + "/Weight_L" + + str(id) + + "_S" + + str(cfg.learning_step) + + ".npy" + ) + + network[id].epsilon_xy = torch.tensor( + np.load( + cfg.eps_xy_path + + "/EpsXY_L" + + str(id) + + "_S" + + str(cfg.learning_step) + + ".npy" + ), + dtype=torch.float64, + ) + + eps_xy[id] = np.load( + cfg.eps_xy_path + + "/EpsXY_L" + + str(id) + + "_S" + + str(cfg.learning_step) + + ".npy" + ) + + # Are there weights that overwrite the initial weights? + file_to_load: str = ( + cfg.learning_parameters.overload_path + "/Weight_L" + str(id) + ".npy" + ) + if os.path.exists(file_to_load) is True: + network[id].weights = torch.tensor( + np.load(file_to_load), + dtype=torch.float64, + ) + wf[id] = np.load(file_to_load) + logging.info(f"File used: {file_to_load}") + + file_to_load = cfg.learning_parameters.overload_path + "/EpsXY_L" + str(id) + ".npy" + if os.path.exists(file_to_load) is True: + network[id].epsilon_xy = torch.tensor( + np.load(file_to_load), + dtype=torch.float64, + ) + eps_xy[id] = np.load(file_to_load) + logging.info(f"File used: {file_to_load}") + +####################################################################### +# Optimizer and LR Scheduler # +####################################################################### + +# I keep weights and epsilon xy seperate to +# set the initial learning rate independently +parameter_list_weights: list = [] +parameter_list_epsilon_xy: list = [] + +for id in range(0, len(network)): + parameter_list_weights.append(network[id]._weights) + parameter_list_epsilon_xy.append(network[id]._epsilon_xy) + +if cfg.learning_parameters.optimizer_name == "Adam": + if cfg.learning_parameters.learning_rate_gamma_w > 0: + optimizer_wf: torch.optim.Optimizer = torch.optim.Adam( + parameter_list_weights, + lr=cfg.learning_parameters.learning_rate_gamma_w, + ) + else: + optimizer_wf = torch.optim.Adam( + parameter_list_weights, + ) + + if cfg.learning_parameters.learning_rate_gamma_eps_xy > 0: + optimizer_eps: torch.optim.Optimizer = torch.optim.Adam( + parameter_list_epsilon_xy, + lr=cfg.learning_parameters.learning_rate_gamma_eps_xy, + ) + else: + optimizer_eps = torch.optim.Adam( + parameter_list_epsilon_xy, + ) +else: + raise Exception("Optimizer not implemented") + +if cfg.learning_parameters.lr_schedule_name == "ReduceLROnPlateau": + lr_scheduler_wf = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer_wf, + factor=cfg.learning_parameters.lr_scheduler_factor, + patience=cfg.learning_parameters.lr_scheduler_patience, + ) + + lr_scheduler_eps = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer_eps, + factor=cfg.learning_parameters.lr_scheduler_factor, + patience=cfg.learning_parameters.lr_scheduler_patience, + ) +else: + raise Exception("lr_scheduler not implemented") + +logging.info("*** Optimizer prepared.") + + +####################################################################### +# Some variable declarations # +####################################################################### + +test_correct: int = 0 +test_all: int = 0 +test_complete: int = the_dataset_test.__len__() + +train_correct: int = 0 +train_all: int = 0 +train_complete: int = the_dataset_train.__len__() + +train_number_of_processed_pattern: int = 0 + +train_loss: np.ndarray = np.zeros((1), dtype=np.float32) + +last_test_performance: float = -1.0 + + +logging.info("") + +with torch.no_grad(): + if cfg.learning_parameters.learning_active is True: + while True: + + ############################################### + # Run a training data batch # + ############################################### + + for h_x, h_x_labels in my_loader_train: + time_0: float = time.perf_counter() + + if train_number_of_processed_pattern == 0: + # Reset the gradient of the torch optimizers + optimizer_wf.zero_grad() + optimizer_eps.zero_grad() + + with torch.enable_grad(): + + h_collection = [] + h_collection.append( + the_dataset_train.pattern_filter_train(h_x, cfg).type( + dtype=torch.float64 + ) + ) + for id in range(0, len(network)): + h_collection.append(network[id](h_collection[-1])) + + # Convert label into one hot + target_one_hot: torch.Tensor = torch.zeros( + ( + h_x_labels.shape[0], + int(cfg.network_structure.number_of_output_neurons), + ) + ) + target_one_hot.scatter_( + 1, h_x_labels.unsqueeze(1), torch.ones((h_x_labels.shape[0], 1)) + ) + target_one_hot = ( + target_one_hot.unsqueeze(2) + .unsqueeze(2) + .type(dtype=torch.float64) + ) + + # through the loss functions + h_y1 = torch.log(h_collection[-1]) + h_y2 = torch.nan_to_num(h_y1, nan=0.0, posinf=0.0, neginf=0.0) + + my_loss: torch.Tensor = ( + ( + torch.nn.functional.mse_loss( + h_collection[-1], target_one_hot, reduction="none" + ) + * cfg.learning_parameters.loss_coeffs_mse + + torch.nn.functional.kl_div( + h_y2, target_one_hot, reduction="none" + ) + * cfg.learning_parameters.loss_coeffs_kldiv + ) + / ( + cfg.learning_parameters.loss_coeffs_kldiv + + cfg.learning_parameters.loss_coeffs_mse + ) + ).mean() + + time_1: float = time.perf_counter() + + my_loss.backward() + my_loss_float = my_loss.item() + time_2: float = time.perf_counter() + + train_correct += ( + (h_collection[-1].argmax(dim=1).squeeze() == h_x_labels) + .sum() + .numpy() + ) + train_all += h_collection[-1].shape[0] + + performance: float = 100.0 * train_correct / train_all + + time_measure_a: float = time_1 - time_0 + + logging.info( + ( + f"{cfg.learning_step:^6} Training \t{train_all^6} pattern " + f"with {performance/100.0:^6.2%} " + f"\t\tForward time: \t{time_measure_a:^6.2f}sec" + ) + ) + + train_loss[0] += my_loss_float + train_number_of_processed_pattern += h_collection[-1].shape[0] + + time_measure_b: float = time_2 - time_1 + + logging.info( + ( + f"\t\t\tLoss: {train_loss[0]/train_number_of_processed_pattern:^15.3e} " + f"\t\t\tBackward time: \t{time_measure_b:^6.2f}sec " + ) + ) + + if ( + train_number_of_processed_pattern + >= cfg.get_update_after_x_pattern() + ): + logging.info("\t\t\t*** Updating the weights ***") + my_loss_for_batch: float = ( + train_loss[0] / train_number_of_processed_pattern + ) + + optimizer_wf.step() + optimizer_eps.step() + + for id in range(0, len(network)): + if cfg.network_structure.w_trainable[id] is True: + network[id].norm_weights() + network[id].threshold_weights( + cfg.learning_parameters.learning_rate_threshold_w + ) + network[id].norm_weights() + else: + network[id].weights = torch.tensor( + wf[id], dtype=torch.float64 + ) + + if cfg.network_structure.eps_xy_trainable[id] is True: + network[id].threshold_epsilon_xy( + cfg.learning_parameters.learning_rate_threshold_eps_xy + ) + else: + network[id].epsilon_xy = torch.tensor( + eps_xy[id], dtype=torch.float64 + ) + + # Save the new values + np.save( + cfg.weight_path + + "/Weight_L" + + str(id) + + "_S" + + str(cfg.learning_step) + + ".npy", + network[id].weights.detach().numpy(), + ) + + try: + tb.add_histogram( + "Weights " + str(id), + network[id].weights, + cfg.learning_step, + ) + except ValueError: + pass + + np.save( + cfg.eps_xy_path + + "/EpsXY_L" + + str(id) + + "_S" + + str(cfg.learning_step) + + ".npy", + network[id].epsilon_xy.detach().numpy(), + ) + try: + tb.add_histogram( + "Epsilon XY " + str(id), + network[id].epsilon_xy.detach().numpy(), + cfg.learning_step, + ) + except ValueError: + pass + + # Let the torch learning rate scheduler update the + # learning rates of the optimiers + lr_scheduler_wf.step(my_loss_for_batch) + lr_scheduler_eps.step(my_loss_for_batch) + + tb.add_scalar( + "Train Performance", 100.0 - performance, cfg.learning_step + ) + tb.add_scalar("Train Loss", my_loss_for_batch, cfg.learning_step) + tb.add_scalar( + "Learning Rate Scale WF", + optimizer_wf.param_groups[-1]["lr"], + cfg.learning_step, + ) + tb.add_scalar( + "Learning Rate Scale Eps XY ", + optimizer_eps.param_groups[-1]["lr"], + cfg.learning_step, + ) + + cfg.learning_step += 1 + train_loss = np.zeros((1), dtype=np.float32) + train_correct = 0 + train_all = 0 + performance = 0 + train_number_of_processed_pattern = 0 + + tb.flush() + + test_correct = 0 + test_all = 0 + + if last_test_performance < 0: + logging.info("") + else: + logging.info( + f"\t\t\tLast test performance: {last_test_performance/100.0:^6.2%}" + ) + logging.info("") + + ############################################### + # Run a test data performance measurement # + ############################################### + if ( + ( + ( + ( + cfg.learning_step + % cfg.learning_parameters.test_every_x_learning_steps + ) + == 0 + ) + or (cfg.learning_step == cfg.learning_step_max) + ) + and (cfg.learning_parameters.test_during_learning is True) + and (cfg.learning_step > 0) + ): + logging.info("") + logging.info("Testing:") + + for h_x, h_x_labels in my_loader_test: + time_0 = time.perf_counter() + + h_h: torch.Tensor = network( + the_dataset_train.pattern_filter_test(h_x, cfg).type( + dtype=torch.float64 + ) + ) + + test_correct += ( + (h_h.argmax(dim=1).squeeze() == h_x_labels) + .sum() + .numpy() + ) + test_all += h_h.shape[0] + performance = 100.0 * test_correct / test_all + time_1 = time.perf_counter() + time_measure_a = time_1 - time_0 + + logging.info( + ( + f"\t\t{test_all} of {test_complete}" + f" with {performance/100:^6.2%} \t Time used: {time_measure_a:^6.2f}sec" + ) + ) + + logging.info("") + + last_test_performance = performance + + tb.add_scalar( + "Test Error", 100.0 - performance, cfg.learning_step + ) + tb.flush() + + if cfg.learning_step == cfg.learning_step_max: + tb.close() + exit(1) + +# %%