First version
This commit is contained in:
parent
c224b86e20
commit
2928901a38
6 changed files with 2432 additions and 0 deletions
422
Dataset.py
Normal file
422
Dataset.py
Normal file
|
@ -0,0 +1,422 @@
|
||||||
|
# MIT License
|
||||||
|
# Copyright 2022 University of Bremen
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
# a copy of this software and associated documentation files (the "Software"),
|
||||||
|
# to deal in the Software without restriction, including without limitation
|
||||||
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included
|
||||||
|
# in all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||||
|
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
||||||
|
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
||||||
|
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR
|
||||||
|
# THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# David Rotermund ( davrot@uni-bremen.de )
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Release history:
|
||||||
|
# ================
|
||||||
|
# 1.0.0 -- 01.05.2022: first release
|
||||||
|
#
|
||||||
|
#
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import torchvision as tv # type: ignore
|
||||||
|
from Parameter import Config
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetMaster(torch.utils.data.Dataset, ABC):
|
||||||
|
|
||||||
|
path_label: str
|
||||||
|
label_storage: np.ndarray
|
||||||
|
pattern_storage: np.ndarray
|
||||||
|
number_of_pattern: int
|
||||||
|
mean: list[float]
|
||||||
|
|
||||||
|
# Initialize
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
train: bool = False,
|
||||||
|
path_pattern: str = "./",
|
||||||
|
path_label: str = "./",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if train is True:
|
||||||
|
self.label_storage = np.load(path_label + "/TrainLabelStorage.npy")
|
||||||
|
else:
|
||||||
|
self.label_storage = np.load(path_label + "/TestLabelStorage.npy")
|
||||||
|
|
||||||
|
if train is True:
|
||||||
|
self.pattern_storage = np.load(path_pattern + "/TrainPatternStorage.npy")
|
||||||
|
else:
|
||||||
|
self.pattern_storage = np.load(path_pattern + "/TestPatternStorage.npy")
|
||||||
|
|
||||||
|
self.number_of_pattern = self.label_storage.shape[0]
|
||||||
|
|
||||||
|
self.mean = []
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self.number_of_pattern
|
||||||
|
|
||||||
|
# Get one pattern at position index
|
||||||
|
@abstractmethod
|
||||||
|
def __getitem__(self, index: int) -> tuple[torch.Tensor, int]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def pattern_filter_test(self, pattern: torch.Tensor, cfg: Config) -> torch.Tensor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def pattern_filter_train(self, pattern: torch.Tensor, cfg: Config) -> torch.Tensor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetMNIST(DatasetMaster):
|
||||||
|
"""Contstructor"""
|
||||||
|
|
||||||
|
# Initialize
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
train: bool = False,
|
||||||
|
path_pattern: str = "./",
|
||||||
|
path_label: str = "./",
|
||||||
|
) -> None:
|
||||||
|
super().__init__(train, path_pattern, path_label)
|
||||||
|
|
||||||
|
self.pattern_storage = np.ascontiguousarray(
|
||||||
|
self.pattern_storage[:, np.newaxis, :, :].astype(dtype=np.float32)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pattern_storage /= np.max(self.pattern_storage)
|
||||||
|
|
||||||
|
mean = self.pattern_storage.mean(3).mean(2).mean(0)
|
||||||
|
self.mean = [*mean]
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> tuple[torch.Tensor, int]:
|
||||||
|
|
||||||
|
image = self.pattern_storage[index, 0:1, :, :]
|
||||||
|
target = int(self.label_storage[index])
|
||||||
|
return torch.tensor(image), target
|
||||||
|
|
||||||
|
def pattern_filter_test(self, pattern: torch.Tensor, cfg: Config) -> torch.Tensor:
|
||||||
|
"""0. The test image comes in
|
||||||
|
1. is center cropped
|
||||||
|
2. on/off filteres
|
||||||
|
3. returned.
|
||||||
|
|
||||||
|
This is a 1 channel version (e.g. one gray channel).
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert len(cfg.image_statistics.mean) == 1
|
||||||
|
assert len(cfg.image_statistics.the_size) == 2
|
||||||
|
assert cfg.image_statistics.the_size[0] > 0
|
||||||
|
assert cfg.image_statistics.the_size[1] > 0
|
||||||
|
|
||||||
|
# Transformation chain
|
||||||
|
my_transforms: torch.nn.Sequential = torch.nn.Sequential(
|
||||||
|
tv.transforms.CenterCrop(size=cfg.image_statistics.the_size),
|
||||||
|
)
|
||||||
|
scripted_transforms = torch.jit.script(my_transforms)
|
||||||
|
|
||||||
|
# Preprocess the input data
|
||||||
|
pattern = scripted_transforms(pattern)
|
||||||
|
|
||||||
|
# => On/Off
|
||||||
|
my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0])
|
||||||
|
gray: torch.Tensor = my_on_off_filter(
|
||||||
|
pattern[:, 0:1, :, :],
|
||||||
|
)
|
||||||
|
|
||||||
|
return gray
|
||||||
|
|
||||||
|
def pattern_filter_train(self, pattern: torch.Tensor, cfg: Config) -> torch.Tensor:
|
||||||
|
"""0. The training image comes in
|
||||||
|
1. is cropped from a random position
|
||||||
|
2. on/off filteres
|
||||||
|
3. returned.
|
||||||
|
|
||||||
|
This is a 1 channel version (e.g. one gray channel).
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert len(cfg.image_statistics.mean) == 1
|
||||||
|
assert len(cfg.image_statistics.the_size) == 2
|
||||||
|
assert cfg.image_statistics.the_size[0] > 0
|
||||||
|
assert cfg.image_statistics.the_size[1] > 0
|
||||||
|
|
||||||
|
# Transformation chain
|
||||||
|
my_transforms: torch.nn.Sequential = torch.nn.Sequential(
|
||||||
|
tv.transforms.RandomCrop(size=cfg.image_statistics.the_size),
|
||||||
|
)
|
||||||
|
scripted_transforms = torch.jit.script(my_transforms)
|
||||||
|
|
||||||
|
# Preprocess the input data
|
||||||
|
pattern = scripted_transforms(pattern)
|
||||||
|
|
||||||
|
# => On/Off
|
||||||
|
my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0])
|
||||||
|
gray: torch.Tensor = my_on_off_filter(
|
||||||
|
pattern[:, 0:1, :, :],
|
||||||
|
)
|
||||||
|
|
||||||
|
return gray
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetFashionMNIST(DatasetMaster):
|
||||||
|
"""Contstructor"""
|
||||||
|
|
||||||
|
# Initialize
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
train: bool = False,
|
||||||
|
path_pattern: str = "./",
|
||||||
|
path_label: str = "./",
|
||||||
|
) -> None:
|
||||||
|
super().__init__(train, path_pattern, path_label)
|
||||||
|
|
||||||
|
self.pattern_storage = np.ascontiguousarray(
|
||||||
|
self.pattern_storage[:, np.newaxis, :, :].astype(dtype=np.float32)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pattern_storage /= np.max(self.pattern_storage)
|
||||||
|
|
||||||
|
mean = self.pattern_storage.mean(3).mean(2).mean(0)
|
||||||
|
self.mean = [*mean]
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> tuple[torch.Tensor, int]:
|
||||||
|
|
||||||
|
image = self.pattern_storage[index, 0:1, :, :]
|
||||||
|
target = int(self.label_storage[index])
|
||||||
|
return torch.tensor(image), target
|
||||||
|
|
||||||
|
def pattern_filter_test(self, pattern: torch.Tensor, cfg: Config) -> torch.Tensor:
|
||||||
|
"""0. The test image comes in
|
||||||
|
1. is center cropped
|
||||||
|
2. on/off filteres
|
||||||
|
3. returned.
|
||||||
|
|
||||||
|
This is a 1 channel version (e.g. one gray channel).
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert len(cfg.image_statistics.mean) == 1
|
||||||
|
assert len(cfg.image_statistics.the_size) == 2
|
||||||
|
assert cfg.image_statistics.the_size[0] > 0
|
||||||
|
assert cfg.image_statistics.the_size[1] > 0
|
||||||
|
|
||||||
|
# Transformation chain
|
||||||
|
my_transforms: torch.nn.Sequential = torch.nn.Sequential(
|
||||||
|
tv.transforms.CenterCrop(size=cfg.image_statistics.the_size),
|
||||||
|
)
|
||||||
|
scripted_transforms = torch.jit.script(my_transforms)
|
||||||
|
|
||||||
|
# Preprocess the input data
|
||||||
|
pattern = scripted_transforms(pattern)
|
||||||
|
|
||||||
|
# => On/Off
|
||||||
|
my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0])
|
||||||
|
gray: torch.Tensor = my_on_off_filter(
|
||||||
|
pattern[:, 0:1, :, :],
|
||||||
|
)
|
||||||
|
|
||||||
|
return gray
|
||||||
|
|
||||||
|
def pattern_filter_train(self, pattern: torch.Tensor, cfg: Config) -> torch.Tensor:
|
||||||
|
"""0. The training image comes in
|
||||||
|
1. is cropped from a random position
|
||||||
|
2. on/off filteres
|
||||||
|
3. returned.
|
||||||
|
|
||||||
|
This is a 1 channel version (e.g. one gray channel).
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert len(cfg.image_statistics.mean) == 1
|
||||||
|
assert len(cfg.image_statistics.the_size) == 2
|
||||||
|
assert cfg.image_statistics.the_size[0] > 0
|
||||||
|
assert cfg.image_statistics.the_size[1] > 0
|
||||||
|
|
||||||
|
# Transformation chain
|
||||||
|
my_transforms: torch.nn.Sequential = torch.nn.Sequential(
|
||||||
|
tv.transforms.RandomCrop(size=cfg.image_statistics.the_size),
|
||||||
|
tv.transforms.RandomHorizontalFlip(p=cfg.augmentation.flip_p),
|
||||||
|
tv.transforms.ColorJitter(
|
||||||
|
brightness=cfg.augmentation.jitter_brightness,
|
||||||
|
contrast=cfg.augmentation.jitter_contrast,
|
||||||
|
saturation=cfg.augmentation.jitter_saturation,
|
||||||
|
hue=cfg.augmentation.jitter_hue,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
scripted_transforms = torch.jit.script(my_transforms)
|
||||||
|
|
||||||
|
# Preprocess the input data
|
||||||
|
pattern = scripted_transforms(pattern)
|
||||||
|
|
||||||
|
# => On/Off
|
||||||
|
my_on_off_filter: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0])
|
||||||
|
gray: torch.Tensor = my_on_off_filter(
|
||||||
|
pattern[:, 0:1, :, :],
|
||||||
|
)
|
||||||
|
|
||||||
|
return gray
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetCIFAR(DatasetMaster):
|
||||||
|
"""Contstructor"""
|
||||||
|
|
||||||
|
# Initialize
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
train: bool = False,
|
||||||
|
path_pattern: str = "./",
|
||||||
|
path_label: str = "./",
|
||||||
|
) -> None:
|
||||||
|
super().__init__(train, path_pattern, path_label)
|
||||||
|
|
||||||
|
self.pattern_storage = np.ascontiguousarray(
|
||||||
|
np.moveaxis(self.pattern_storage.astype(dtype=np.float32), 3, 1)
|
||||||
|
)
|
||||||
|
self.pattern_storage /= np.max(self.pattern_storage)
|
||||||
|
|
||||||
|
mean = self.pattern_storage.mean(3).mean(2).mean(0)
|
||||||
|
self.mean = [*mean]
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> tuple[torch.Tensor, int]:
|
||||||
|
|
||||||
|
image = self.pattern_storage[index, :, :, :]
|
||||||
|
target = int(self.label_storage[index])
|
||||||
|
return torch.tensor(image), target
|
||||||
|
|
||||||
|
def pattern_filter_test(self, pattern: torch.Tensor, cfg: Config) -> torch.Tensor:
|
||||||
|
"""0. The test image comes in
|
||||||
|
1. is center cropped
|
||||||
|
2. on/off filteres
|
||||||
|
3. returned.
|
||||||
|
|
||||||
|
This is a 3 channel version (e.g. r,g,b channels).
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert len(cfg.image_statistics.mean) == 3
|
||||||
|
assert len(cfg.image_statistics.the_size) == 2
|
||||||
|
assert cfg.image_statistics.the_size[0] > 0
|
||||||
|
assert cfg.image_statistics.the_size[1] > 0
|
||||||
|
|
||||||
|
# Transformation chain
|
||||||
|
my_transforms: torch.nn.Sequential = torch.nn.Sequential(
|
||||||
|
tv.transforms.CenterCrop(size=cfg.image_statistics.the_size),
|
||||||
|
)
|
||||||
|
scripted_transforms = torch.jit.script(my_transforms)
|
||||||
|
|
||||||
|
# Preprocess the input data
|
||||||
|
pattern = scripted_transforms(pattern)
|
||||||
|
|
||||||
|
# => On/Off
|
||||||
|
|
||||||
|
my_on_off_filter_r: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0])
|
||||||
|
my_on_off_filter_g: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[1])
|
||||||
|
my_on_off_filter_b: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[2])
|
||||||
|
r: torch.Tensor = my_on_off_filter_r(
|
||||||
|
pattern[:, 0:1, :, :],
|
||||||
|
)
|
||||||
|
g: torch.Tensor = my_on_off_filter_g(
|
||||||
|
pattern[:, 1:2, :, :],
|
||||||
|
)
|
||||||
|
b: torch.Tensor = my_on_off_filter_b(
|
||||||
|
pattern[:, 2:3, :, :],
|
||||||
|
)
|
||||||
|
|
||||||
|
new_tensor: torch.Tensor = torch.cat((r, g, b), dim=1)
|
||||||
|
return new_tensor
|
||||||
|
|
||||||
|
def pattern_filter_train(self, pattern: torch.Tensor, cfg: Config) -> torch.Tensor:
|
||||||
|
"""0. The training image comes in
|
||||||
|
1. is cropped from a random position
|
||||||
|
2. is randomly horizontally flipped
|
||||||
|
3. is randomly color jitteres
|
||||||
|
4. on/off filteres
|
||||||
|
5. returned.
|
||||||
|
|
||||||
|
This is a 3 channel version (e.g. r,g,b channels).
|
||||||
|
"""
|
||||||
|
assert len(cfg.image_statistics.mean) == 3
|
||||||
|
assert len(cfg.image_statistics.the_size) == 2
|
||||||
|
assert cfg.image_statistics.the_size[0] > 0
|
||||||
|
assert cfg.image_statistics.the_size[1] > 0
|
||||||
|
|
||||||
|
# Transformation chain
|
||||||
|
my_transforms: torch.nn.Sequential = torch.nn.Sequential(
|
||||||
|
tv.transforms.RandomCrop(size=cfg.image_statistics.the_size),
|
||||||
|
tv.transforms.RandomHorizontalFlip(p=cfg.augmentation.flip_p),
|
||||||
|
tv.transforms.ColorJitter(
|
||||||
|
brightness=cfg.augmentation.jitter_brightness,
|
||||||
|
contrast=cfg.augmentation.jitter_contrast,
|
||||||
|
saturation=cfg.augmentation.jitter_saturation,
|
||||||
|
hue=cfg.augmentation.jitter_hue,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
scripted_transforms = torch.jit.script(my_transforms)
|
||||||
|
|
||||||
|
# Preprocess the input data
|
||||||
|
pattern = scripted_transforms(pattern)
|
||||||
|
|
||||||
|
# => On/Off
|
||||||
|
my_on_off_filter_r: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[0])
|
||||||
|
my_on_off_filter_g: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[1])
|
||||||
|
my_on_off_filter_b: OnOffFilter = OnOffFilter(p=cfg.image_statistics.mean[2])
|
||||||
|
r: torch.Tensor = my_on_off_filter_r(
|
||||||
|
pattern[:, 0:1, :, :],
|
||||||
|
)
|
||||||
|
g: torch.Tensor = my_on_off_filter_g(
|
||||||
|
pattern[:, 1:2, :, :],
|
||||||
|
)
|
||||||
|
b: torch.Tensor = my_on_off_filter_b(
|
||||||
|
pattern[:, 2:3, :, :],
|
||||||
|
)
|
||||||
|
|
||||||
|
new_tensor: torch.Tensor = torch.cat((r, g, b), dim=1)
|
||||||
|
return new_tensor
|
||||||
|
|
||||||
|
|
||||||
|
class OnOffFilter(torch.nn.Module):
|
||||||
|
def __init__(self, p: float = 0.5) -> None:
|
||||||
|
super(OnOffFilter, self).__init__()
|
||||||
|
self.p: float = p
|
||||||
|
|
||||||
|
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
assert tensor.shape[1] == 1
|
||||||
|
|
||||||
|
tensor_clone = 2.0 * (tensor - self.p)
|
||||||
|
|
||||||
|
temp_0: torch.Tensor = torch.where(
|
||||||
|
tensor_clone < 0.0,
|
||||||
|
-tensor_clone,
|
||||||
|
tensor_clone.new_zeros(tensor_clone.shape, dtype=tensor_clone.dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
temp_1: torch.Tensor = torch.where(
|
||||||
|
tensor_clone >= 0.0,
|
||||||
|
tensor_clone,
|
||||||
|
tensor_clone.new_zeros(tensor_clone.shape, dtype=tensor_clone.dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
new_tensor: torch.Tensor = torch.cat((temp_0, temp_1), dim=1)
|
||||||
|
|
||||||
|
return new_tensor
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return self.__class__.__name__ + "(p={0})".format(self.p)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pass
|
164
Parameter.py
Normal file
164
Parameter.py
Normal file
|
@ -0,0 +1,164 @@
|
||||||
|
# MIT License
|
||||||
|
# Copyright 2022 University of Bremen
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
# a copy of this software and associated documentation files (the "Software"),
|
||||||
|
# to deal in the Software without restriction, including without limitation
|
||||||
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included
|
||||||
|
# in all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||||
|
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
||||||
|
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
||||||
|
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR
|
||||||
|
# THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# David Rotermund ( davrot@uni-bremen.de )
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Release history:
|
||||||
|
# ================
|
||||||
|
# 1.0.0 -- 01.05.2022: first release
|
||||||
|
#
|
||||||
|
#
|
||||||
|
|
||||||
|
# %%
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Network:
|
||||||
|
"""Parameters of the network. The details about
|
||||||
|
its layers and the number of output neurons."""
|
||||||
|
|
||||||
|
number_of_output_neurons: int = field(default=0)
|
||||||
|
forward_kernel_size: list[list[int]] = field(default_factory=list)
|
||||||
|
forward_neuron_numbers: list[list[int]] = field(default_factory=list)
|
||||||
|
strides: list[list[int]] = field(default_factory=list)
|
||||||
|
dilation: list[list[int]] = field(default_factory=list)
|
||||||
|
padding: list[list[int]] = field(default_factory=list)
|
||||||
|
is_pooling_layer: list[bool] = field(default_factory=list)
|
||||||
|
w_trainable: list[bool] = field(default_factory=list)
|
||||||
|
eps_xy_trainable: list[bool] = field(default_factory=list)
|
||||||
|
eps_xy_mean: list[bool] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LearningParameters:
|
||||||
|
"""Parameter required for training"""
|
||||||
|
|
||||||
|
loss_coeffs_mse: float = field(default=0.5)
|
||||||
|
loss_coeffs_kldiv: float = field(default=1.0)
|
||||||
|
learning_rate_gamma_w: float = field(default=-1.0)
|
||||||
|
learning_rate_gamma_eps_xy: float = field(default=-1.0)
|
||||||
|
learning_rate_threshold_w: float = field(default=0.00001)
|
||||||
|
learning_rate_threshold_eps_xy: float = field(default=0.00001)
|
||||||
|
learning_active: bool = field(default=True)
|
||||||
|
weight_noise_amplitude: float = field(default=0.01)
|
||||||
|
eps_xy_intitial: float = field(default=0.1)
|
||||||
|
test_every_x_learning_steps: int = field(default=50)
|
||||||
|
test_during_learning: bool = field(default=True)
|
||||||
|
lr_scheduler_factor: float = field(default=0.75)
|
||||||
|
lr_scheduler_patience: int = field(default=10)
|
||||||
|
optimizer_name: str = field(default="Adam")
|
||||||
|
lr_schedule_name: str = field(default="ReduceLROnPlateau")
|
||||||
|
number_of_batches_for_one_update: int = field(default=1)
|
||||||
|
alpha_number_of_iterations: int = field(default=0)
|
||||||
|
overload_path: str = field(default="./Previous")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Augmentation:
|
||||||
|
"""Parameters used for data augmentation."""
|
||||||
|
|
||||||
|
crop_width_in_pixel: int = field(default=2)
|
||||||
|
flip_p: float = field(default=0.5)
|
||||||
|
jitter_brightness: float = field(default=0.5)
|
||||||
|
jitter_contrast: float = field(default=0.1)
|
||||||
|
jitter_saturation: float = field(default=0.1)
|
||||||
|
jitter_hue: float = field(default=0.15)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImageStatistics:
|
||||||
|
"""(Statistical) information about the input. i.e.
|
||||||
|
mean values and the x and y size of the input"""
|
||||||
|
|
||||||
|
mean: list[float] = field(default_factory=list)
|
||||||
|
the_size: list[int] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Config:
|
||||||
|
"""Master config class."""
|
||||||
|
|
||||||
|
# Sub classes
|
||||||
|
network_structure: Network = field(default_factory=Network)
|
||||||
|
learning_parameters: LearningParameters = field(default_factory=LearningParameters)
|
||||||
|
augmentation: Augmentation = field(default_factory=Augmentation)
|
||||||
|
image_statistics: ImageStatistics = field(default_factory=ImageStatistics)
|
||||||
|
|
||||||
|
batch_size: int = field(default=500)
|
||||||
|
data_mode: str = field(default="")
|
||||||
|
|
||||||
|
learning_step: int = field(default=0)
|
||||||
|
learning_step_max: int = field(default=10000)
|
||||||
|
|
||||||
|
number_of_cpu_processes: int = field(default=-1)
|
||||||
|
|
||||||
|
number_of_spikes: int = field(default=0)
|
||||||
|
cooldown_after_number_of_spikes: int = field(default=0)
|
||||||
|
|
||||||
|
weight_path: str = field(default="./Weights/")
|
||||||
|
eps_xy_path: str = field(default="./EpsXY/")
|
||||||
|
data_path: str = field(default="./")
|
||||||
|
|
||||||
|
reduction_cooldown: float = field(default=25.0)
|
||||||
|
epsilon_0: float = field(default=1.0)
|
||||||
|
|
||||||
|
update_after_x_batch: float = field(default=1.0)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Post init determines the number of cores.
|
||||||
|
Creates the required directory and gives us an optimized
|
||||||
|
(for the amount of cores) batch size."""
|
||||||
|
number_of_cpu_processes_temp = os.cpu_count()
|
||||||
|
|
||||||
|
if self.number_of_cpu_processes < 1:
|
||||||
|
if number_of_cpu_processes_temp is None:
|
||||||
|
self.number_of_cpu_processes = 1
|
||||||
|
else:
|
||||||
|
self.number_of_cpu_processes = number_of_cpu_processes_temp
|
||||||
|
|
||||||
|
os.makedirs(self.weight_path, exist_ok=True)
|
||||||
|
os.makedirs(self.eps_xy_path, exist_ok=True)
|
||||||
|
os.makedirs(self.data_path, exist_ok=True)
|
||||||
|
|
||||||
|
self.batch_size = (
|
||||||
|
self.batch_size // self.number_of_cpu_processes
|
||||||
|
) * self.number_of_cpu_processes
|
||||||
|
|
||||||
|
self.batch_size = np.max((self.batch_size, self.number_of_cpu_processes))
|
||||||
|
self.batch_size = int(self.batch_size)
|
||||||
|
|
||||||
|
def get_epsilon_t(self):
|
||||||
|
"""Generates the time series of the basic epsilon."""
|
||||||
|
np_epsilon_t: np.ndarray = np.ones((self.number_of_spikes), dtype=np.float32)
|
||||||
|
np_epsilon_t[
|
||||||
|
self.cooldown_after_number_of_spikes : self.number_of_spikes
|
||||||
|
] /= self.reduction_cooldown
|
||||||
|
return torch.tensor(np_epsilon_t)
|
||||||
|
|
||||||
|
def get_update_after_x_pattern(self):
|
||||||
|
"""Tells us after how many pattern we need to update the weights."""
|
||||||
|
return self.batch_size * self.update_after_x_batch
|
18
PyHDynamicCNNManyIP.pyi
Normal file
18
PyHDynamicCNNManyIP.pyi
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
#
|
||||||
|
# AUTOMATICALLY GENERATED FILE, DO NOT EDIT!
|
||||||
|
#
|
||||||
|
|
||||||
|
"""HDynamicCNNManyIP Module"""
|
||||||
|
from __future__ import annotations
|
||||||
|
import PyHDynamicCNNManyIP
|
||||||
|
import typing
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"HDynamicCNNManyIP"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class HDynamicCNNManyIP():
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
def update_with_init_vector_multi_pattern(self, arg0: int, arg1: int, arg2: int, arg3: int, arg4: int, arg5: int, arg6: int, arg7: int, arg8: int, arg9: int, arg10: int, arg11: int, arg12: int, arg13: int, arg14: int, arg15: int, arg16: int, arg17: int, arg18: int, arg19: int, arg20: int) -> bool: ...
|
||||||
|
pass
|
18
PySpikeGeneration2DManyIP.pyi
Normal file
18
PySpikeGeneration2DManyIP.pyi
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
#
|
||||||
|
# AUTOMATICALLY GENERATED FILE, DO NOT EDIT!
|
||||||
|
#
|
||||||
|
|
||||||
|
"""SpikeGeneration2DManyIP Module"""
|
||||||
|
from __future__ import annotations
|
||||||
|
import PySpikeGeneration2DManyIP
|
||||||
|
import typing
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SpikeGeneration2DManyIP"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SpikeGeneration2DManyIP():
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
def spike_generation_multi_pattern(self, arg0: int, arg1: int, arg2: int, arg3: int, arg4: int, arg5: int, arg6: int, arg7: int, arg8: int, arg9: int, arg10: int, arg11: int, arg12: int, arg13: int, arg14: int, arg15: int) -> bool: ...
|
||||||
|
pass
|
590
learn_it.py
Normal file
590
learn_it.py
Normal file
|
@ -0,0 +1,590 @@
|
||||||
|
# MIT License
|
||||||
|
# Copyright 2022 University of Bremen
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
# a copy of this software and associated documentation files (the "Software"),
|
||||||
|
# to deal in the Software without restriction, including without limitation
|
||||||
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included
|
||||||
|
# in all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||||
|
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
||||||
|
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
||||||
|
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR
|
||||||
|
# THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# David Rotermund ( davrot@uni-bremen.de )
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Release history:
|
||||||
|
# ================
|
||||||
|
# 1.0.0 -- 01.05.2022: first release
|
||||||
|
#
|
||||||
|
#
|
||||||
|
|
||||||
|
# %%
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
import dataconf
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from Dataset import (
|
||||||
|
DatasetMaster,
|
||||||
|
DatasetCIFAR,
|
||||||
|
DatasetMNIST,
|
||||||
|
DatasetFashionMNIST,
|
||||||
|
)
|
||||||
|
from Parameter import Config
|
||||||
|
from SbS import SbS
|
||||||
|
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
tb = SummaryWriter()
|
||||||
|
|
||||||
|
#######################################################################
|
||||||
|
# We want to log what is going on into a file and screen #
|
||||||
|
#######################################################################
|
||||||
|
|
||||||
|
now = datetime.now()
|
||||||
|
dt_string_filename = now.strftime("%Y_%m_%d_%H_%M_%S")
|
||||||
|
logging.basicConfig(
|
||||||
|
filename="log_" + dt_string_filename + ".txt",
|
||||||
|
filemode="w",
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s %(message)s",
|
||||||
|
)
|
||||||
|
logging.getLogger().addHandler(logging.StreamHandler())
|
||||||
|
|
||||||
|
#######################################################################
|
||||||
|
# Load the config data from the json file #
|
||||||
|
#######################################################################
|
||||||
|
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
raise Exception("Argument: Config file name is missing")
|
||||||
|
|
||||||
|
filename: str = sys.argv[1]
|
||||||
|
|
||||||
|
if os.path.exists(filename) is False:
|
||||||
|
raise Exception(f"Config file not found! {filename}")
|
||||||
|
|
||||||
|
cfg = dataconf.file(filename, Config)
|
||||||
|
logging.info(f"Using configuration file: {filename}")
|
||||||
|
|
||||||
|
|
||||||
|
#######################################################################
|
||||||
|
# Prepare the test and training data #
|
||||||
|
#######################################################################
|
||||||
|
|
||||||
|
# Load the input data
|
||||||
|
the_dataset_train: DatasetMaster
|
||||||
|
the_dataset_test: DatasetMaster
|
||||||
|
if cfg.data_mode == "CIFAR10":
|
||||||
|
the_dataset_train = DatasetCIFAR(
|
||||||
|
train=True, path_pattern=cfg.data_path, path_label=cfg.data_path
|
||||||
|
)
|
||||||
|
the_dataset_test = DatasetCIFAR(
|
||||||
|
train=False, path_pattern=cfg.data_path, path_label=cfg.data_path
|
||||||
|
)
|
||||||
|
elif cfg.data_mode == "MNIST":
|
||||||
|
the_dataset_train = DatasetMNIST(
|
||||||
|
train=True, path_pattern=cfg.data_path, path_label=cfg.data_path
|
||||||
|
)
|
||||||
|
the_dataset_test = DatasetMNIST(
|
||||||
|
train=False, path_pattern=cfg.data_path, path_label=cfg.data_path
|
||||||
|
)
|
||||||
|
elif cfg.data_mode == "MNIST_FASHION":
|
||||||
|
the_dataset_train = DatasetFashionMNIST(
|
||||||
|
train=True, path_pattern=cfg.data_path, path_label=cfg.data_path
|
||||||
|
)
|
||||||
|
the_dataset_test = DatasetFashionMNIST(
|
||||||
|
train=False, path_pattern=cfg.data_path, path_label=cfg.data_path
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception("data_mode unknown")
|
||||||
|
|
||||||
|
cfg.image_statistics.mean = the_dataset_train.mean
|
||||||
|
|
||||||
|
# The basic size
|
||||||
|
cfg.image_statistics.the_size = [
|
||||||
|
the_dataset_train.pattern_storage.shape[2],
|
||||||
|
the_dataset_train.pattern_storage.shape[3],
|
||||||
|
]
|
||||||
|
|
||||||
|
# Minus the stuff we cut away in the pattern filter
|
||||||
|
cfg.image_statistics.the_size[0] -= 2 * cfg.augmentation.crop_width_in_pixel
|
||||||
|
cfg.image_statistics.the_size[1] -= 2 * cfg.augmentation.crop_width_in_pixel
|
||||||
|
|
||||||
|
my_loader_test: torch.utils.data.DataLoader = torch.utils.data.DataLoader(
|
||||||
|
the_dataset_test, batch_size=cfg.batch_size, shuffle=False
|
||||||
|
)
|
||||||
|
my_loader_train: torch.utils.data.DataLoader = torch.utils.data.DataLoader(
|
||||||
|
the_dataset_train, batch_size=cfg.batch_size, shuffle=True
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("*** Data loaded.")
|
||||||
|
|
||||||
|
#######################################################################
|
||||||
|
# Build the network #
|
||||||
|
#######################################################################
|
||||||
|
|
||||||
|
wf: list[np.ndarray] = []
|
||||||
|
eps_xy: list[np.ndarray] = []
|
||||||
|
network = torch.nn.Sequential()
|
||||||
|
for id in range(0, len(cfg.network_structure.is_pooling_layer)):
|
||||||
|
if id == 0:
|
||||||
|
input_size: list[int] = cfg.image_statistics.the_size
|
||||||
|
else:
|
||||||
|
input_size = network[id - 1].output_size.tolist()
|
||||||
|
|
||||||
|
network.append(
|
||||||
|
SbS(
|
||||||
|
number_of_input_neurons=cfg.network_structure.forward_neuron_numbers[id][0],
|
||||||
|
number_of_neurons=cfg.network_structure.forward_neuron_numbers[id][1],
|
||||||
|
input_size=input_size,
|
||||||
|
forward_kernel_size=cfg.network_structure.forward_kernel_size[id],
|
||||||
|
number_of_spikes=cfg.number_of_spikes,
|
||||||
|
epsilon_t=cfg.get_epsilon_t(),
|
||||||
|
epsilon_xy_intitial=cfg.learning_parameters.eps_xy_intitial,
|
||||||
|
epsilon_0=cfg.epsilon_0,
|
||||||
|
weight_noise_amplitude=cfg.learning_parameters.weight_noise_amplitude,
|
||||||
|
is_pooling_layer=cfg.network_structure.is_pooling_layer[id],
|
||||||
|
strides=cfg.network_structure.strides[id],
|
||||||
|
dilation=cfg.network_structure.dilation[id],
|
||||||
|
padding=cfg.network_structure.padding[id],
|
||||||
|
alpha_number_of_iterations=cfg.learning_parameters.alpha_number_of_iterations,
|
||||||
|
number_of_cpu_processes=cfg.number_of_cpu_processes,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
eps_xy.append(network[id].epsilon_xy.detach().clone().numpy())
|
||||||
|
wf.append(network[id].weights.detach().clone().numpy())
|
||||||
|
|
||||||
|
logging.info("*** Network generated.")
|
||||||
|
|
||||||
|
for id in range(0, len(network)):
|
||||||
|
# Load previous weights and epsilon xy
|
||||||
|
if cfg.learning_step > 0:
|
||||||
|
network[id].weights = torch.tensor(
|
||||||
|
np.load(
|
||||||
|
cfg.weight_path
|
||||||
|
+ "/Weight_L"
|
||||||
|
+ str(id)
|
||||||
|
+ "_S"
|
||||||
|
+ str(cfg.learning_step)
|
||||||
|
+ ".npy"
|
||||||
|
),
|
||||||
|
dtype=torch.float64,
|
||||||
|
)
|
||||||
|
|
||||||
|
wf[id] = np.load(
|
||||||
|
cfg.weight_path
|
||||||
|
+ "/Weight_L"
|
||||||
|
+ str(id)
|
||||||
|
+ "_S"
|
||||||
|
+ str(cfg.learning_step)
|
||||||
|
+ ".npy"
|
||||||
|
)
|
||||||
|
|
||||||
|
network[id].epsilon_xy = torch.tensor(
|
||||||
|
np.load(
|
||||||
|
cfg.eps_xy_path
|
||||||
|
+ "/EpsXY_L"
|
||||||
|
+ str(id)
|
||||||
|
+ "_S"
|
||||||
|
+ str(cfg.learning_step)
|
||||||
|
+ ".npy"
|
||||||
|
),
|
||||||
|
dtype=torch.float64,
|
||||||
|
)
|
||||||
|
|
||||||
|
eps_xy[id] = np.load(
|
||||||
|
cfg.eps_xy_path
|
||||||
|
+ "/EpsXY_L"
|
||||||
|
+ str(id)
|
||||||
|
+ "_S"
|
||||||
|
+ str(cfg.learning_step)
|
||||||
|
+ ".npy"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Are there weights that overwrite the initial weights?
|
||||||
|
file_to_load: str = (
|
||||||
|
cfg.learning_parameters.overload_path + "/Weight_L" + str(id) + ".npy"
|
||||||
|
)
|
||||||
|
if os.path.exists(file_to_load) is True:
|
||||||
|
network[id].weights = torch.tensor(
|
||||||
|
np.load(file_to_load),
|
||||||
|
dtype=torch.float64,
|
||||||
|
)
|
||||||
|
wf[id] = np.load(file_to_load)
|
||||||
|
logging.info(f"File used: {file_to_load}")
|
||||||
|
|
||||||
|
file_to_load = cfg.learning_parameters.overload_path + "/EpsXY_L" + str(id) + ".npy"
|
||||||
|
if os.path.exists(file_to_load) is True:
|
||||||
|
network[id].epsilon_xy = torch.tensor(
|
||||||
|
np.load(file_to_load),
|
||||||
|
dtype=torch.float64,
|
||||||
|
)
|
||||||
|
eps_xy[id] = np.load(file_to_load)
|
||||||
|
logging.info(f"File used: {file_to_load}")
|
||||||
|
|
||||||
|
#######################################################################
|
||||||
|
# Optimizer and LR Scheduler #
|
||||||
|
#######################################################################
|
||||||
|
|
||||||
|
# I keep weights and epsilon xy seperate to
|
||||||
|
# set the initial learning rate independently
|
||||||
|
parameter_list_weights: list = []
|
||||||
|
parameter_list_epsilon_xy: list = []
|
||||||
|
|
||||||
|
for id in range(0, len(network)):
|
||||||
|
parameter_list_weights.append(network[id]._weights)
|
||||||
|
parameter_list_epsilon_xy.append(network[id]._epsilon_xy)
|
||||||
|
|
||||||
|
if cfg.learning_parameters.optimizer_name == "Adam":
|
||||||
|
if cfg.learning_parameters.learning_rate_gamma_w > 0:
|
||||||
|
optimizer_wf: torch.optim.Optimizer = torch.optim.Adam(
|
||||||
|
parameter_list_weights,
|
||||||
|
lr=cfg.learning_parameters.learning_rate_gamma_w,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
optimizer_wf = torch.optim.Adam(
|
||||||
|
parameter_list_weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.learning_parameters.learning_rate_gamma_eps_xy > 0:
|
||||||
|
optimizer_eps: torch.optim.Optimizer = torch.optim.Adam(
|
||||||
|
parameter_list_epsilon_xy,
|
||||||
|
lr=cfg.learning_parameters.learning_rate_gamma_eps_xy,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
optimizer_eps = torch.optim.Adam(
|
||||||
|
parameter_list_epsilon_xy,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception("Optimizer not implemented")
|
||||||
|
|
||||||
|
if cfg.learning_parameters.lr_schedule_name == "ReduceLROnPlateau":
|
||||||
|
lr_scheduler_wf = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
|
optimizer_wf,
|
||||||
|
factor=cfg.learning_parameters.lr_scheduler_factor,
|
||||||
|
patience=cfg.learning_parameters.lr_scheduler_patience,
|
||||||
|
)
|
||||||
|
|
||||||
|
lr_scheduler_eps = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
|
optimizer_eps,
|
||||||
|
factor=cfg.learning_parameters.lr_scheduler_factor,
|
||||||
|
patience=cfg.learning_parameters.lr_scheduler_patience,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception("lr_scheduler not implemented")
|
||||||
|
|
||||||
|
logging.info("*** Optimizer prepared.")
|
||||||
|
|
||||||
|
|
||||||
|
#######################################################################
|
||||||
|
# Some variable declarations #
|
||||||
|
#######################################################################
|
||||||
|
|
||||||
|
test_correct: int = 0
|
||||||
|
test_all: int = 0
|
||||||
|
test_complete: int = the_dataset_test.__len__()
|
||||||
|
|
||||||
|
train_correct: int = 0
|
||||||
|
train_all: int = 0
|
||||||
|
train_complete: int = the_dataset_train.__len__()
|
||||||
|
|
||||||
|
train_number_of_processed_pattern: int = 0
|
||||||
|
|
||||||
|
train_loss: np.ndarray = np.zeros((1), dtype=np.float32)
|
||||||
|
|
||||||
|
last_test_performance: float = -1.0
|
||||||
|
|
||||||
|
|
||||||
|
logging.info("")
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
if cfg.learning_parameters.learning_active is True:
|
||||||
|
while True:
|
||||||
|
|
||||||
|
###############################################
|
||||||
|
# Run a training data batch #
|
||||||
|
###############################################
|
||||||
|
|
||||||
|
for h_x, h_x_labels in my_loader_train:
|
||||||
|
time_0: float = time.perf_counter()
|
||||||
|
|
||||||
|
if train_number_of_processed_pattern == 0:
|
||||||
|
# Reset the gradient of the torch optimizers
|
||||||
|
optimizer_wf.zero_grad()
|
||||||
|
optimizer_eps.zero_grad()
|
||||||
|
|
||||||
|
with torch.enable_grad():
|
||||||
|
|
||||||
|
h_collection = []
|
||||||
|
h_collection.append(
|
||||||
|
the_dataset_train.pattern_filter_train(h_x, cfg).type(
|
||||||
|
dtype=torch.float64
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for id in range(0, len(network)):
|
||||||
|
h_collection.append(network[id](h_collection[-1]))
|
||||||
|
|
||||||
|
# Convert label into one hot
|
||||||
|
target_one_hot: torch.Tensor = torch.zeros(
|
||||||
|
(
|
||||||
|
h_x_labels.shape[0],
|
||||||
|
int(cfg.network_structure.number_of_output_neurons),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
target_one_hot.scatter_(
|
||||||
|
1, h_x_labels.unsqueeze(1), torch.ones((h_x_labels.shape[0], 1))
|
||||||
|
)
|
||||||
|
target_one_hot = (
|
||||||
|
target_one_hot.unsqueeze(2)
|
||||||
|
.unsqueeze(2)
|
||||||
|
.type(dtype=torch.float64)
|
||||||
|
)
|
||||||
|
|
||||||
|
# through the loss functions
|
||||||
|
h_y1 = torch.log(h_collection[-1])
|
||||||
|
h_y2 = torch.nan_to_num(h_y1, nan=0.0, posinf=0.0, neginf=0.0)
|
||||||
|
|
||||||
|
my_loss: torch.Tensor = (
|
||||||
|
(
|
||||||
|
torch.nn.functional.mse_loss(
|
||||||
|
h_collection[-1], target_one_hot, reduction="none"
|
||||||
|
)
|
||||||
|
* cfg.learning_parameters.loss_coeffs_mse
|
||||||
|
+ torch.nn.functional.kl_div(
|
||||||
|
h_y2, target_one_hot, reduction="none"
|
||||||
|
)
|
||||||
|
* cfg.learning_parameters.loss_coeffs_kldiv
|
||||||
|
)
|
||||||
|
/ (
|
||||||
|
cfg.learning_parameters.loss_coeffs_kldiv
|
||||||
|
+ cfg.learning_parameters.loss_coeffs_mse
|
||||||
|
)
|
||||||
|
).mean()
|
||||||
|
|
||||||
|
time_1: float = time.perf_counter()
|
||||||
|
|
||||||
|
my_loss.backward()
|
||||||
|
my_loss_float = my_loss.item()
|
||||||
|
time_2: float = time.perf_counter()
|
||||||
|
|
||||||
|
train_correct += (
|
||||||
|
(h_collection[-1].argmax(dim=1).squeeze() == h_x_labels)
|
||||||
|
.sum()
|
||||||
|
.numpy()
|
||||||
|
)
|
||||||
|
train_all += h_collection[-1].shape[0]
|
||||||
|
|
||||||
|
performance: float = 100.0 * train_correct / train_all
|
||||||
|
|
||||||
|
time_measure_a: float = time_1 - time_0
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
(
|
||||||
|
f"{cfg.learning_step:^6} Training \t{train_all^6} pattern "
|
||||||
|
f"with {performance/100.0:^6.2%} "
|
||||||
|
f"\t\tForward time: \t{time_measure_a:^6.2f}sec"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
train_loss[0] += my_loss_float
|
||||||
|
train_number_of_processed_pattern += h_collection[-1].shape[0]
|
||||||
|
|
||||||
|
time_measure_b: float = time_2 - time_1
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
(
|
||||||
|
f"\t\t\tLoss: {train_loss[0]/train_number_of_processed_pattern:^15.3e} "
|
||||||
|
f"\t\t\tBackward time: \t{time_measure_b:^6.2f}sec "
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
train_number_of_processed_pattern
|
||||||
|
>= cfg.get_update_after_x_pattern()
|
||||||
|
):
|
||||||
|
logging.info("\t\t\t*** Updating the weights ***")
|
||||||
|
my_loss_for_batch: float = (
|
||||||
|
train_loss[0] / train_number_of_processed_pattern
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer_wf.step()
|
||||||
|
optimizer_eps.step()
|
||||||
|
|
||||||
|
for id in range(0, len(network)):
|
||||||
|
if cfg.network_structure.w_trainable[id] is True:
|
||||||
|
network[id].norm_weights()
|
||||||
|
network[id].threshold_weights(
|
||||||
|
cfg.learning_parameters.learning_rate_threshold_w
|
||||||
|
)
|
||||||
|
network[id].norm_weights()
|
||||||
|
else:
|
||||||
|
network[id].weights = torch.tensor(
|
||||||
|
wf[id], dtype=torch.float64
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.network_structure.eps_xy_trainable[id] is True:
|
||||||
|
network[id].threshold_epsilon_xy(
|
||||||
|
cfg.learning_parameters.learning_rate_threshold_eps_xy
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
network[id].epsilon_xy = torch.tensor(
|
||||||
|
eps_xy[id], dtype=torch.float64
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the new values
|
||||||
|
np.save(
|
||||||
|
cfg.weight_path
|
||||||
|
+ "/Weight_L"
|
||||||
|
+ str(id)
|
||||||
|
+ "_S"
|
||||||
|
+ str(cfg.learning_step)
|
||||||
|
+ ".npy",
|
||||||
|
network[id].weights.detach().numpy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
tb.add_histogram(
|
||||||
|
"Weights " + str(id),
|
||||||
|
network[id].weights,
|
||||||
|
cfg.learning_step,
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
np.save(
|
||||||
|
cfg.eps_xy_path
|
||||||
|
+ "/EpsXY_L"
|
||||||
|
+ str(id)
|
||||||
|
+ "_S"
|
||||||
|
+ str(cfg.learning_step)
|
||||||
|
+ ".npy",
|
||||||
|
network[id].epsilon_xy.detach().numpy(),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
tb.add_histogram(
|
||||||
|
"Epsilon XY " + str(id),
|
||||||
|
network[id].epsilon_xy.detach().numpy(),
|
||||||
|
cfg.learning_step,
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Let the torch learning rate scheduler update the
|
||||||
|
# learning rates of the optimiers
|
||||||
|
lr_scheduler_wf.step(my_loss_for_batch)
|
||||||
|
lr_scheduler_eps.step(my_loss_for_batch)
|
||||||
|
|
||||||
|
tb.add_scalar(
|
||||||
|
"Train Performance", 100.0 - performance, cfg.learning_step
|
||||||
|
)
|
||||||
|
tb.add_scalar("Train Loss", my_loss_for_batch, cfg.learning_step)
|
||||||
|
tb.add_scalar(
|
||||||
|
"Learning Rate Scale WF",
|
||||||
|
optimizer_wf.param_groups[-1]["lr"],
|
||||||
|
cfg.learning_step,
|
||||||
|
)
|
||||||
|
tb.add_scalar(
|
||||||
|
"Learning Rate Scale Eps XY ",
|
||||||
|
optimizer_eps.param_groups[-1]["lr"],
|
||||||
|
cfg.learning_step,
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg.learning_step += 1
|
||||||
|
train_loss = np.zeros((1), dtype=np.float32)
|
||||||
|
train_correct = 0
|
||||||
|
train_all = 0
|
||||||
|
performance = 0
|
||||||
|
train_number_of_processed_pattern = 0
|
||||||
|
|
||||||
|
tb.flush()
|
||||||
|
|
||||||
|
test_correct = 0
|
||||||
|
test_all = 0
|
||||||
|
|
||||||
|
if last_test_performance < 0:
|
||||||
|
logging.info("")
|
||||||
|
else:
|
||||||
|
logging.info(
|
||||||
|
f"\t\t\tLast test performance: {last_test_performance/100.0:^6.2%}"
|
||||||
|
)
|
||||||
|
logging.info("")
|
||||||
|
|
||||||
|
###############################################
|
||||||
|
# Run a test data performance measurement #
|
||||||
|
###############################################
|
||||||
|
if (
|
||||||
|
(
|
||||||
|
(
|
||||||
|
(
|
||||||
|
cfg.learning_step
|
||||||
|
% cfg.learning_parameters.test_every_x_learning_steps
|
||||||
|
)
|
||||||
|
== 0
|
||||||
|
)
|
||||||
|
or (cfg.learning_step == cfg.learning_step_max)
|
||||||
|
)
|
||||||
|
and (cfg.learning_parameters.test_during_learning is True)
|
||||||
|
and (cfg.learning_step > 0)
|
||||||
|
):
|
||||||
|
logging.info("")
|
||||||
|
logging.info("Testing:")
|
||||||
|
|
||||||
|
for h_x, h_x_labels in my_loader_test:
|
||||||
|
time_0 = time.perf_counter()
|
||||||
|
|
||||||
|
h_h: torch.Tensor = network(
|
||||||
|
the_dataset_train.pattern_filter_test(h_x, cfg).type(
|
||||||
|
dtype=torch.float64
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
test_correct += (
|
||||||
|
(h_h.argmax(dim=1).squeeze() == h_x_labels)
|
||||||
|
.sum()
|
||||||
|
.numpy()
|
||||||
|
)
|
||||||
|
test_all += h_h.shape[0]
|
||||||
|
performance = 100.0 * test_correct / test_all
|
||||||
|
time_1 = time.perf_counter()
|
||||||
|
time_measure_a = time_1 - time_0
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
(
|
||||||
|
f"\t\t{test_all} of {test_complete}"
|
||||||
|
f" with {performance/100:^6.2%} \t Time used: {time_measure_a:^6.2f}sec"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("")
|
||||||
|
|
||||||
|
last_test_performance = performance
|
||||||
|
|
||||||
|
tb.add_scalar(
|
||||||
|
"Test Error", 100.0 - performance, cfg.learning_step
|
||||||
|
)
|
||||||
|
tb.flush()
|
||||||
|
|
||||||
|
if cfg.learning_step == cfg.learning_step_max:
|
||||||
|
tb.close()
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
# %%
|
Loading…
Reference in a new issue