Add files via upload
This commit is contained in:
parent
9a3e9273b6
commit
dcee82fca6
4 changed files with 158 additions and 65 deletions
|
@ -1,7 +1,6 @@
|
|||
# %%
|
||||
from dataclasses import dataclass, field
|
||||
import numpy as np
|
||||
import torch
|
||||
import os
|
||||
|
||||
|
||||
|
@ -101,6 +100,8 @@ class Config:
|
|||
default_factory=ApproximationSetting
|
||||
)
|
||||
|
||||
extract_noisy_pictures: bool = field(default=False)
|
||||
|
||||
# For labeling simulations
|
||||
# (not actively used)
|
||||
simulation_id: int = field(default=0)
|
||||
|
@ -163,21 +164,6 @@ class Config:
|
|||
self.batch_size = np.max((self.batch_size, self.number_of_cpu_processes))
|
||||
self.batch_size = int(self.batch_size)
|
||||
|
||||
def get_epsilon_t(self, number_of_spikes: int):
|
||||
"""Generates the time series of the basic epsilon."""
|
||||
t = np.arange(0, number_of_spikes, dtype=np.float32) + 1
|
||||
np_epsilon_t: np.ndarray = t ** (
|
||||
-1.0 / 2.0
|
||||
) # np.ones((number_of_spikes), dtype=np.float32)
|
||||
|
||||
if (self.cooldown_after_number_of_spikes < number_of_spikes) and (
|
||||
self.cooldown_after_number_of_spikes >= 0
|
||||
):
|
||||
np_epsilon_t[
|
||||
self.cooldown_after_number_of_spikes : number_of_spikes
|
||||
] /= self.reduction_cooldown
|
||||
return torch.tensor(np_epsilon_t)
|
||||
|
||||
def get_update_after_x_pattern(self):
|
||||
"""Tells us after how many pattern we need to update the weights."""
|
||||
return (
|
||||
|
|
161
network/SbS.py
161
network/SbS.py
|
@ -3,6 +3,8 @@ import torch
|
|||
from network.CPP.PySpikeGeneration2DManyIP import SpikeGeneration2DManyIP
|
||||
from network.CPP.PyHDynamicCNNManyIP import HDynamicCNNManyIP
|
||||
from network.calculate_output_size import calculate_output_size
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
global_sbs_gpu_setting: list[torch.Tensor] = []
|
||||
global_sbs_size: list[torch.Tensor] = []
|
||||
|
@ -16,7 +18,6 @@ class SbS(torch.nn.Module):
|
|||
|
||||
_epsilon_xy: torch.Tensor | None = None
|
||||
_epsilon_0: float
|
||||
_epsilon_t: torch.Tensor | None = None
|
||||
_weights: torch.nn.parameter.Parameter
|
||||
_weights_exists: bool = False
|
||||
_kernel_size: list[int]
|
||||
|
@ -58,6 +59,9 @@ class SbS(torch.nn.Module):
|
|||
spike_generation_cpp_position: int = -1
|
||||
spike_generation_gpu_setting_position: int = -1
|
||||
|
||||
_cooldown_after_number_of_spikes: int = -1
|
||||
_reduction_cooldown: float = 1.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
number_of_input_neurons: int,
|
||||
|
@ -65,7 +69,6 @@ class SbS(torch.nn.Module):
|
|||
input_size: list[int],
|
||||
forward_kernel_size: list[int],
|
||||
number_of_spikes: int,
|
||||
epsilon_t: torch.Tensor,
|
||||
epsilon_xy_intitial: float = 0.1,
|
||||
epsilon_0: float = 1.0,
|
||||
weight_noise_range: list[float] = [0.0, 1.0],
|
||||
|
@ -83,6 +86,8 @@ class SbS(torch.nn.Module):
|
|||
default_dtype: torch.dtype | None = None,
|
||||
gpu_tuning_factor: int = 5,
|
||||
layer_id: int = -1,
|
||||
cooldown_after_number_of_spikes: int = -1,
|
||||
reduction_cooldown: float = 1.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
@ -107,6 +112,8 @@ class SbS(torch.nn.Module):
|
|||
self._number_of_spikes = int(number_of_spikes)
|
||||
self._weight_noise_range = weight_noise_range
|
||||
self._is_pooling_layer = bool(is_pooling_layer)
|
||||
self._cooldown_after_number_of_spikes = int(cooldown_after_number_of_spikes)
|
||||
self.reduction_cooldown = float(reduction_cooldown)
|
||||
|
||||
assert len(input_size) == 2
|
||||
self._input_size = input_size
|
||||
|
@ -145,8 +152,6 @@ class SbS(torch.nn.Module):
|
|||
forgetting_offset, dtype=self.default_dtype, device=self.device
|
||||
)
|
||||
|
||||
self.epsilon_t = epsilon_t.type(dtype=self.default_dtype).to(device=self.device)
|
||||
|
||||
self._output_size = calculate_output_size(
|
||||
value=input_size,
|
||||
kernel_size=self._kernel_size,
|
||||
|
@ -158,6 +163,7 @@ class SbS(torch.nn.Module):
|
|||
self.set_h_init_to_uniform()
|
||||
|
||||
self.functional_sbs = FunctionalSbS.apply
|
||||
self.functional_spike_generation = FunctionalSpikeGeneration.apply
|
||||
|
||||
# ###############################################################
|
||||
# Initialize the weights
|
||||
|
@ -190,22 +196,23 @@ class SbS(torch.nn.Module):
|
|||
# Variables in and out #
|
||||
####################################################################
|
||||
|
||||
@property
|
||||
def epsilon_t(self) -> torch.Tensor | None:
|
||||
return self._epsilon_t
|
||||
def get_epsilon_t(self, number_of_spikes: int):
|
||||
"""Generates the time series of the basic epsilon."""
|
||||
t = np.arange(0, number_of_spikes, dtype=np.float32) + 1
|
||||
np_epsilon_t: np.ndarray = t ** (
|
||||
-1.0 / 2.0
|
||||
) # np.ones((number_of_spikes), dtype=np.float32)
|
||||
|
||||
@epsilon_t.setter
|
||||
def epsilon_t(self, value: torch.Tensor):
|
||||
assert value is not None
|
||||
assert torch.is_tensor(value) is True
|
||||
assert value.dim() == 1
|
||||
assert value.dtype == self.default_dtype
|
||||
self._epsilon_t = (
|
||||
value.detach()
|
||||
.clone(memory_format=torch.contiguous_format)
|
||||
if (self._cooldown_after_number_of_spikes < number_of_spikes) and (
|
||||
self._cooldown_after_number_of_spikes >= 0
|
||||
):
|
||||
np_epsilon_t[
|
||||
self._cooldown_after_number_of_spikes : number_of_spikes
|
||||
] /= self._reduction_cooldown
|
||||
return (
|
||||
torch.tensor(np_epsilon_t)
|
||||
.type(dtype=self.default_dtype)
|
||||
.to(device=self.device)
|
||||
.requires_grad_(False)
|
||||
)
|
||||
|
||||
@property
|
||||
|
@ -348,7 +355,13 @@ class SbS(torch.nn.Module):
|
|||
####################################################################
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, labels: torch.Tensor | None = None
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
labels: torch.Tensor | None = None,
|
||||
extract_noisy_pictures: bool = False,
|
||||
layer_id: int = -1,
|
||||
mini_batch_id: int = -1,
|
||||
overwrite_number_of_spikes: int = -1,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# Are we happy with the input?
|
||||
|
@ -362,7 +375,6 @@ class SbS(torch.nn.Module):
|
|||
|
||||
# Are we happy with the rest of the network?
|
||||
assert self._epsilon_0 is not None
|
||||
assert self._epsilon_t is not None
|
||||
|
||||
assert self._h_initial is not None
|
||||
assert self._forgetting_offset is not None
|
||||
|
@ -405,8 +417,15 @@ class SbS(torch.nn.Module):
|
|||
else:
|
||||
self.last_input_data = None
|
||||
|
||||
if overwrite_number_of_spikes >= 1:
|
||||
_number_of_spikes = int(overwrite_number_of_spikes)
|
||||
else:
|
||||
_number_of_spikes = int(self._number_of_spikes)
|
||||
|
||||
epsilon_t_0: torch.Tensor = (
|
||||
(self._epsilon_t * self._epsilon_0).type(input.dtype).to(input.device)
|
||||
(self.get_epsilon_t(_number_of_spikes) * self._epsilon_0)
|
||||
.type(input.dtype)
|
||||
.to(input.device)
|
||||
)
|
||||
|
||||
parameter_list = torch.tensor(
|
||||
|
@ -415,7 +434,7 @@ class SbS(torch.nn.Module):
|
|||
int(self._disable_scale_grade), # 1
|
||||
int(self._keep_last_grad_scale), # 2
|
||||
int(self._skip_gradient_calculation), # 3
|
||||
int(self._number_of_spikes), # 4
|
||||
int(_number_of_spikes), # 4
|
||||
int(self._number_of_cpu_processes), # 5
|
||||
int(self._output_size[0]), # 6
|
||||
int(self._output_size[1]), # 7
|
||||
|
@ -448,9 +467,46 @@ class SbS(torch.nn.Module):
|
|||
assert self._epsilon_xy.shape[1] == input_convolved.shape[2]
|
||||
assert self._epsilon_xy.shape[2] == input_convolved.shape[3]
|
||||
|
||||
spike = self.functional_spike_generation(input_convolved, parameter_list)
|
||||
|
||||
if (
|
||||
(extract_noisy_pictures is True)
|
||||
and (layer_id == 0)
|
||||
and (labels is not None)
|
||||
and (mini_batch_id >= 0)
|
||||
):
|
||||
assert labels.shape[0] == spike.shape[0]
|
||||
|
||||
path_sub: str = "noisy_picture_data"
|
||||
path_sub_spikes: str = f"{int(_number_of_spikes)}"
|
||||
path = os.path.join(path_sub, path_sub_spikes)
|
||||
os.makedirs(path_sub, exist_ok=True)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
the_images = torch.zeros_like(
|
||||
input_convolved, dtype=torch.int64, device=self.device
|
||||
)
|
||||
|
||||
for p_id in range(0, the_images.shape[0]):
|
||||
for sp_id in range(0, spike.shape[1]):
|
||||
for x_id in range(0, the_images.shape[2]):
|
||||
for y_id in range(0, the_images.shape[3]):
|
||||
the_images[
|
||||
p_id, spike[p_id, sp_id, x_id, y_id], x_id, y_id
|
||||
] += 1
|
||||
|
||||
np.savez_compressed(
|
||||
os.path.join(path, f"{mini_batch_id}.npz"),
|
||||
the_images=the_images.cpu().numpy(),
|
||||
labels=labels.cpu().numpy(),
|
||||
)
|
||||
|
||||
assert spike.shape[1] == _number_of_spikes
|
||||
|
||||
# SbS forward functional
|
||||
output = self.functional_sbs(
|
||||
input_convolved,
|
||||
spike,
|
||||
self._epsilon_xy,
|
||||
epsilon_t_0,
|
||||
self._weights,
|
||||
|
@ -468,19 +524,12 @@ class SbS(torch.nn.Module):
|
|||
return output
|
||||
|
||||
|
||||
class FunctionalSbS(torch.autograd.Function):
|
||||
class FunctionalSpikeGeneration(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward( # type: ignore
|
||||
ctx,
|
||||
input: torch.Tensor,
|
||||
epsilon_xy: torch.Tensor,
|
||||
epsilon_t_0: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
h_initial: torch.Tensor,
|
||||
parameter_list: torch.Tensor,
|
||||
grad_output_scale: torch.Tensor,
|
||||
forgetting_offset: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert input.dim() == 4
|
||||
|
@ -492,17 +541,6 @@ class FunctionalSbS(torch.autograd.Function):
|
|||
else:
|
||||
spike_number_of_cpu_processes = -1
|
||||
|
||||
if input.device == torch.device("cpu"):
|
||||
hdyn_number_of_cpu_processes: int = int(parameter_list[5])
|
||||
else:
|
||||
hdyn_number_of_cpu_processes = -1
|
||||
|
||||
output_size_0: int = int(parameter_list[6])
|
||||
output_size_1: int = int(parameter_list[7])
|
||||
gpu_tuning_factor: int = int(parameter_list[8])
|
||||
|
||||
sbs_gpu_setting_position = int(parameter_list[11])
|
||||
sbs_hdynamic_cpp_position = int(parameter_list[12])
|
||||
spike_generation_cpp_position = int(parameter_list[13])
|
||||
spike_generation_gpu_setting_position = int(parameter_list[14])
|
||||
|
||||
|
@ -615,6 +653,45 @@ class FunctionalSbS(torch.autograd.Function):
|
|||
del random_values
|
||||
del input_cumsum
|
||||
|
||||
return spikes
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_input = grad_output
|
||||
grad_parameter_list = None
|
||||
return (grad_input, grad_parameter_list)
|
||||
|
||||
|
||||
class FunctionalSbS(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward( # type: ignore
|
||||
ctx,
|
||||
input: torch.Tensor,
|
||||
spikes: torch.Tensor,
|
||||
epsilon_xy: torch.Tensor,
|
||||
epsilon_t_0: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
h_initial: torch.Tensor,
|
||||
parameter_list: torch.Tensor,
|
||||
grad_output_scale: torch.Tensor,
|
||||
forgetting_offset: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
number_of_spikes: int = int(parameter_list[4])
|
||||
|
||||
if input.device == torch.device("cpu"):
|
||||
hdyn_number_of_cpu_processes: int = int(parameter_list[5])
|
||||
else:
|
||||
hdyn_number_of_cpu_processes = -1
|
||||
|
||||
output_size_0: int = int(parameter_list[6])
|
||||
output_size_1: int = int(parameter_list[7])
|
||||
gpu_tuning_factor: int = int(parameter_list[8])
|
||||
|
||||
sbs_gpu_setting_position = int(parameter_list[11])
|
||||
sbs_hdynamic_cpp_position = int(parameter_list[12])
|
||||
|
||||
# ###########################################################
|
||||
# H dynamic
|
||||
# ###########################################################
|
||||
|
@ -713,7 +790,6 @@ class FunctionalSbS(torch.autograd.Function):
|
|||
float(forgetting_offset.item()),
|
||||
int(gpu_tuning_factor),
|
||||
)
|
||||
del spikes
|
||||
|
||||
# ###########################################################
|
||||
# Save the necessary data for the backward pass
|
||||
|
@ -750,6 +826,7 @@ class FunctionalSbS(torch.autograd.Function):
|
|||
# Default output
|
||||
# ##############################################
|
||||
grad_input = None
|
||||
grad_spikes = None
|
||||
grad_eps_xy = None
|
||||
grad_epsilon_t_0 = None
|
||||
grad_weights = None
|
||||
|
@ -789,6 +866,7 @@ class FunctionalSbS(torch.autograd.Function):
|
|||
|
||||
return (
|
||||
grad_input,
|
||||
grad_spikes,
|
||||
grad_eps_xy,
|
||||
grad_epsilon_t_0,
|
||||
grad_weights,
|
||||
|
@ -894,6 +972,7 @@ class FunctionalSbS(torch.autograd.Function):
|
|||
|
||||
return (
|
||||
grad_input,
|
||||
grad_spikes,
|
||||
grad_eps_xy,
|
||||
grad_epsilon_t_0,
|
||||
grad_weights,
|
||||
|
|
|
@ -150,7 +150,6 @@ def build_network(
|
|||
input_size=input_size[-1],
|
||||
forward_kernel_size=kernel_size,
|
||||
number_of_spikes=number_of_spikes,
|
||||
epsilon_t=cfg.get_epsilon_t(number_of_spikes),
|
||||
epsilon_xy_intitial=cfg.learning_parameters.eps_xy_intitial,
|
||||
epsilon_0=cfg.epsilon_0,
|
||||
weight_noise_range=weight_noise_range,
|
||||
|
@ -167,6 +166,8 @@ def build_network(
|
|||
device=device,
|
||||
default_dtype=default_dtype,
|
||||
layer_id=layer_id,
|
||||
cooldown_after_number_of_spikes=cfg.cooldown_after_number_of_spikes,
|
||||
reduction_cooldown=cfg.reduction_cooldown,
|
||||
)
|
||||
)
|
||||
# Adding the x,y output dimensions
|
||||
|
|
|
@ -185,11 +185,14 @@ def forward_pass_train(
|
|||
|
||||
def forward_pass_test(
|
||||
input: torch.Tensor,
|
||||
labels: torch.Tensor | None,
|
||||
the_dataset_test,
|
||||
cfg: Config,
|
||||
network: torch.nn.modules.container.Sequential,
|
||||
device: torch.device,
|
||||
default_dtype: torch.dtype,
|
||||
mini_batch_id: int = -1,
|
||||
overwrite_number_of_spikes: int = -1,
|
||||
) -> list[torch.Tensor]:
|
||||
|
||||
h_collection = []
|
||||
|
@ -199,7 +202,22 @@ def forward_pass_test(
|
|||
.to(device=device)
|
||||
)
|
||||
for id in range(0, len(network)):
|
||||
h_collection.append(network[id](h_collection[-1]))
|
||||
if (cfg.extract_noisy_pictures is True) or (overwrite_number_of_spikes != -1):
|
||||
if isinstance(network[id], SbS) is True:
|
||||
h_collection.append(
|
||||
network[id](
|
||||
h_collection[-1],
|
||||
layer_id=id,
|
||||
labels=labels,
|
||||
extract_noisy_pictures=cfg.extract_noisy_pictures,
|
||||
mini_batch_id=mini_batch_id,
|
||||
overwrite_number_of_spikes=overwrite_number_of_spikes,
|
||||
)
|
||||
)
|
||||
else:
|
||||
h_collection.append(network[id](h_collection[-1]))
|
||||
else:
|
||||
h_collection.append(network[id](h_collection[-1]))
|
||||
|
||||
return h_collection
|
||||
|
||||
|
@ -545,7 +563,8 @@ def loop_test(
|
|||
device: torch.device,
|
||||
default_dtype: torch.dtype,
|
||||
logging,
|
||||
tb: SummaryWriter,
|
||||
tb: SummaryWriter | None,
|
||||
overwrite_number_of_spikes: int = -1,
|
||||
) -> float:
|
||||
|
||||
test_correct = 0
|
||||
|
@ -554,17 +573,21 @@ def loop_test(
|
|||
|
||||
logging.info("")
|
||||
logging.info("Testing:")
|
||||
mini_batch_id: int = 0
|
||||
|
||||
for h_x, h_x_labels in my_loader_test:
|
||||
time_0 = time.perf_counter()
|
||||
|
||||
h_collection = forward_pass_test(
|
||||
input=h_x,
|
||||
labels=h_x_labels,
|
||||
the_dataset_test=the_dataset_test,
|
||||
cfg=cfg,
|
||||
network=network,
|
||||
device=device,
|
||||
default_dtype=default_dtype,
|
||||
mini_batch_id=mini_batch_id,
|
||||
overwrite_number_of_spikes=overwrite_number_of_spikes,
|
||||
)
|
||||
h_h: torch.Tensor = h_collection[-1].detach().clone().cpu()
|
||||
|
||||
|
@ -580,11 +603,13 @@ def loop_test(
|
|||
f" with {performance/100:^6.2%} \t Time used: {time_measure_a:^6.2f}sec"
|
||||
)
|
||||
)
|
||||
mini_batch_id += 1
|
||||
|
||||
logging.info("")
|
||||
|
||||
tb.add_scalar("Test Error", 100.0 - performance, epoch_id)
|
||||
tb.flush()
|
||||
if tb is not None:
|
||||
tb.add_scalar("Test Error", 100.0 - performance, epoch_id)
|
||||
tb.flush()
|
||||
|
||||
return performance
|
||||
|
||||
|
@ -598,7 +623,7 @@ def loop_test_reconstruction(
|
|||
device: torch.device,
|
||||
default_dtype: torch.dtype,
|
||||
logging,
|
||||
tb: SummaryWriter,
|
||||
tb: SummaryWriter | None,
|
||||
) -> float:
|
||||
|
||||
test_count: int = 0
|
||||
|
@ -613,6 +638,7 @@ def loop_test_reconstruction(
|
|||
|
||||
h_collection = forward_pass_test(
|
||||
input=h_x,
|
||||
labels=None,
|
||||
the_dataset_test=the_dataset_test,
|
||||
cfg=cfg,
|
||||
network=network,
|
||||
|
@ -645,7 +671,8 @@ def loop_test_reconstruction(
|
|||
|
||||
logging.info("")
|
||||
|
||||
tb.add_scalar("Test Error", performance, epoch_id)
|
||||
tb.flush()
|
||||
if tb is not None:
|
||||
tb.add_scalar("Test Error", performance, epoch_id)
|
||||
tb.flush()
|
||||
|
||||
return performance
|
||||
|
|
Loading…
Reference in a new issue