Add files via upload

This commit is contained in:
David Rotermund 2023-02-21 14:37:51 +01:00 committed by GitHub
parent 1a807b44df
commit a85805e92c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 659 additions and 46 deletions

View file

@ -151,11 +151,7 @@ class Adam(torch.optim.Optimizer):
if sbs_setting[i] is False:
param -= step_size * (exp_avg / denom)
else:
# delta = torch.exp(-step_size * (exp_avg / denom))
delta = torch.tanh(-step_size * (exp_avg / denom))
delta += 1.0
delta *= 0.5
delta += 0.5
delta = 0.5 * torch.tanh(-step_size * (exp_avg / denom)) + 1.0
self._logging.info(
f"ADAM: Layer {i} -> dw_min:{float(delta.min()):.4e} dw_max:{float(delta.max()):.4e} lr:{lr:.4e}"
)

View file

@ -15,6 +15,8 @@ class DatasetMaster(torch.utils.data.Dataset, ABC):
initial_size: list[int]
channel_size: int
alpha: float
# Initialize
def __init__(
self,

90
network/DatasetMix.py Normal file
View file

@ -0,0 +1,90 @@
import torch
from network.Dataset import DatasetMNIST, DatasetFashionMNIST, DatasetCIFAR
import math
class DatasetMNISTMix(DatasetMNIST):
def __init__(
self,
train: bool = False,
path_pattern: str = "./",
path_label: str = "./",
alpha: float = 1.0,
) -> None:
super().__init__(train, path_pattern, path_label)
self.alpha = alpha
def __getitem__(self, index: int) -> tuple[torch.Tensor, list[int]]: # type: ignore
assert self.alpha >= 0.0
assert self.alpha <= 1.0
image_a, target_a = super().__getitem__(index)
target_b: int = target_a
while target_b == target_a:
image_b, target_b = super().__getitem__(
int(math.floor(self.number_of_pattern * torch.rand((1)).item()))
)
image = self.alpha * image_a + (1.0 - self.alpha) * image_b
target = [target_a, target_b]
return image, target
class DatasetFashionMNISTMix(DatasetFashionMNIST):
def __init__(
self,
train: bool = False,
path_pattern: str = "./",
path_label: str = "./",
alpha: float = 1.0,
) -> None:
super().__init__(train, path_pattern, path_label)
self.alpha = alpha
def __getitem__(self, index: int) -> tuple[torch.Tensor, list[int]]: # type: ignore
assert self.alpha >= 0.0
assert self.alpha <= 1.0
image_a, target_a = super().__getitem__(index)
target_b: int = target_a
while target_b == target_a:
image_b, target_b = super().__getitem__(
int(math.floor(self.number_of_pattern * torch.rand((1)).item()))
)
image = self.alpha * image_a + (1.0 - self.alpha) * image_b
target = [target_a, target_b]
return image, target
class DatasetCIFARMix(DatasetCIFAR):
def __init__(
self,
train: bool = False,
path_pattern: str = "./",
path_label: str = "./",
alpha: float = 1.0,
) -> None:
super().__init__(train, path_pattern, path_label)
self.alpha = alpha
def __getitem__(self, index: int) -> tuple[torch.Tensor, list[int]]: # type: ignore
assert self.alpha >= 0.0
assert self.alpha <= 1.0
image_a, target_a = super().__getitem__(index)
target_b: int = target_a
while target_b == target_a:
image_b, target_b = super().__getitem__(
int(math.floor(self.number_of_pattern * torch.rand((1)).item()))
)
image = self.alpha * image_a + (1.0 - self.alpha) * image_b
target = [target_a, target_b]
return image, target

View file

@ -443,7 +443,6 @@ class FunctionalSbS(torch.autograd.Function):
)
elif (parameter_output_layer is True) and (parameter_local_learning is True):
target_one_hot: torch.Tensor = torch.zeros(
(
labels.shape[0],

View file

@ -51,7 +51,13 @@ class InputSpikeImage(torch.nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.number_of_spikes < 1:
return input
output = input
output = output.type(dtype=input.dtype)
if self._normalize is True:
output = output * output.shape[-1] * output.shape[-2] * output.shape[-3] / output.sum(dim=-1, keepdim=True).sum(
dim=-2, keepdim=True
).sum(dim=-3, keepdim=True)
return output
input_shape: list[int] = [
int(input.shape[0]),
@ -95,9 +101,10 @@ class InputSpikeImage(torch.nn.Module):
)
)
output = output.type(dtype=input_work.dtype)
if self._normalize is True:
output = output.type(dtype=input_work.dtype)
output = output / output.sum(dim=-1, keepdim=True).sum(
output = output * output.shape[-1] * output.shape[-2] * output.shape[-3] / output.sum(dim=-1, keepdim=True).sum(
dim=-2, keepdim=True
).sum(dim=-3, keepdim=True)

280
network/NNMFLayer.py Normal file
View file

@ -0,0 +1,280 @@
import torch
from network.calculate_output_size import calculate_output_size
class NNMFLayer(torch.nn.Module):
_epsilon_0: float
_weights: torch.nn.parameter.Parameter
_weights_exists: bool = False
_kernel_size: list[int]
_stride: list[int]
_dilation: list[int]
_padding: list[int]
_output_size: torch.Tensor
_number_of_neurons: int
_number_of_input_neurons: int
_h_initial: torch.Tensor | None = None
_w_trainable: bool
_weight_noise_range: list[float]
_input_size: list[int]
_output_layer: bool = False
_number_of_iterations: int
_local_learning: bool = False
device: torch.device
default_dtype: torch.dtype
_number_of_grad_weight_contributions: float = 0.0
last_input_store: bool = False
last_input_data: torch.Tensor | None = None
_layer_id: int = -1
def __init__(
self,
number_of_input_neurons: int,
number_of_neurons: int,
input_size: list[int],
forward_kernel_size: list[int],
number_of_iterations: int,
epsilon_0: float = 1.0,
weight_noise_range: list[float] = [0.0, 1.0],
strides: list[int] = [1, 1],
dilation: list[int] = [0, 0],
padding: list[int] = [0, 0],
w_trainable: bool = False,
device: torch.device | None = None,
default_dtype: torch.dtype | None = None,
layer_id: int = -1,
local_learning: bool = False,
output_layer: bool = False,
) -> None:
super().__init__()
assert device is not None
assert default_dtype is not None
self.device = device
self.default_dtype = default_dtype
self._w_trainable = bool(w_trainable)
self._stride = strides
self._dilation = dilation
self._padding = padding
self._kernel_size = forward_kernel_size
self._number_of_input_neurons = int(number_of_input_neurons)
self._number_of_neurons = int(number_of_neurons)
self._epsilon_0 = float(epsilon_0)
self._number_of_iterations = int(number_of_iterations)
self._weight_noise_range = weight_noise_range
self._layer_id = layer_id
self._local_learning = local_learning
self._output_layer = output_layer
assert len(input_size) == 2
self._input_size = input_size
self._output_size = calculate_output_size(
value=input_size,
kernel_size=self._kernel_size,
stride=self._stride,
dilation=self._dilation,
padding=self._padding,
)
self.set_h_init_to_uniform()
# ###############################################################
# Initialize the weights
# ###############################################################
assert len(self._weight_noise_range) == 2
weights = torch.empty(
(
int(self._kernel_size[0])
* int(self._kernel_size[1])
* int(self._number_of_input_neurons),
int(self._number_of_neurons),
),
dtype=self.default_dtype,
device=self.device,
)
torch.nn.init.uniform_(
weights,
a=float(self._weight_noise_range[0]),
b=float(self._weight_noise_range[1]),
)
self.weights = weights
@property
def weights(self) -> torch.Tensor | None:
if self._weights_exists is False:
return None
else:
return self._weights
@weights.setter
def weights(self, value: torch.Tensor):
assert value is not None
assert torch.is_tensor(value) is True
assert value.dim() == 2
temp: torch.Tensor = (
value.detach()
.clone(memory_format=torch.contiguous_format)
.type(dtype=self.default_dtype)
.to(device=self.device)
)
temp /= temp.sum(dim=0, keepdim=True, dtype=self.default_dtype)
if self._weights_exists is False:
self._weights = torch.nn.parameter.Parameter(temp, requires_grad=True)
self._weights_exists = True
else:
self._weights.data = temp
@property
def h_initial(self) -> torch.Tensor | None:
return self._h_initial
@h_initial.setter
def h_initial(self, value: torch.Tensor):
assert value is not None
assert torch.is_tensor(value) is True
assert value.dim() == 1
assert value.dtype == self.default_dtype
self._h_initial = (
value.detach()
.clone(memory_format=torch.contiguous_format)
.type(dtype=self.default_dtype)
.to(device=self.device)
.requires_grad_(False)
)
def update_pre_care(self):
if self._weights.grad is not None:
assert self._number_of_grad_weight_contributions > 0
self._weights.grad /= self._number_of_grad_weight_contributions
self._number_of_grad_weight_contributions = 0.0
def update_after_care(self, threshold_weight: float):
if self._w_trainable is True:
self.norm_weights()
self.threshold_weights(threshold_weight)
self.norm_weights()
def set_h_init_to_uniform(self) -> None:
assert self._number_of_neurons > 2
self.h_initial: torch.Tensor = torch.full(
(self._number_of_neurons,),
(1.0 / float(self._number_of_neurons)),
dtype=self.default_dtype,
device=self.device,
)
def norm_weights(self) -> None:
assert self._weights_exists is True
temp: torch.Tensor = (
self._weights.data.detach()
.clone(memory_format=torch.contiguous_format)
.type(dtype=self.default_dtype)
.to(device=self.device)
)
temp /= temp.sum(dim=0, keepdim=True, dtype=self.default_dtype)
self._weights.data = temp
def threshold_weights(self, threshold: float) -> None:
assert self._weights_exists is True
assert threshold >= 0
torch.clamp(
self._weights.data,
min=float(threshold),
max=None,
out=self._weights.data,
)
####################################################################
# Forward #
####################################################################
def forward(
self,
input: torch.Tensor,
) -> torch.Tensor:
# Are we happy with the input?
assert input is not None
assert torch.is_tensor(input) is True
assert input.dim() == 4
assert input.dtype == self.default_dtype
assert input.shape[1] == self._number_of_input_neurons
assert input.shape[2] == self._input_size[0]
assert input.shape[3] == self._input_size[1]
# Are we happy with the rest of the network?
assert self._epsilon_0 is not None
assert self._h_initial is not None
assert self._weights_exists is True
assert self._weights is not None
# Convolution of the input...
# Well, this is a convoltion layer
# there needs to be convolution somewhere
input_convolved = torch.nn.functional.fold(
torch.nn.functional.unfold(
input.requires_grad_(True),
kernel_size=(int(self._kernel_size[0]), int(self._kernel_size[1])),
dilation=(int(self._dilation[0]), int(self._dilation[1])),
padding=(int(self._padding[0]), int(self._padding[1])),
stride=(int(self._stride[0]), int(self._stride[1])),
),
output_size=tuple(self._output_size.tolist()),
kernel_size=(1, 1),
dilation=(1, 1),
padding=(0, 0),
stride=(1, 1),
)
# We might need the convolved input for other layers
# let us keep it for the future
if self.last_input_store is True:
self.last_input_data = input_convolved.detach().clone()
self.last_input_data /= self.last_input_data.sum(dim=1, keepdim=True)
else:
self.last_input_data = None
input_convolved = input_convolved / input_convolved.sum(dim=1, keepdim=True)
h = torch.tile(
self._h_initial.unsqueeze(0).unsqueeze(-1).unsqueeze(-1),
dims=[
int(input.shape[0]),
1,
int(self._output_size[0]),
int(self._output_size[1]),
],
).requires_grad_(True)
for _ in range(0, self._number_of_iterations):
h_w = h.unsqueeze(1) * self._weights.unsqueeze(0).unsqueeze(-1).unsqueeze(
-1
)
h_w = h_w / (h_w.sum(dim=2, keepdim=True) + 1e-20)
h_w = (h_w * input_convolved.unsqueeze(2)).sum(dim=1)
if self._epsilon_0 > 0:
h = h + self._epsilon_0 * h_w
else:
h = h_w
h = h / (h.sum(dim=1, keepdim=True) + 1e-20)
self._number_of_grad_weight_contributions += (
h.shape[0] * h.shape[-2] * h.shape[-1]
)
return h

View file

@ -44,7 +44,7 @@ class LearningParameters:
overload_path: str = field(default="Previous")
weight_noise_range: list[float] = field(default_factory=list)
eps_xy_intitial: float = field(default=0.1)
eps_xy_intitial: float = field(default=1.0)
disable_scale_grade: bool = field(default=False)
kepp_last_grad_scale: bool = field(default=True)
@ -55,7 +55,6 @@ class LearningParameters:
w_trainable: list[bool] = field(default_factory=list)
@dataclass
class Augmentation:
"""Parameters used for data augmentation."""

View file

@ -87,6 +87,8 @@ class SbSLayer(torch.nn.Module):
spike_full_layer_input_distribution: bool = False,
force_forward_spike_on_cpu: bool = False,
force_forward_spike_output_on_cpu: bool = False,
local_learning: bool = False,
output_layer: bool = False,
) -> None:
super().__init__()
@ -117,6 +119,8 @@ class SbSLayer(torch.nn.Module):
self._epsilon_xy_use = epsilon_xy_use
self._force_forward_h_dynamic_on_cpu = force_forward_h_dynamic_on_cpu
self._spike_full_layer_input_distribution = spike_full_layer_input_distribution
self._local_learning = local_learning
self._output_layer = output_layer
assert len(input_size) == 2
self._input_size = input_size

View file

@ -28,27 +28,28 @@ class SplitOnOffLayer(torch.nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
assert input.ndim == 4
# self.training is switched by network.eval() and network.train()
if self.training is True:
mean_temp = (
input.mean(dim=0, keepdim=True)
.mean(dim=1, keepdim=True)
.detach()
.clone()
)
# # self.training is switched by network.eval() and network.train()
# if self.training is True:
# mean_temp = (
# input.mean(dim=0, keepdim=True)
# .mean(dim=1, keepdim=True)
# .detach()
# .clone()
# )
#
# if self.mean is None:
# self.mean = mean_temp
# else:
# self.mean = (1.0 - self.epsilon) * self.mean + self.epsilon * mean_temp
#
# assert self.mean is not None
if self.mean is None:
self.mean = mean_temp
else:
self.mean = (1.0 - self.epsilon) * self.mean + self.epsilon * mean_temp
assert self.mean is not None
temp = input - self.mean.detach().clone()
# temp = input - self.mean.detach().clone()
temp = input - 0.5
temp_a = torch.nn.functional.relu(temp)
temp_b = torch.nn.functional.relu(-temp)
output = torch.cat((temp_a, temp_b), dim=1)
output /= output.sum(dim=1, keepdim=True) + 1e-20
#output /= output.sum(dim=1, keepdim=True) + 1e-20
return output

View file

@ -6,6 +6,11 @@ from network.Dataset import (
DatasetMNIST,
DatasetFashionMNIST,
)
from network.DatasetMix import (
DatasetCIFARMix,
DatasetMNISTMix,
DatasetFashionMNISTMix,
)
from network.Parameter import Config
@ -42,6 +47,27 @@ def build_datasets(
the_dataset_test = DatasetFashionMNIST(
train=False, path_pattern=cfg.data_path, path_label=cfg.data_path
)
elif cfg.data_mode == "MIX_CIFAR10":
the_dataset_train = DatasetCIFARMix(
train=True, path_pattern=cfg.data_path, path_label=cfg.data_path
)
the_dataset_test = DatasetCIFARMix(
train=False, path_pattern=cfg.data_path, path_label=cfg.data_path
)
elif cfg.data_mode == "MIX_MNIST":
the_dataset_train = DatasetMNISTMix(
train=True, path_pattern=cfg.data_path, path_label=cfg.data_path
)
the_dataset_test = DatasetMNISTMix(
train=False, path_pattern=cfg.data_path, path_label=cfg.data_path
)
elif cfg.data_mode == "MIX_MNIST_FASHION":
the_dataset_train = DatasetFashionMNISTMix(
train=True, path_pattern=cfg.data_path, path_label=cfg.data_path
)
the_dataset_test = DatasetFashionMNISTMix(
train=False, path_pattern=cfg.data_path, path_label=cfg.data_path
)
else:
raise Exception("data_mode unknown")

View file

@ -39,7 +39,8 @@ def build_lr_scheduler(
):
lr_scheduler_list.append(
torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer[id_optimizer],eps=1e-14,
optimizer[id_optimizer],
eps=1e-14,
)
)
else:

View file

@ -4,6 +4,7 @@ import torch
from network.calculate_output_size import calculate_output_size
from network.Parameter import Config
from network.SbSLayer import SbSLayer
from network.NNMFLayer import NNMFLayer
from network.SplitOnOffLayer import SplitOnOffLayer
from network.Conv2dApproximation import Conv2dApproximation
from network.SbSReconstruction import SbSReconstruction
@ -153,6 +154,14 @@ def build_network(
if cfg.network_structure.layer_type[layer_id].upper().find("POOLING") != -1:
is_pooling_layer = True
local_learning = False
if cfg.network_structure.layer_type[layer_id].upper().find("LOCAL") != -1:
local_learning = True
output_layer = False
if layer_id == len(cfg.network_structure.layer_type) - 1:
output_layer = True
network.append(
SbSLayer(
number_of_input_neurons=in_channels,
@ -180,19 +189,13 @@ def build_network(
reduction_cooldown=cfg.reduction_cooldown,
force_forward_h_dynamic_on_cpu=cfg.force_forward_h_dynamic_on_cpu,
spike_full_layer_input_distribution=spike_full_layer_input_distribution,
local_learning=local_learning,
output_layer=output_layer,
)
)
# Adding the x,y output dimensions
input_size.append(network[-1]._output_size.tolist())
network[-1]._output_layer = False
if layer_id == len(cfg.network_structure.layer_type) - 1:
network[-1]._output_layer = True
network[-1]._local_learning = False
if cfg.network_structure.layer_type[layer_id].upper().find("LOCAL") != -1:
network[-1]._local_learning = True
elif (
cfg.network_structure.layer_type[layer_id]
.upper()
@ -276,6 +279,8 @@ def build_network(
):
logging.info(f"Layer: {layer_id} -> RELU Layer")
network.append(torch.nn.ReLU())
network[-1]._w_trainable = False
input_size.append(input_size[-1])
# #############################################################
@ -296,6 +301,8 @@ def build_network(
)
)
network[-1]._w_trainable = False
# Calculate the x,y output dimensions
input_size_temp = calculate_output_size(
value=input_size[-1],
@ -323,6 +330,9 @@ def build_network(
padding=(int(padding[0]), int(padding[1])),
)
)
network[-1]._w_trainable = False
# Calculate the x,y output dimensions
input_size_temp = calculate_output_size(
value=input_size[-1],
@ -405,8 +415,67 @@ def build_network(
)
)
network[-1]._w_trainable = False
input_size.append(input_size[-1])
# #############################################################
# NNMF:
# #############################################################
elif (
cfg.network_structure.layer_type[layer_id].upper().startswith("NNMF")
is True
):
assert in_channels > 0
assert out_channels > 0
number_of_iterations: int = -1
if len(cfg.number_of_spikes) > layer_id:
number_of_iterations = cfg.number_of_spikes[layer_id]
elif len(cfg.number_of_spikes) == 1:
number_of_iterations = cfg.number_of_spikes[0]
assert number_of_iterations > 0
logging.info(
(
f"Layer: {layer_id} -> NNMF Layer with {number_of_iterations} iterations "
)
)
local_learning = False
if cfg.network_structure.layer_type[layer_id].upper().find("LOCAL") != -1:
local_learning = True
output_layer = False
if layer_id == len(cfg.network_structure.layer_type) - 1:
output_layer = True
network.append(
NNMFLayer(
number_of_input_neurons=in_channels,
number_of_neurons=out_channels,
input_size=input_size[-1],
forward_kernel_size=kernel_size,
number_of_iterations=number_of_iterations,
epsilon_0=cfg.epsilon_0,
weight_noise_range=weight_noise_range,
strides=strides,
dilation=dilation,
padding=padding,
w_trainable=w_trainable,
device=device,
default_dtype=default_dtype,
layer_id=layer_id,
local_learning=local_learning,
output_layer=output_layer,
)
)
# Adding the x,y output dimensions
input_size.append(network[-1]._output_size.tolist())
# #############################################################
# Failure becaue we didn't found the selection of layer
# #############################################################

View file

@ -2,6 +2,8 @@
import torch
from network.Parameter import Config
from network.SbSLayer import SbSLayer
from network.NNMFLayer import NNMFLayer
from network.Conv2dApproximation import Conv2dApproximation
from network.Adam import Adam
@ -26,6 +28,12 @@ def build_optimizer(
parameter_list_weights.append(network[id]._weights)
parameter_list_sbs.append(True)
if (isinstance(network[id], NNMFLayer) is True) and (
network[id]._w_trainable is True
):
parameter_list_weights.append(network[id]._weights)
parameter_list_sbs.append(True)
if (isinstance(network[id], torch.nn.modules.conv.Conv2d) is True) and (
network[id]._w_trainable is True
):

View file

@ -4,6 +4,8 @@ import glob
import numpy as np
from network.SbSLayer import SbSLayer
from network.NNMFLayer import NNMFLayer
from network.SplitOnOffLayer import SplitOnOffLayer
from network.Conv2dApproximation import Conv2dApproximation
import os
@ -46,6 +48,23 @@ def load_previous_weights(
)
logging.info(f"Weights file used for layer {id} : {file_to_load[0]}")
if isinstance(network[id], NNMFLayer) is True:
filename_wilcard = os.path.join(
overload_path, f"Weight_L{id}_*{post_fix}.npy"
)
file_to_load = glob.glob(filename_wilcard)
if len(file_to_load) > 1:
raise Exception(f"Too many previous weights files {filename_wilcard}")
if len(file_to_load) == 1:
network[id].weights = torch.tensor(
np.load(file_to_load[0]),
dtype=default_dtype,
device=device,
)
logging.info(f"Weights file used for layer {id} : {file_to_load[0]}")
if isinstance(network[id], torch.nn.modules.conv.Conv2d) is True:
# #################################################

View file

@ -4,6 +4,7 @@ from network.Parameter import Config
from torch.utils.tensorboard import SummaryWriter
from network.SbSLayer import SbSLayer
from network.NNMFLayer import NNMFLayer
from network.save_weight_and_bias import save_weight_and_bias
from network.SbSReconstruction import SbSReconstruction
@ -19,7 +20,9 @@ def add_weight_and_bias_to_histogram(
# ################################################
# Log the SbS Weights
# ################################################
if isinstance(network[id], SbSLayer) is True:
if (isinstance(network[id], SbSLayer) is True) or (
isinstance(network[id], NNMFLayer) is True
):
if network[id]._w_trainable is True:
try:
@ -228,7 +231,9 @@ def run_optimizer(
cfg: Config,
) -> None:
for id in range(0, len(network)):
if isinstance(network[id], SbSLayer) is True:
if (isinstance(network[id], SbSLayer) is True) or (
isinstance(network[id], NNMFLayer) is True
):
network[id].update_pre_care()
for optimizer_item in optimizer:
@ -236,7 +241,9 @@ def run_optimizer(
optimizer_item.step()
for id in range(0, len(network)):
if isinstance(network[id], SbSLayer) is True:
if (isinstance(network[id], SbSLayer) is True) or (
isinstance(network[id], NNMFLayer) is True
):
network[id].update_after_care(
cfg.learning_parameters.learning_rate_threshold_w
/ float(
@ -618,6 +625,94 @@ def loop_test(
return performance
def loop_test_mix(
epoch_id: int,
cfg: Config,
network: torch.nn.modules.container.Sequential,
my_loader_test: torch.utils.data.dataloader.DataLoader,
the_dataset_test,
device: torch.device,
default_dtype: torch.dtype,
logging,
tb: SummaryWriter | None,
overwrite_number_of_spikes: int = -1,
) -> tuple[float, float]:
test_correct_a_0: int = 0
test_correct_a_1: int = 0
test_correct_b_0: int = 0
test_correct_b_1: int = 0
test_count: int = 0
test_complete: int = the_dataset_test.__len__()
logging.info("")
logging.info("Testing:")
mini_batch_id: int = 0
for h_x, h_x_labels in my_loader_test:
assert len(h_x_labels) == 2
label_a = h_x_labels[0]
label_b = h_x_labels[1]
assert label_a.shape[0] == label_b.shape[0]
assert h_x.shape[0] == label_b.shape[0]
time_0 = time.perf_counter()
h_collection = forward_pass_test(
input=h_x,
labels=label_a,
the_dataset_test=the_dataset_test,
cfg=cfg,
network=network,
device=device,
default_dtype=default_dtype,
mini_batch_id=mini_batch_id,
overwrite_number_of_spikes=overwrite_number_of_spikes,
)
h_h: torch.Tensor = h_collection[-1].detach().clone().cpu()
# -------------
for id in range(0, h_h.shape[0]):
temp = h_h[id, ...].squeeze().argsort(descending=True)
test_correct_a_0 += float(label_a[id] == int(temp[0]))
test_correct_a_1 += float(label_a[id] == int(temp[1]))
test_correct_b_0 += float(label_b[id] == int(temp[0]))
test_correct_b_1 += float(label_b[id] == int(temp[1]))
test_count += h_h.shape[0]
performance_a_0: float = 100.0 * test_correct_a_0 / test_count
performance_a_1: float = 100.0 * test_correct_a_1 / test_count
performance_b_0: float = 100.0 * test_correct_b_0 / test_count
performance_b_1: float = 100.0 * test_correct_b_1 / test_count
time_1 = time.perf_counter()
time_measure_a = time_1 - time_0
logging.info(
(
f"\t\t{test_count} of {test_complete}"
f" with {performance_a_0/100:^6.2%}, "
f"{performance_a_1/100:^6.2%}, "
f"{performance_b_0/100:^6.2%}, "
f"{performance_b_1/100:^6.2%} \t "
f"Time used: {time_measure_a:^6.2f}sec"
)
)
mini_batch_id += 1
logging.info("")
if tb is not None:
tb.add_scalar("Test Error A0", 100.0 - performance_a_0, epoch_id)
tb.add_scalar("Test Error A1", 100.0 - performance_a_1, epoch_id)
tb.add_scalar("Test Error B0", 100.0 - performance_b_0, epoch_id)
tb.add_scalar("Test Error B1", 100.0 - performance_b_1, epoch_id)
tb.flush()
return performance_a_0, performance_a_1, performance_b_0, performance_b_1
def loop_test_reconstruction(
epoch_id: int,
cfg: Config,

View file

@ -4,6 +4,7 @@ from network.Parameter import Config
import numpy as np
from network.SbSLayer import SbSLayer
from network.NNMFLayer import NNMFLayer
from network.SplitOnOffLayer import SplitOnOffLayer
from network.Conv2dApproximation import Conv2dApproximation
@ -38,6 +39,21 @@ def save_weight_and_bias(
network[id].weights.detach().cpu().numpy(),
)
# ################################################
# Save the NNMF Weights
# ################################################
if isinstance(network[id], NNMFLayer) is True:
if network[id]._w_trainable is True:
np.save(
os.path.join(
cfg.weight_path,
f"Weight_L{id}_S{iteration_number}{post_fix}.npy",
),
network[id].weights.detach().cpu().numpy(),
)
# ################################################
# Save the Conv2 Weights and Biases
# ################################################
@ -88,9 +104,10 @@ def save_weight_and_bias(
if isinstance(network[id], SplitOnOffLayer) is True:
np.save(
os.path.join(
cfg.weight_path, f"Mean_L{id}_S{iteration_number}{post_fix}.npy"
),
network[id].mean.detach().cpu().numpy(),
)
if network[id].mean is not None:
np.save(
os.path.join(
cfg.weight_path, f"Mean_L{id}_S{iteration_number}{post_fix}.npy"
),
network[id].mean.detach().cpu().numpy(),
)