Add files via upload
This commit is contained in:
parent
9a3e9273b6
commit
dcee82fca6
4 changed files with 158 additions and 65 deletions
|
@ -1,7 +1,6 @@
|
||||||
# %%
|
# %%
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
@ -101,6 +100,8 @@ class Config:
|
||||||
default_factory=ApproximationSetting
|
default_factory=ApproximationSetting
|
||||||
)
|
)
|
||||||
|
|
||||||
|
extract_noisy_pictures: bool = field(default=False)
|
||||||
|
|
||||||
# For labeling simulations
|
# For labeling simulations
|
||||||
# (not actively used)
|
# (not actively used)
|
||||||
simulation_id: int = field(default=0)
|
simulation_id: int = field(default=0)
|
||||||
|
@ -163,21 +164,6 @@ class Config:
|
||||||
self.batch_size = np.max((self.batch_size, self.number_of_cpu_processes))
|
self.batch_size = np.max((self.batch_size, self.number_of_cpu_processes))
|
||||||
self.batch_size = int(self.batch_size)
|
self.batch_size = int(self.batch_size)
|
||||||
|
|
||||||
def get_epsilon_t(self, number_of_spikes: int):
|
|
||||||
"""Generates the time series of the basic epsilon."""
|
|
||||||
t = np.arange(0, number_of_spikes, dtype=np.float32) + 1
|
|
||||||
np_epsilon_t: np.ndarray = t ** (
|
|
||||||
-1.0 / 2.0
|
|
||||||
) # np.ones((number_of_spikes), dtype=np.float32)
|
|
||||||
|
|
||||||
if (self.cooldown_after_number_of_spikes < number_of_spikes) and (
|
|
||||||
self.cooldown_after_number_of_spikes >= 0
|
|
||||||
):
|
|
||||||
np_epsilon_t[
|
|
||||||
self.cooldown_after_number_of_spikes : number_of_spikes
|
|
||||||
] /= self.reduction_cooldown
|
|
||||||
return torch.tensor(np_epsilon_t)
|
|
||||||
|
|
||||||
def get_update_after_x_pattern(self):
|
def get_update_after_x_pattern(self):
|
||||||
"""Tells us after how many pattern we need to update the weights."""
|
"""Tells us after how many pattern we need to update the weights."""
|
||||||
return (
|
return (
|
||||||
|
|
161
network/SbS.py
161
network/SbS.py
|
@ -3,6 +3,8 @@ import torch
|
||||||
from network.CPP.PySpikeGeneration2DManyIP import SpikeGeneration2DManyIP
|
from network.CPP.PySpikeGeneration2DManyIP import SpikeGeneration2DManyIP
|
||||||
from network.CPP.PyHDynamicCNNManyIP import HDynamicCNNManyIP
|
from network.CPP.PyHDynamicCNNManyIP import HDynamicCNNManyIP
|
||||||
from network.calculate_output_size import calculate_output_size
|
from network.calculate_output_size import calculate_output_size
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
global_sbs_gpu_setting: list[torch.Tensor] = []
|
global_sbs_gpu_setting: list[torch.Tensor] = []
|
||||||
global_sbs_size: list[torch.Tensor] = []
|
global_sbs_size: list[torch.Tensor] = []
|
||||||
|
@ -16,7 +18,6 @@ class SbS(torch.nn.Module):
|
||||||
|
|
||||||
_epsilon_xy: torch.Tensor | None = None
|
_epsilon_xy: torch.Tensor | None = None
|
||||||
_epsilon_0: float
|
_epsilon_0: float
|
||||||
_epsilon_t: torch.Tensor | None = None
|
|
||||||
_weights: torch.nn.parameter.Parameter
|
_weights: torch.nn.parameter.Parameter
|
||||||
_weights_exists: bool = False
|
_weights_exists: bool = False
|
||||||
_kernel_size: list[int]
|
_kernel_size: list[int]
|
||||||
|
@ -58,6 +59,9 @@ class SbS(torch.nn.Module):
|
||||||
spike_generation_cpp_position: int = -1
|
spike_generation_cpp_position: int = -1
|
||||||
spike_generation_gpu_setting_position: int = -1
|
spike_generation_gpu_setting_position: int = -1
|
||||||
|
|
||||||
|
_cooldown_after_number_of_spikes: int = -1
|
||||||
|
_reduction_cooldown: float = 1.0
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
number_of_input_neurons: int,
|
number_of_input_neurons: int,
|
||||||
|
@ -65,7 +69,6 @@ class SbS(torch.nn.Module):
|
||||||
input_size: list[int],
|
input_size: list[int],
|
||||||
forward_kernel_size: list[int],
|
forward_kernel_size: list[int],
|
||||||
number_of_spikes: int,
|
number_of_spikes: int,
|
||||||
epsilon_t: torch.Tensor,
|
|
||||||
epsilon_xy_intitial: float = 0.1,
|
epsilon_xy_intitial: float = 0.1,
|
||||||
epsilon_0: float = 1.0,
|
epsilon_0: float = 1.0,
|
||||||
weight_noise_range: list[float] = [0.0, 1.0],
|
weight_noise_range: list[float] = [0.0, 1.0],
|
||||||
|
@ -83,6 +86,8 @@ class SbS(torch.nn.Module):
|
||||||
default_dtype: torch.dtype | None = None,
|
default_dtype: torch.dtype | None = None,
|
||||||
gpu_tuning_factor: int = 5,
|
gpu_tuning_factor: int = 5,
|
||||||
layer_id: int = -1,
|
layer_id: int = -1,
|
||||||
|
cooldown_after_number_of_spikes: int = -1,
|
||||||
|
reduction_cooldown: float = 1.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -107,6 +112,8 @@ class SbS(torch.nn.Module):
|
||||||
self._number_of_spikes = int(number_of_spikes)
|
self._number_of_spikes = int(number_of_spikes)
|
||||||
self._weight_noise_range = weight_noise_range
|
self._weight_noise_range = weight_noise_range
|
||||||
self._is_pooling_layer = bool(is_pooling_layer)
|
self._is_pooling_layer = bool(is_pooling_layer)
|
||||||
|
self._cooldown_after_number_of_spikes = int(cooldown_after_number_of_spikes)
|
||||||
|
self.reduction_cooldown = float(reduction_cooldown)
|
||||||
|
|
||||||
assert len(input_size) == 2
|
assert len(input_size) == 2
|
||||||
self._input_size = input_size
|
self._input_size = input_size
|
||||||
|
@ -145,8 +152,6 @@ class SbS(torch.nn.Module):
|
||||||
forgetting_offset, dtype=self.default_dtype, device=self.device
|
forgetting_offset, dtype=self.default_dtype, device=self.device
|
||||||
)
|
)
|
||||||
|
|
||||||
self.epsilon_t = epsilon_t.type(dtype=self.default_dtype).to(device=self.device)
|
|
||||||
|
|
||||||
self._output_size = calculate_output_size(
|
self._output_size = calculate_output_size(
|
||||||
value=input_size,
|
value=input_size,
|
||||||
kernel_size=self._kernel_size,
|
kernel_size=self._kernel_size,
|
||||||
|
@ -158,6 +163,7 @@ class SbS(torch.nn.Module):
|
||||||
self.set_h_init_to_uniform()
|
self.set_h_init_to_uniform()
|
||||||
|
|
||||||
self.functional_sbs = FunctionalSbS.apply
|
self.functional_sbs = FunctionalSbS.apply
|
||||||
|
self.functional_spike_generation = FunctionalSpikeGeneration.apply
|
||||||
|
|
||||||
# ###############################################################
|
# ###############################################################
|
||||||
# Initialize the weights
|
# Initialize the weights
|
||||||
|
@ -190,22 +196,23 @@ class SbS(torch.nn.Module):
|
||||||
# Variables in and out #
|
# Variables in and out #
|
||||||
####################################################################
|
####################################################################
|
||||||
|
|
||||||
@property
|
def get_epsilon_t(self, number_of_spikes: int):
|
||||||
def epsilon_t(self) -> torch.Tensor | None:
|
"""Generates the time series of the basic epsilon."""
|
||||||
return self._epsilon_t
|
t = np.arange(0, number_of_spikes, dtype=np.float32) + 1
|
||||||
|
np_epsilon_t: np.ndarray = t ** (
|
||||||
|
-1.0 / 2.0
|
||||||
|
) # np.ones((number_of_spikes), dtype=np.float32)
|
||||||
|
|
||||||
@epsilon_t.setter
|
if (self._cooldown_after_number_of_spikes < number_of_spikes) and (
|
||||||
def epsilon_t(self, value: torch.Tensor):
|
self._cooldown_after_number_of_spikes >= 0
|
||||||
assert value is not None
|
):
|
||||||
assert torch.is_tensor(value) is True
|
np_epsilon_t[
|
||||||
assert value.dim() == 1
|
self._cooldown_after_number_of_spikes : number_of_spikes
|
||||||
assert value.dtype == self.default_dtype
|
] /= self._reduction_cooldown
|
||||||
self._epsilon_t = (
|
return (
|
||||||
value.detach()
|
torch.tensor(np_epsilon_t)
|
||||||
.clone(memory_format=torch.contiguous_format)
|
|
||||||
.type(dtype=self.default_dtype)
|
.type(dtype=self.default_dtype)
|
||||||
.to(device=self.device)
|
.to(device=self.device)
|
||||||
.requires_grad_(False)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -348,7 +355,13 @@ class SbS(torch.nn.Module):
|
||||||
####################################################################
|
####################################################################
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input: torch.Tensor, labels: torch.Tensor | None = None
|
self,
|
||||||
|
input: torch.Tensor,
|
||||||
|
labels: torch.Tensor | None = None,
|
||||||
|
extract_noisy_pictures: bool = False,
|
||||||
|
layer_id: int = -1,
|
||||||
|
mini_batch_id: int = -1,
|
||||||
|
overwrite_number_of_spikes: int = -1,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
# Are we happy with the input?
|
# Are we happy with the input?
|
||||||
|
@ -362,7 +375,6 @@ class SbS(torch.nn.Module):
|
||||||
|
|
||||||
# Are we happy with the rest of the network?
|
# Are we happy with the rest of the network?
|
||||||
assert self._epsilon_0 is not None
|
assert self._epsilon_0 is not None
|
||||||
assert self._epsilon_t is not None
|
|
||||||
|
|
||||||
assert self._h_initial is not None
|
assert self._h_initial is not None
|
||||||
assert self._forgetting_offset is not None
|
assert self._forgetting_offset is not None
|
||||||
|
@ -405,8 +417,15 @@ class SbS(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
self.last_input_data = None
|
self.last_input_data = None
|
||||||
|
|
||||||
|
if overwrite_number_of_spikes >= 1:
|
||||||
|
_number_of_spikes = int(overwrite_number_of_spikes)
|
||||||
|
else:
|
||||||
|
_number_of_spikes = int(self._number_of_spikes)
|
||||||
|
|
||||||
epsilon_t_0: torch.Tensor = (
|
epsilon_t_0: torch.Tensor = (
|
||||||
(self._epsilon_t * self._epsilon_0).type(input.dtype).to(input.device)
|
(self.get_epsilon_t(_number_of_spikes) * self._epsilon_0)
|
||||||
|
.type(input.dtype)
|
||||||
|
.to(input.device)
|
||||||
)
|
)
|
||||||
|
|
||||||
parameter_list = torch.tensor(
|
parameter_list = torch.tensor(
|
||||||
|
@ -415,7 +434,7 @@ class SbS(torch.nn.Module):
|
||||||
int(self._disable_scale_grade), # 1
|
int(self._disable_scale_grade), # 1
|
||||||
int(self._keep_last_grad_scale), # 2
|
int(self._keep_last_grad_scale), # 2
|
||||||
int(self._skip_gradient_calculation), # 3
|
int(self._skip_gradient_calculation), # 3
|
||||||
int(self._number_of_spikes), # 4
|
int(_number_of_spikes), # 4
|
||||||
int(self._number_of_cpu_processes), # 5
|
int(self._number_of_cpu_processes), # 5
|
||||||
int(self._output_size[0]), # 6
|
int(self._output_size[0]), # 6
|
||||||
int(self._output_size[1]), # 7
|
int(self._output_size[1]), # 7
|
||||||
|
@ -448,9 +467,46 @@ class SbS(torch.nn.Module):
|
||||||
assert self._epsilon_xy.shape[1] == input_convolved.shape[2]
|
assert self._epsilon_xy.shape[1] == input_convolved.shape[2]
|
||||||
assert self._epsilon_xy.shape[2] == input_convolved.shape[3]
|
assert self._epsilon_xy.shape[2] == input_convolved.shape[3]
|
||||||
|
|
||||||
|
spike = self.functional_spike_generation(input_convolved, parameter_list)
|
||||||
|
|
||||||
|
if (
|
||||||
|
(extract_noisy_pictures is True)
|
||||||
|
and (layer_id == 0)
|
||||||
|
and (labels is not None)
|
||||||
|
and (mini_batch_id >= 0)
|
||||||
|
):
|
||||||
|
assert labels.shape[0] == spike.shape[0]
|
||||||
|
|
||||||
|
path_sub: str = "noisy_picture_data"
|
||||||
|
path_sub_spikes: str = f"{int(_number_of_spikes)}"
|
||||||
|
path = os.path.join(path_sub, path_sub_spikes)
|
||||||
|
os.makedirs(path_sub, exist_ok=True)
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
|
||||||
|
the_images = torch.zeros_like(
|
||||||
|
input_convolved, dtype=torch.int64, device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
for p_id in range(0, the_images.shape[0]):
|
||||||
|
for sp_id in range(0, spike.shape[1]):
|
||||||
|
for x_id in range(0, the_images.shape[2]):
|
||||||
|
for y_id in range(0, the_images.shape[3]):
|
||||||
|
the_images[
|
||||||
|
p_id, spike[p_id, sp_id, x_id, y_id], x_id, y_id
|
||||||
|
] += 1
|
||||||
|
|
||||||
|
np.savez_compressed(
|
||||||
|
os.path.join(path, f"{mini_batch_id}.npz"),
|
||||||
|
the_images=the_images.cpu().numpy(),
|
||||||
|
labels=labels.cpu().numpy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert spike.shape[1] == _number_of_spikes
|
||||||
|
|
||||||
# SbS forward functional
|
# SbS forward functional
|
||||||
output = self.functional_sbs(
|
output = self.functional_sbs(
|
||||||
input_convolved,
|
input_convolved,
|
||||||
|
spike,
|
||||||
self._epsilon_xy,
|
self._epsilon_xy,
|
||||||
epsilon_t_0,
|
epsilon_t_0,
|
||||||
self._weights,
|
self._weights,
|
||||||
|
@ -468,19 +524,12 @@ class SbS(torch.nn.Module):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class FunctionalSbS(torch.autograd.Function):
|
class FunctionalSpikeGeneration(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward( # type: ignore
|
def forward( # type: ignore
|
||||||
ctx,
|
ctx,
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
epsilon_xy: torch.Tensor,
|
|
||||||
epsilon_t_0: torch.Tensor,
|
|
||||||
weights: torch.Tensor,
|
|
||||||
h_initial: torch.Tensor,
|
|
||||||
parameter_list: torch.Tensor,
|
parameter_list: torch.Tensor,
|
||||||
grad_output_scale: torch.Tensor,
|
|
||||||
forgetting_offset: torch.Tensor,
|
|
||||||
labels: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
assert input.dim() == 4
|
assert input.dim() == 4
|
||||||
|
@ -492,17 +541,6 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
else:
|
else:
|
||||||
spike_number_of_cpu_processes = -1
|
spike_number_of_cpu_processes = -1
|
||||||
|
|
||||||
if input.device == torch.device("cpu"):
|
|
||||||
hdyn_number_of_cpu_processes: int = int(parameter_list[5])
|
|
||||||
else:
|
|
||||||
hdyn_number_of_cpu_processes = -1
|
|
||||||
|
|
||||||
output_size_0: int = int(parameter_list[6])
|
|
||||||
output_size_1: int = int(parameter_list[7])
|
|
||||||
gpu_tuning_factor: int = int(parameter_list[8])
|
|
||||||
|
|
||||||
sbs_gpu_setting_position = int(parameter_list[11])
|
|
||||||
sbs_hdynamic_cpp_position = int(parameter_list[12])
|
|
||||||
spike_generation_cpp_position = int(parameter_list[13])
|
spike_generation_cpp_position = int(parameter_list[13])
|
||||||
spike_generation_gpu_setting_position = int(parameter_list[14])
|
spike_generation_gpu_setting_position = int(parameter_list[14])
|
||||||
|
|
||||||
|
@ -615,6 +653,45 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
del random_values
|
del random_values
|
||||||
del input_cumsum
|
del input_cumsum
|
||||||
|
|
||||||
|
return spikes
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
grad_input = grad_output
|
||||||
|
grad_parameter_list = None
|
||||||
|
return (grad_input, grad_parameter_list)
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionalSbS(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward( # type: ignore
|
||||||
|
ctx,
|
||||||
|
input: torch.Tensor,
|
||||||
|
spikes: torch.Tensor,
|
||||||
|
epsilon_xy: torch.Tensor,
|
||||||
|
epsilon_t_0: torch.Tensor,
|
||||||
|
weights: torch.Tensor,
|
||||||
|
h_initial: torch.Tensor,
|
||||||
|
parameter_list: torch.Tensor,
|
||||||
|
grad_output_scale: torch.Tensor,
|
||||||
|
forgetting_offset: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
number_of_spikes: int = int(parameter_list[4])
|
||||||
|
|
||||||
|
if input.device == torch.device("cpu"):
|
||||||
|
hdyn_number_of_cpu_processes: int = int(parameter_list[5])
|
||||||
|
else:
|
||||||
|
hdyn_number_of_cpu_processes = -1
|
||||||
|
|
||||||
|
output_size_0: int = int(parameter_list[6])
|
||||||
|
output_size_1: int = int(parameter_list[7])
|
||||||
|
gpu_tuning_factor: int = int(parameter_list[8])
|
||||||
|
|
||||||
|
sbs_gpu_setting_position = int(parameter_list[11])
|
||||||
|
sbs_hdynamic_cpp_position = int(parameter_list[12])
|
||||||
|
|
||||||
# ###########################################################
|
# ###########################################################
|
||||||
# H dynamic
|
# H dynamic
|
||||||
# ###########################################################
|
# ###########################################################
|
||||||
|
@ -713,7 +790,6 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
float(forgetting_offset.item()),
|
float(forgetting_offset.item()),
|
||||||
int(gpu_tuning_factor),
|
int(gpu_tuning_factor),
|
||||||
)
|
)
|
||||||
del spikes
|
|
||||||
|
|
||||||
# ###########################################################
|
# ###########################################################
|
||||||
# Save the necessary data for the backward pass
|
# Save the necessary data for the backward pass
|
||||||
|
@ -750,6 +826,7 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
# Default output
|
# Default output
|
||||||
# ##############################################
|
# ##############################################
|
||||||
grad_input = None
|
grad_input = None
|
||||||
|
grad_spikes = None
|
||||||
grad_eps_xy = None
|
grad_eps_xy = None
|
||||||
grad_epsilon_t_0 = None
|
grad_epsilon_t_0 = None
|
||||||
grad_weights = None
|
grad_weights = None
|
||||||
|
@ -789,6 +866,7 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
|
|
||||||
return (
|
return (
|
||||||
grad_input,
|
grad_input,
|
||||||
|
grad_spikes,
|
||||||
grad_eps_xy,
|
grad_eps_xy,
|
||||||
grad_epsilon_t_0,
|
grad_epsilon_t_0,
|
||||||
grad_weights,
|
grad_weights,
|
||||||
|
@ -894,6 +972,7 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
|
|
||||||
return (
|
return (
|
||||||
grad_input,
|
grad_input,
|
||||||
|
grad_spikes,
|
||||||
grad_eps_xy,
|
grad_eps_xy,
|
||||||
grad_epsilon_t_0,
|
grad_epsilon_t_0,
|
||||||
grad_weights,
|
grad_weights,
|
||||||
|
|
|
@ -150,7 +150,6 @@ def build_network(
|
||||||
input_size=input_size[-1],
|
input_size=input_size[-1],
|
||||||
forward_kernel_size=kernel_size,
|
forward_kernel_size=kernel_size,
|
||||||
number_of_spikes=number_of_spikes,
|
number_of_spikes=number_of_spikes,
|
||||||
epsilon_t=cfg.get_epsilon_t(number_of_spikes),
|
|
||||||
epsilon_xy_intitial=cfg.learning_parameters.eps_xy_intitial,
|
epsilon_xy_intitial=cfg.learning_parameters.eps_xy_intitial,
|
||||||
epsilon_0=cfg.epsilon_0,
|
epsilon_0=cfg.epsilon_0,
|
||||||
weight_noise_range=weight_noise_range,
|
weight_noise_range=weight_noise_range,
|
||||||
|
@ -167,6 +166,8 @@ def build_network(
|
||||||
device=device,
|
device=device,
|
||||||
default_dtype=default_dtype,
|
default_dtype=default_dtype,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
cooldown_after_number_of_spikes=cfg.cooldown_after_number_of_spikes,
|
||||||
|
reduction_cooldown=cfg.reduction_cooldown,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Adding the x,y output dimensions
|
# Adding the x,y output dimensions
|
||||||
|
|
|
@ -185,11 +185,14 @@ def forward_pass_train(
|
||||||
|
|
||||||
def forward_pass_test(
|
def forward_pass_test(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
|
labels: torch.Tensor | None,
|
||||||
the_dataset_test,
|
the_dataset_test,
|
||||||
cfg: Config,
|
cfg: Config,
|
||||||
network: torch.nn.modules.container.Sequential,
|
network: torch.nn.modules.container.Sequential,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
default_dtype: torch.dtype,
|
default_dtype: torch.dtype,
|
||||||
|
mini_batch_id: int = -1,
|
||||||
|
overwrite_number_of_spikes: int = -1,
|
||||||
) -> list[torch.Tensor]:
|
) -> list[torch.Tensor]:
|
||||||
|
|
||||||
h_collection = []
|
h_collection = []
|
||||||
|
@ -199,7 +202,22 @@ def forward_pass_test(
|
||||||
.to(device=device)
|
.to(device=device)
|
||||||
)
|
)
|
||||||
for id in range(0, len(network)):
|
for id in range(0, len(network)):
|
||||||
h_collection.append(network[id](h_collection[-1]))
|
if (cfg.extract_noisy_pictures is True) or (overwrite_number_of_spikes != -1):
|
||||||
|
if isinstance(network[id], SbS) is True:
|
||||||
|
h_collection.append(
|
||||||
|
network[id](
|
||||||
|
h_collection[-1],
|
||||||
|
layer_id=id,
|
||||||
|
labels=labels,
|
||||||
|
extract_noisy_pictures=cfg.extract_noisy_pictures,
|
||||||
|
mini_batch_id=mini_batch_id,
|
||||||
|
overwrite_number_of_spikes=overwrite_number_of_spikes,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
h_collection.append(network[id](h_collection[-1]))
|
||||||
|
else:
|
||||||
|
h_collection.append(network[id](h_collection[-1]))
|
||||||
|
|
||||||
return h_collection
|
return h_collection
|
||||||
|
|
||||||
|
@ -545,7 +563,8 @@ def loop_test(
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
default_dtype: torch.dtype,
|
default_dtype: torch.dtype,
|
||||||
logging,
|
logging,
|
||||||
tb: SummaryWriter,
|
tb: SummaryWriter | None,
|
||||||
|
overwrite_number_of_spikes: int = -1,
|
||||||
) -> float:
|
) -> float:
|
||||||
|
|
||||||
test_correct = 0
|
test_correct = 0
|
||||||
|
@ -554,17 +573,21 @@ def loop_test(
|
||||||
|
|
||||||
logging.info("")
|
logging.info("")
|
||||||
logging.info("Testing:")
|
logging.info("Testing:")
|
||||||
|
mini_batch_id: int = 0
|
||||||
|
|
||||||
for h_x, h_x_labels in my_loader_test:
|
for h_x, h_x_labels in my_loader_test:
|
||||||
time_0 = time.perf_counter()
|
time_0 = time.perf_counter()
|
||||||
|
|
||||||
h_collection = forward_pass_test(
|
h_collection = forward_pass_test(
|
||||||
input=h_x,
|
input=h_x,
|
||||||
|
labels=h_x_labels,
|
||||||
the_dataset_test=the_dataset_test,
|
the_dataset_test=the_dataset_test,
|
||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
network=network,
|
network=network,
|
||||||
device=device,
|
device=device,
|
||||||
default_dtype=default_dtype,
|
default_dtype=default_dtype,
|
||||||
|
mini_batch_id=mini_batch_id,
|
||||||
|
overwrite_number_of_spikes=overwrite_number_of_spikes,
|
||||||
)
|
)
|
||||||
h_h: torch.Tensor = h_collection[-1].detach().clone().cpu()
|
h_h: torch.Tensor = h_collection[-1].detach().clone().cpu()
|
||||||
|
|
||||||
|
@ -580,11 +603,13 @@ def loop_test(
|
||||||
f" with {performance/100:^6.2%} \t Time used: {time_measure_a:^6.2f}sec"
|
f" with {performance/100:^6.2%} \t Time used: {time_measure_a:^6.2f}sec"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
mini_batch_id += 1
|
||||||
|
|
||||||
logging.info("")
|
logging.info("")
|
||||||
|
|
||||||
tb.add_scalar("Test Error", 100.0 - performance, epoch_id)
|
if tb is not None:
|
||||||
tb.flush()
|
tb.add_scalar("Test Error", 100.0 - performance, epoch_id)
|
||||||
|
tb.flush()
|
||||||
|
|
||||||
return performance
|
return performance
|
||||||
|
|
||||||
|
@ -598,7 +623,7 @@ def loop_test_reconstruction(
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
default_dtype: torch.dtype,
|
default_dtype: torch.dtype,
|
||||||
logging,
|
logging,
|
||||||
tb: SummaryWriter,
|
tb: SummaryWriter | None,
|
||||||
) -> float:
|
) -> float:
|
||||||
|
|
||||||
test_count: int = 0
|
test_count: int = 0
|
||||||
|
@ -613,6 +638,7 @@ def loop_test_reconstruction(
|
||||||
|
|
||||||
h_collection = forward_pass_test(
|
h_collection = forward_pass_test(
|
||||||
input=h_x,
|
input=h_x,
|
||||||
|
labels=None,
|
||||||
the_dataset_test=the_dataset_test,
|
the_dataset_test=the_dataset_test,
|
||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
network=network,
|
network=network,
|
||||||
|
@ -645,7 +671,8 @@ def loop_test_reconstruction(
|
||||||
|
|
||||||
logging.info("")
|
logging.info("")
|
||||||
|
|
||||||
tb.add_scalar("Test Error", performance, epoch_id)
|
if tb is not None:
|
||||||
tb.flush()
|
tb.add_scalar("Test Error", performance, epoch_id)
|
||||||
|
tb.flush()
|
||||||
|
|
||||||
return performance
|
return performance
|
||||||
|
|
Loading…
Reference in a new issue