Add files via upload

This commit is contained in:
David Rotermund 2023-02-04 14:24:47 +01:00 committed by GitHub
parent 5ac2f1dc96
commit d23e2edd8d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 1406 additions and 84 deletions

View file

@ -12,6 +12,8 @@ class DatasetMaster(torch.utils.data.Dataset, ABC):
pattern_storage: np.ndarray
number_of_pattern: int
mean: list[float]
initial_size: list[int]
channel_size: int
# Initialize
def __init__(
@ -36,6 +38,9 @@ class DatasetMaster(torch.utils.data.Dataset, ABC):
self.mean = []
self.initial_size = [0, 0]
self.channel_size = 0
def __len__(self) -> int:
return self.number_of_pattern
@ -74,6 +79,9 @@ class DatasetMNIST(DatasetMaster):
mean = self.pattern_storage.mean(3).mean(2).mean(0)
self.mean = [*mean]
self.initial_size = [28, 28]
self.channel_size = 1
def __getitem__(self, index: int) -> tuple[torch.Tensor, int]:
image = self.pattern_storage[index, 0:1, :, :]
@ -154,6 +162,9 @@ class DatasetFashionMNIST(DatasetMaster):
mean = self.pattern_storage.mean(3).mean(2).mean(0)
self.mean = [*mean]
self.initial_size = [28, 28]
self.channel_size = 1
def __getitem__(self, index: int) -> tuple[torch.Tensor, int]:
image = self.pattern_storage[index, 0:1, :, :]
@ -240,6 +251,9 @@ class DatasetCIFAR(DatasetMaster):
mean = self.pattern_storage.mean(3).mean(2).mean(0)
self.mean = [*mean]
self.initial_size = [32, 32]
self.channel_size = 3
def __getitem__(self, index: int) -> tuple[torch.Tensor, int]:
image = self.pattern_storage[index, :, :, :]

451
network/HDynamicLayer.py Normal file
View file

@ -0,0 +1,451 @@
import torch
from network.PyHDynamicCNNCPU import HDynamicCNNCPU
from network.PyHDynamicCNNGPU import HDynamicCNNGPU
global_sbs_gpu_setting: list[torch.Tensor] = []
global_sbs_size: list[torch.Tensor] = []
global_sbs_hdynamic_cpp: list[HDynamicCNNCPU | HDynamicCNNGPU] = []
class HDynamicLayer(torch.nn.Module):
_sbs_gpu_setting_position: int
_sbs_hdynamic_cpp_position: int
_gpu_tuning_factor: int
_number_of_cpu_processes: int
_output_size: list[int]
_w_trainable: bool
_output_layer: bool
_local_learning: bool
device: torch.device
default_dtype: torch.dtype
def __init__(
self,
output_size: list[int],
output_layer: bool = False,
local_learning: bool = False,
number_of_cpu_processes: int = 1,
w_trainable: bool = False,
skip_gradient_calculation: bool = False,
device: torch.device | None = None,
default_dtype: torch.dtype | None = None,
gpu_tuning_factor: int = 5,
) -> None:
super().__init__()
assert device is not None
self.device = device
self.default_dtype = default_dtype
self._gpu_tuning_factor = int(gpu_tuning_factor)
self._number_of_cpu_processes = int(number_of_cpu_processes)
self._w_trainable = bool(w_trainable)
self._skip_gradient_calculation = bool(skip_gradient_calculation)
self._output_size = output_size
self._output_layer = bool(output_layer)
self._local_learning = bool(local_learning)
global_sbs_gpu_setting.append(torch.tensor([0]))
global_sbs_size.append(torch.tensor([0, 0, 0, 0]))
if device == torch.device("cpu"):
global_sbs_hdynamic_cpp.append(HDynamicCNNCPU())
else:
global_sbs_hdynamic_cpp.append(HDynamicCNNGPU())
self._sbs_gpu_setting_position = len(global_sbs_gpu_setting) - 1
self._sbs_hdynamic_cpp_position = len(global_sbs_hdynamic_cpp) - 1
self.functional_sbs = FunctionalSbS.apply
####################################################################
# Forward #
####################################################################
def forward(
self,
input: torch.Tensor,
spike: torch.Tensor,
epsilon_xy: torch.Tensor,
epsilon_t_0: torch.Tensor,
weights: torch.Tensor,
h_initial: torch.Tensor,
last_grad_scale: torch.Tensor,
labels: torch.Tensor | None = None,
keep_last_grad_scale: bool = False,
disable_scale_grade: bool = True,
forgetting_offset: float = -1.0,
) -> torch.Tensor:
if labels is None:
labels_copy: torch.Tensor = torch.tensor(
[], dtype=torch.int64, device=self.device
)
else:
labels_copy = (
labels.detach().clone().type(dtype=torch.int64).to(device=self.device)
)
if (spike.shape[-2] * spike.shape[-1]) > self._gpu_tuning_factor:
gpu_tuning_factor = self._gpu_tuning_factor
else:
gpu_tuning_factor = 0
parameter_list = torch.tensor(
[
int(self._number_of_cpu_processes), # 0
int(self._output_size[0]), # 1
int(self._output_size[1]), # 2
int(gpu_tuning_factor), # 3
int(self._sbs_gpu_setting_position), # 4
int(self._sbs_hdynamic_cpp_position), # 5
int(self._w_trainable), # 6
int(disable_scale_grade), # 7
int(keep_last_grad_scale), # 8
int(self._skip_gradient_calculation), # 9
int(self._output_layer), # 10
int(self._local_learning), # 11
],
dtype=torch.int64,
)
# SbS forward functional
return self.functional_sbs(
input,
spike,
epsilon_xy,
epsilon_t_0,
weights,
h_initial,
parameter_list,
last_grad_scale,
torch.tensor(
forgetting_offset, device=self.device, dtype=self.default_dtype
),
labels_copy,
)
class FunctionalSbS(torch.autograd.Function):
@staticmethod
def forward( # type: ignore
ctx,
input: torch.Tensor,
spikes: torch.Tensor,
epsilon_xy: torch.Tensor | None,
epsilon_t_0: torch.Tensor,
weights: torch.Tensor,
h_initial: torch.Tensor,
parameter_list: torch.Tensor,
grad_output_scale: torch.Tensor,
forgetting_offset: torch.Tensor,
labels: torch.Tensor,
) -> torch.Tensor:
number_of_spikes: int = int(spikes.shape[1])
if input.device == torch.device("cpu"):
hdyn_number_of_cpu_processes: int = int(parameter_list[0])
else:
hdyn_number_of_cpu_processes = -1
output_size_0: int = int(parameter_list[1])
output_size_1: int = int(parameter_list[2])
gpu_tuning_factor: int = int(parameter_list[3])
sbs_gpu_setting_position = int(parameter_list[4])
sbs_hdynamic_cpp_position = int(parameter_list[5])
# ###########################################################
# H dynamic
# ###########################################################
assert epsilon_t_0.ndim == 1
assert epsilon_t_0.shape[0] >= number_of_spikes
# ############################################
# Make space for the results
# ############################################
output = torch.empty(
(
int(input.shape[0]),
int(weights.shape[1]),
output_size_0,
output_size_1,
),
dtype=input.dtype,
device=input.device,
)
assert output.is_contiguous() is True
if epsilon_xy is not None:
assert epsilon_xy.is_contiguous() is True
assert epsilon_xy.ndim == 3
assert epsilon_t_0.is_contiguous() is True
assert weights.is_contiguous() is True
assert spikes.is_contiguous() is True
assert h_initial.is_contiguous() is True
assert weights.ndim == 2
assert h_initial.ndim == 1
sbs_profile = global_sbs_gpu_setting[sbs_gpu_setting_position].clone()
sbs_size = global_sbs_size[sbs_gpu_setting_position].clone()
if input.device != torch.device("cpu"):
if (
(sbs_profile.numel() == 1)
or (sbs_size[0] != int(output.shape[0]))
or (sbs_size[1] != int(output.shape[1]))
or (sbs_size[2] != int(output.shape[2]))
or (sbs_size[3] != int(output.shape[3]))
):
sbs_profile = torch.zeros(
(14, 7), dtype=torch.int64, device=torch.device("cpu")
)
global_sbs_hdynamic_cpp[sbs_hdynamic_cpp_position].gpu_occupancy_export(
int(output.shape[2]),
int(output.shape[3]),
int(output.shape[0]),
int(output.shape[1]),
sbs_profile.data_ptr(),
int(sbs_profile.shape[0]),
int(sbs_profile.shape[1]),
)
global_sbs_gpu_setting[sbs_gpu_setting_position] = sbs_profile.clone()
sbs_size[0] = int(output.shape[0])
sbs_size[1] = int(output.shape[1])
sbs_size[2] = int(output.shape[2])
sbs_size[3] = int(output.shape[3])
global_sbs_size[sbs_gpu_setting_position] = sbs_size.clone()
else:
global_sbs_hdynamic_cpp[sbs_hdynamic_cpp_position].gpu_occupancy_import(
sbs_profile.data_ptr(),
int(sbs_profile.shape[0]),
int(sbs_profile.shape[1]),
)
global_sbs_hdynamic_cpp[sbs_hdynamic_cpp_position].update(
output.data_ptr(),
int(output.shape[0]),
int(output.shape[1]),
int(output.shape[2]),
int(output.shape[3]),
epsilon_xy.data_ptr() if epsilon_xy is not None else int(0),
int(epsilon_xy.shape[0]) if epsilon_xy is not None else int(0),
int(epsilon_xy.shape[1]) if epsilon_xy is not None else int(0),
int(epsilon_xy.shape[2]) if epsilon_xy is not None else int(0),
epsilon_t_0.data_ptr(),
int(epsilon_t_0.shape[0]),
weights.data_ptr(),
int(weights.shape[0]),
int(weights.shape[1]),
spikes.data_ptr(),
int(spikes.shape[0]),
int(spikes.shape[1]),
int(spikes.shape[2]),
int(spikes.shape[3]),
h_initial.data_ptr(),
int(h_initial.shape[0]),
hdyn_number_of_cpu_processes,
float(forgetting_offset.cpu().item()),
int(gpu_tuning_factor),
)
# ###########################################################
# Save the necessary data for the backward pass
# ###########################################################
ctx.save_for_backward(
input,
weights,
output,
parameter_list,
grad_output_scale,
labels,
)
return output
@staticmethod
def backward(ctx, grad_output):
# ##############################################
# Get the variables back
# ##############################################
(
input,
weights,
output,
parameter_list,
last_grad_scale,
labels,
) = ctx.saved_tensors
assert labels.numel() > 0
# ##############################################
# Default output
# ##############################################
grad_input = None
grad_spikes = None
grad_eps_xy = None
grad_epsilon_t_0 = None
grad_weights = None
grad_h_initial = None
grad_parameter_list = None
grad_forgetting_offset = None
grad_labels = None
# ##############################################
# Parameters
# ##############################################
parameter_w_trainable: bool = bool(parameter_list[6])
parameter_disable_scale_grade: bool = bool(parameter_list[7])
parameter_keep_last_grad_scale: bool = bool(parameter_list[8])
parameter_skip_gradient_calculation: bool = bool(parameter_list[9])
parameter_output_layer: bool = bool(parameter_list[10])
parameter_local_learning: bool = bool(parameter_list[11])
# ##############################################
# Dealing with overall scale of the gradient
# ##############################################
if parameter_disable_scale_grade is False:
if parameter_keep_last_grad_scale is True:
last_grad_scale = torch.tensor(
[torch.abs(grad_output).max(), last_grad_scale]
).max()
grad_output /= last_grad_scale
grad_output_scale = last_grad_scale.clone()
input /= input.sum(dim=1, keepdim=True, dtype=weights.dtype)
# #################################################
# User doesn't want us to calculate the gradients
# #################################################
if parameter_skip_gradient_calculation is True:
return (
grad_input,
grad_spikes,
grad_eps_xy,
grad_epsilon_t_0,
grad_weights,
grad_h_initial,
grad_parameter_list,
grad_output_scale,
grad_forgetting_offset,
grad_labels,
)
# #################################################
# Calculate backprop error (grad_input)
# #################################################
backprop_r: torch.Tensor = weights.unsqueeze(0).unsqueeze(-1).unsqueeze(
-1
) * output.unsqueeze(1)
backprop_bigr: torch.Tensor = backprop_r.sum(dim=2)
backprop_z: torch.Tensor = backprop_r * (
1.0 / (backprop_bigr + 1e-20)
).unsqueeze(2)
grad_input: torch.Tensor = (backprop_z * grad_output.unsqueeze(1)).sum(2)
del backprop_z
# #################################################
# Calculate weight gradient (grad_weights)
# #################################################
if parameter_w_trainable is False:
# #################################################
# We don't train this weight
# #################################################
grad_weights = None
elif (parameter_output_layer is False) and (parameter_local_learning is True):
# #################################################
# Local learning
# #################################################
grad_weights = (
(-2 * (input - backprop_bigr).unsqueeze(2) * output.unsqueeze(1))
.sum(0)
.sum(-1)
.sum(-1)
)
elif (parameter_output_layer is True) and (parameter_local_learning is True):
target_one_hot: torch.Tensor = torch.zeros(
(
labels.shape[0],
output.shape[1],
),
device=input.device,
dtype=input.dtype,
)
target_one_hot.scatter_(
1,
labels.to(input.device).unsqueeze(1),
torch.ones(
(labels.shape[0], 1),
device=input.device,
dtype=input.dtype,
),
)
target_one_hot = target_one_hot.unsqueeze(-1).unsqueeze(-1)
# (-2 * (input - backprop_bigr).unsqueeze(2) * (target_one_hot-output).unsqueeze(1))
# (-2 * input.unsqueeze(2) * (target_one_hot-output).unsqueeze(1))
grad_weights = (
(
-2
* (input - backprop_bigr).unsqueeze(2)
* target_one_hot.unsqueeze(1)
)
.sum(0)
.sum(-1)
.sum(-1)
)
else:
# #################################################
# Backprop
# #################################################
backprop_f: torch.Tensor = output.unsqueeze(1) * (
input / (backprop_bigr**2 + 1e-20)
).unsqueeze(2)
result_omega: torch.Tensor = backprop_bigr.unsqueeze(
2
) * grad_output.unsqueeze(1)
result_omega -= (backprop_r * grad_output.unsqueeze(1)).sum(2).unsqueeze(2)
result_omega *= backprop_f
del backprop_f
grad_weights = result_omega.sum(0).sum(-1).sum(-1)
del result_omega
del backprop_bigr
del backprop_r
return (
grad_input,
grad_spikes,
grad_eps_xy,
grad_epsilon_t_0,
grad_weights,
grad_h_initial,
grad_parameter_list,
grad_output_scale,
grad_forgetting_offset,
grad_labels,
)

104
network/InputSpikeImage.py Normal file
View file

@ -0,0 +1,104 @@
import torch
from network.SpikeLayer import SpikeLayer
from network.SpikeCountLayer import SpikeCountLayer
class InputSpikeImage(torch.nn.Module):
_reshape: bool
_normalize: bool
_device: torch.device
number_of_spikes: int
def __init__(
self,
number_of_spikes: int = -1,
number_of_cpu_processes: int = 1,
reshape: bool = False,
normalize: bool = True,
device: torch.device | None = None,
) -> None:
super().__init__()
assert device is not None
self._device = device
self._reshape = bool(reshape)
self._normalize = bool(normalize)
self.number_of_spikes = int(number_of_spikes)
if device != torch.device("cpu"):
number_of_cpu_processes_spike_generator = 0
else:
number_of_cpu_processes_spike_generator = number_of_cpu_processes
self.spike_generator = SpikeLayer(
number_of_cpu_processes=number_of_cpu_processes_spike_generator,
device=device,
)
self.spike_count = SpikeCountLayer(
number_of_cpu_processes=number_of_cpu_processes
)
####################################################################
# Forward #
####################################################################
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.number_of_spikes < 1:
return input
input_shape: list[int] = [
int(input.shape[0]),
int(input.shape[1]),
int(input.shape[2]),
int(input.shape[3]),
]
if self._reshape is True:
input_work = (
input.detach()
.clone()
.to(self._device)
.reshape(
(input_shape[0], input_shape[1] * input_shape[2] * input_shape[3])
)
.unsqueeze(-1)
.unsqueeze(-1)
)
else:
input_work = input.detach().clone().to(self._device)
spikes = self.spike_generator(
input=input_work, number_of_spikes=self.number_of_spikes
)
if self._reshape is True:
dim_s: int = input_shape[1] * input_shape[2] * input_shape[3]
else:
dim_s = input_shape[1]
output: torch.Tensor = self.spike_count(spikes, dim_s)
if self._reshape is True:
output = (
output.squeeze(-1)
.squeeze(-1)
.reshape(
(input_shape[0], input_shape[1], input_shape[2], input_shape[3])
)
)
if self._normalize is True:
output = output.type(dtype=input_work.dtype)
output = output / output.sum(dim=-1, keepdim=True).sum(
dim=-2, keepdim=True
).sum(dim=-3, keepdim=True)
return output

View file

@ -3,20 +3,22 @@ export
all:
cd h_dynamic_cnn_cpu_cpp && $(MAKE) all
cd h_dynamic_cnn_gpu_cpp && $(MAKE) all
cd h_dynamic_cnn_gpu_cpp_v1 && $(MAKE) all
cd spike_generation_cpu_cpp && $(MAKE) all
cd spike_generation_gpu_cpp_v2 && $(MAKE) all
cd multiplication_approximation_cpu_cpp && $(MAKE) all
cd multiplication_approximation_gpu_cpp && $(MAKE) all
cd count_spikes_cpu_cpp && $(MAKE) all
cd sort_spikes_cpu_cpp && $(MAKE) all
$(PYBIN)python3 pybind11_auto_pyi.py
clean:
cd h_dynamic_cnn_cpu_cpp && $(MAKE) clean
cd h_dynamic_cnn_gpu_cpp && $(MAKE) clean
cd h_dynamic_cnn_gpu_cpp_v1 && $(MAKE) clean
cd spike_generation_cpu_cpp && $(MAKE) clean
cd spike_generation_gpu_cpp_v2 && $(MAKE) clean
cd multiplication_approximation_cpu_cpp && $(MAKE) clean
cd multiplication_approximation_gpu_cpp && $(MAKE) clean
cd count_spikes_cpu_cpp && $(MAKE) clean
cd sort_spikes_cpu_cpp && $(MAKE) clean

480
network/SbSLayer.py Normal file
View file

@ -0,0 +1,480 @@
import torch
from network.SpikeLayer import SpikeLayer
from network.HDynamicLayer import HDynamicLayer
from network.calculate_output_size import calculate_output_size
from network.SortSpikesLayer import SortSpikesLayer
class SbSLayer(torch.nn.Module):
_epsilon_xy: torch.Tensor | None = None
_epsilon_xy_use: bool
_epsilon_0: float
_weights: torch.nn.parameter.Parameter
_weights_exists: bool = False
_kernel_size: list[int]
_stride: list[int]
_dilation: list[int]
_padding: list[int]
_output_size: torch.Tensor
_number_of_spikes: int
_number_of_cpu_processes: int
_number_of_neurons: int
_number_of_input_neurons: int
_epsilon_xy_intitial: float
_h_initial: torch.Tensor | None = None
_w_trainable: bool
_last_grad_scale: torch.nn.parameter.Parameter
_keep_last_grad_scale: bool
_disable_scale_grade: bool
_forgetting_offset: float
_weight_noise_range: list[float]
_skip_gradient_calculation: bool
_is_pooling_layer: bool
_input_size: list[int]
_output_layer: bool = False
_local_learning: bool = False
device: torch.device
default_dtype: torch.dtype
_gpu_tuning_factor: int
_max_grad_weights: torch.Tensor | None = None
_number_of_grad_weight_contributions: float = 0.0
last_input_store: bool = False
last_input_data: torch.Tensor | None = None
_cooldown_after_number_of_spikes: int = -1
_reduction_cooldown: float = 1.0
_layer_id: int = -1
spike_full_layer_input_distribution: bool = False
def __init__(
self,
number_of_input_neurons: int,
number_of_neurons: int,
input_size: list[int],
forward_kernel_size: list[int],
number_of_spikes: int,
epsilon_xy_intitial: float = 0.1,
epsilon_xy_use: bool = False,
epsilon_0: float = 1.0,
weight_noise_range: list[float] = [0.0, 1.0],
is_pooling_layer: bool = False,
strides: list[int] = [1, 1],
dilation: list[int] = [0, 0],
padding: list[int] = [0, 0],
number_of_cpu_processes: int = 1,
w_trainable: bool = False,
keep_last_grad_scale: bool = False,
disable_scale_grade: bool = True,
forgetting_offset: float = -1.0,
skip_gradient_calculation: bool = False,
device: torch.device | None = None,
default_dtype: torch.dtype | None = None,
gpu_tuning_factor: int = 10,
layer_id: int = -1,
cooldown_after_number_of_spikes: int = -1,
reduction_cooldown: float = 1.0,
) -> None:
super().__init__()
assert device is not None
assert default_dtype is not None
self.device = device
self.default_dtype = default_dtype
self._w_trainable = bool(w_trainable)
self._keep_last_grad_scale = bool(keep_last_grad_scale)
self._skip_gradient_calculation = bool(skip_gradient_calculation)
self._disable_scale_grade = bool(disable_scale_grade)
self._epsilon_xy_intitial = float(epsilon_xy_intitial)
self._stride = strides
self._dilation = dilation
self._padding = padding
self._kernel_size = forward_kernel_size
self._number_of_input_neurons = int(number_of_input_neurons)
self._number_of_neurons = int(number_of_neurons)
self._epsilon_0 = float(epsilon_0)
self._number_of_cpu_processes = int(number_of_cpu_processes)
self._number_of_spikes = int(number_of_spikes)
self._weight_noise_range = weight_noise_range
self._is_pooling_layer = bool(is_pooling_layer)
self._cooldown_after_number_of_spikes = int(cooldown_after_number_of_spikes)
self.reduction_cooldown = float(reduction_cooldown)
self._layer_id = layer_id
self._epsilon_xy_use = epsilon_xy_use
assert len(input_size) == 2
self._input_size = input_size
# The GPU hates me...
# Too many SbS threads == bad
# Thus I need to limit them...
# (Reminder: We cannot access the mini-batch size here,
# which is part of the GPU thread size calculation...)
self._last_grad_scale = torch.nn.parameter.Parameter(
torch.tensor(-1.0, dtype=self.default_dtype),
requires_grad=True,
)
self._forgetting_offset = float(forgetting_offset)
self._output_size = calculate_output_size(
value=input_size,
kernel_size=self._kernel_size,
stride=self._stride,
dilation=self._dilation,
padding=self._padding,
)
self.set_h_init_to_uniform()
self.spike_generator = SpikeLayer(
number_of_spikes=self._number_of_spikes,
number_of_cpu_processes=self._number_of_cpu_processes,
device=self.device,
)
self.h_dynamic = HDynamicLayer(
output_size=self._output_size.tolist(),
output_layer=self._output_layer,
local_learning=self._local_learning,
number_of_cpu_processes=number_of_cpu_processes,
w_trainable=w_trainable,
skip_gradient_calculation=skip_gradient_calculation,
device=device,
default_dtype=self.default_dtype,
gpu_tuning_factor=gpu_tuning_factor,
)
assert len(input_size) >= 2
self.spikes_sorter = SortSpikesLayer(
kernel_size=self._kernel_size,
input_shape=[
self._number_of_input_neurons,
int(input_size[0]),
int(input_size[1]),
],
output_size=self._output_size.clone(),
strides=self._stride,
dilation=self._dilation,
padding=self._padding,
number_of_cpu_processes=number_of_cpu_processes,
)
# TODO: TEST
if layer_id == 0:
self.spike_full_layer_input_distribution = True
# ###############################################################
# Initialize the weights
# ###############################################################
if self._is_pooling_layer is True:
self.weights = self._make_pooling_weights()
else:
assert len(self._weight_noise_range) == 2
weights = torch.empty(
(
int(self._kernel_size[0])
* int(self._kernel_size[1])
* int(self._number_of_input_neurons),
int(self._number_of_neurons),
),
dtype=self.default_dtype,
device=self.device,
)
torch.nn.init.uniform_(
weights,
a=float(self._weight_noise_range[0]),
b=float(self._weight_noise_range[1]),
)
self.weights = weights
####################################################################
# Variables in and out #
####################################################################
def get_epsilon_t(self, number_of_spikes: int):
"""Generates the time series of the basic epsilon."""
t = (
torch.arange(
0, number_of_spikes, dtype=self.default_dtype, device=self.device
)
+ 1
)
# torch.ones((number_of_spikes), dtype=self.default_dtype, device=self.device
epsilon_t: torch.Tensor = t ** (-1.0 / 2.0)
if (self._cooldown_after_number_of_spikes < number_of_spikes) and (
self._cooldown_after_number_of_spikes >= 0
):
epsilon_t[
self._cooldown_after_number_of_spikes : number_of_spikes
] /= self._reduction_cooldown
return epsilon_t
@property
def weights(self) -> torch.Tensor | None:
if self._weights_exists is False:
return None
else:
return self._weights
@weights.setter
def weights(self, value: torch.Tensor):
assert value is not None
assert torch.is_tensor(value) is True
assert value.dim() == 2
temp: torch.Tensor = (
value.detach()
.clone(memory_format=torch.contiguous_format)
.type(dtype=self.default_dtype)
.to(device=self.device)
)
temp /= temp.sum(dim=0, keepdim=True, dtype=self.default_dtype)
if self._weights_exists is False:
self._weights = torch.nn.parameter.Parameter(temp, requires_grad=True)
self._weights_exists = True
else:
self._weights.data = temp
@property
def h_initial(self) -> torch.Tensor | None:
return self._h_initial
@h_initial.setter
def h_initial(self, value: torch.Tensor):
assert value is not None
assert torch.is_tensor(value) is True
assert value.dim() == 1
assert value.dtype == self.default_dtype
self._h_initial = (
value.detach()
.clone(memory_format=torch.contiguous_format)
.type(dtype=self.default_dtype)
.to(device=self.device)
.requires_grad_(False)
)
def update_pre_care(self):
if self._weights.grad is not None:
assert self._number_of_grad_weight_contributions > 0
self._weights.grad /= self._number_of_grad_weight_contributions
self._number_of_grad_weight_contributions = 0.0
def update_after_care(self, threshold_weight: float):
if self._w_trainable is True:
self.norm_weights()
self.threshold_weights(threshold_weight)
self.norm_weights()
def after_batch(self, new_state: bool = False):
if self._keep_last_grad_scale is True:
self._last_grad_scale.data = self._last_grad_scale.grad
self._keep_last_grad_scale = new_state
self._last_grad_scale.grad = torch.zeros_like(self._last_grad_scale.grad)
####################################################################
# Helper functions #
####################################################################
def _make_pooling_weights(self) -> torch.Tensor:
"""For generating the pooling weights."""
assert self._number_of_neurons is not None
assert self._kernel_size is not None
weights: torch.Tensor = torch.zeros(
(
int(self._kernel_size[0]),
int(self._kernel_size[1]),
int(self._number_of_neurons),
int(self._number_of_neurons),
),
dtype=self.default_dtype,
device=self.device,
)
for i in range(0, int(self._number_of_neurons)):
weights[:, :, i, i] = 1.0
weights = weights.moveaxis(-1, 0).moveaxis(-1, 1)
weights = torch.nn.functional.unfold(
input=weights,
kernel_size=(int(self._kernel_size[0]), int(self._kernel_size[1])),
dilation=(1, 1),
padding=(0, 0),
stride=(1, 1),
).squeeze()
weights = torch.moveaxis(weights, 0, 1)
return weights
def set_h_init_to_uniform(self) -> None:
assert self._number_of_neurons > 2
self.h_initial: torch.Tensor = torch.full(
(self._number_of_neurons,),
(1.0 / float(self._number_of_neurons)),
dtype=self.default_dtype,
device=self.device,
)
def norm_weights(self) -> None:
assert self._weights_exists is True
temp: torch.Tensor = (
self._weights.data.detach()
.clone(memory_format=torch.contiguous_format)
.type(dtype=self.default_dtype)
.to(device=self.device)
)
temp /= temp.sum(dim=0, keepdim=True, dtype=self.default_dtype)
self._weights.data = temp
def threshold_weights(self, threshold: float) -> None:
assert self._weights_exists is True
assert threshold >= 0
torch.clamp(
self._weights.data,
min=float(threshold),
max=None,
out=self._weights.data,
)
####################################################################
# Forward #
####################################################################
def forward(
self,
input: torch.Tensor,
labels: torch.Tensor | None = None,
) -> torch.Tensor:
# Are we happy with the input?
assert input is not None
assert torch.is_tensor(input) is True
assert input.dim() == 4
assert input.dtype == self.default_dtype
assert input.shape[1] == self._number_of_input_neurons
assert input.shape[2] == self._input_size[0]
assert input.shape[3] == self._input_size[1]
# Are we happy with the rest of the network?
assert self._epsilon_0 is not None
assert self._h_initial is not None
assert self._forgetting_offset is not None
assert self._weights_exists is True
assert self._weights is not None
# Convolution of the input...
# Well, this is a convoltion layer
# there needs to be convolution somewhere
input_convolved = torch.nn.functional.fold(
torch.nn.functional.unfold(
input.requires_grad_(True),
kernel_size=(int(self._kernel_size[0]), int(self._kernel_size[1])),
dilation=(int(self._dilation[0]), int(self._dilation[1])),
padding=(int(self._padding[0]), int(self._padding[1])),
stride=(int(self._stride[0]), int(self._stride[1])),
),
output_size=tuple(self._output_size.tolist()),
kernel_size=(1, 1),
dilation=(1, 1),
padding=(0, 0),
stride=(1, 1),
)
# We might need the convolved input for other layers
# let us keep it for the future
if self.last_input_store is True:
self.last_input_data = input_convolved.detach().clone()
self.last_input_data /= self.last_input_data.sum(dim=1, keepdim=True)
else:
self.last_input_data = None
epsilon_t_0: torch.Tensor = (
(self.get_epsilon_t(self._number_of_spikes) * self._epsilon_0)
.type(input.dtype)
.to(input.device)
)
if (self._epsilon_xy is None) and (self._epsilon_xy_use is True):
self._epsilon_xy = torch.full(
(
input_convolved.shape[1],
input_convolved.shape[2],
input_convolved.shape[3],
),
float(self._epsilon_xy_intitial),
dtype=self.default_dtype,
device=self.device,
)
if self._epsilon_xy_use is True:
assert self._epsilon_xy is not None
# In the case somebody tried to replace the matrix with wrong dimensions
assert self._epsilon_xy.shape[0] == input_convolved.shape[1]
assert self._epsilon_xy.shape[1] == input_convolved.shape[2]
assert self._epsilon_xy.shape[2] == input_convolved.shape[3]
else:
assert self._epsilon_xy is None
if self.spike_full_layer_input_distribution is False:
spike = self.spike_generator(input_convolved, int(self._number_of_spikes))
else:
input_shape = input.shape
input = (
input.reshape(
(input_shape[0], input_shape[1] * input_shape[2] * input_shape[3])
)
.unsqueeze(-1)
.unsqueeze(-1)
)
spike_unsorted = self.spike_generator(input, int(self._number_of_spikes))
input = (
input.squeeze(-1)
.squeeze(-1)
.reshape(
(input_shape[0], input_shape[1], input_shape[2], input_shape[3])
)
)
spike = self.spikes_sorter(spike_unsorted).to(device=input_convolved.device)
output = self.h_dynamic(
input=input_convolved,
spike=spike,
epsilon_xy=self._epsilon_xy,
epsilon_t_0=epsilon_t_0,
weights=self._weights,
h_initial=self._h_initial,
last_grad_scale=self._last_grad_scale,
labels=labels,
keep_last_grad_scale=self._keep_last_grad_scale,
disable_scale_grade=self._disable_scale_grade,
forgetting_offset=self._forgetting_offset,
)
self._number_of_grad_weight_contributions += (
output.shape[0] * output.shape[-2] * output.shape[-1]
)
return output

View file

@ -1,15 +1,15 @@
import torch
from network.SbS import SbS
from network.SbSLayer import SbSLayer
class SbSReconstruction(torch.nn.Module):
_the_sbs_layer: SbS
_the_sbs_layer: SbSLayer
def __init__(
self,
the_sbs_layer: SbS,
the_sbs_layer: SbSLayer,
) -> None:
super().__init__()

173
network/SortSpikesLayer.py Normal file
View file

@ -0,0 +1,173 @@
import torch
from network.PySortSpikesCPU import SortSpikesCPU
class SortSpikesLayer(torch.nn.Module):
_kernel_size: list[int]
_stride: list[int]
_dilation: list[int]
_padding: list[int]
_output_size: torch.Tensor
_number_of_cpu_processes: int
_input_shape: list[int]
order: torch.Tensor | None = None
order_convoled: torch.Tensor | None = None
indices: torch.Tensor | None = None
def __init__(
self,
kernel_size: list[int],
input_shape: list[int],
output_size: torch.Tensor,
strides: list[int] = [1, 1],
dilation: list[int] = [0, 0],
padding: list[int] = [0, 0],
number_of_cpu_processes: int = 1,
) -> None:
super().__init__()
self._stride = strides
self._dilation = dilation
self._padding = padding
self._kernel_size = kernel_size
self._output_size = output_size
self._number_of_cpu_processes = number_of_cpu_processes
self._input_shape = input_shape
self.sort_spikes = SortSpikesCPU()
self.order = (
torch.arange(
0,
self._input_shape[0] * self._input_shape[1] * self._input_shape[2],
device=torch.device("cpu"),
)
.reshape(
(
1,
self._input_shape[0],
self._input_shape[1],
self._input_shape[2],
)
)
.type(dtype=torch.float32)
)
self.order_convoled = torch.nn.functional.fold(
torch.nn.functional.unfold(
self.order,
kernel_size=(
int(self._kernel_size[0]),
int(self._kernel_size[1]),
),
dilation=(int(self._dilation[0]), int(self._dilation[1])),
padding=(int(self._padding[0]), int(self._padding[1])),
stride=(int(self._stride[0]), int(self._stride[1])),
),
output_size=tuple(self._output_size.tolist()),
kernel_size=(1, 1),
dilation=(1, 1),
padding=(0, 0),
stride=(1, 1),
).type(dtype=torch.int64)
assert self.order_convoled is not None
self.order_convoled = self.order_convoled.reshape(
(
self.order_convoled.shape[1]
* self.order_convoled.shape[2]
* self.order_convoled.shape[3]
)
)
max_length: int = 0
max_range: int = (
self._input_shape[0] * self._input_shape[1] * self._input_shape[2]
)
for id in range(0, max_range):
idx = torch.where(self.order_convoled == id)[0]
max_length = max(max_length, int(idx.shape[0]))
self.indices = torch.full(
(max_range, max_length),
-1,
dtype=torch.int64,
device=torch.device("cpu"),
)
for id in range(0, max_range):
idx = torch.where(self.order_convoled == id)[0]
self.indices[id, 0 : int(idx.shape[0])] = idx
####################################################################
# Forward #
####################################################################
def forward(
self,
input: torch.Tensor,
) -> torch.Tensor:
assert len(self._input_shape) == 3
assert input.shape[-2] == 1
assert input.shape[-1] == 1
assert self.indices is not None
spikes_count = torch.zeros(
(input.shape[0], int(self._output_size[0]), int(self._output_size[1])),
device=torch.device("cpu"),
dtype=torch.int64,
)
input_cpu = input.clone().cpu()
self.sort_spikes.count(
input_cpu.data_ptr(), # Input
int(input_cpu.shape[0]),
int(input_cpu.shape[1]),
int(input_cpu.shape[2]),
int(input_cpu.shape[3]),
spikes_count.data_ptr(), # Output
int(spikes_count.shape[0]),
int(spikes_count.shape[1]),
int(spikes_count.shape[2]),
self.indices.data_ptr(), # Positions
int(self.indices.shape[0]),
int(self.indices.shape[1]),
int(self._number_of_cpu_processes),
)
spikes_output = torch.full(
(
input.shape[0],
int(spikes_count.max()),
int(self._output_size[0]),
int(self._output_size[1]),
),
-1,
dtype=torch.int64,
device=torch.device("cpu"),
)
self.sort_spikes.process(
input_cpu.data_ptr(), # Input
int(input_cpu.shape[0]),
int(input_cpu.shape[1]),
int(input_cpu.shape[2]),
int(input_cpu.shape[3]),
spikes_output.data_ptr(), # Output
int(spikes_output.shape[0]),
int(spikes_output.shape[1]),
int(spikes_output.shape[2]),
int(spikes_output.shape[3]),
self.indices.data_ptr(), # Positions
int(self.indices.shape[0]),
int(self.indices.shape[1]),
int(self._number_of_cpu_processes),
)
return spikes_output

View file

@ -0,0 +1,52 @@
import torch
from network.PyCountSpikesCPU import CountSpikesCPU
class SpikeCountLayer(torch.nn.Module):
_number_of_cpu_processes: int
def __init__(
self,
number_of_cpu_processes: int = 1,
) -> None:
super().__init__()
self._number_of_cpu_processes = number_of_cpu_processes
####################################################################
# Forward #
####################################################################
def forward(self, input: torch.Tensor, dim_s: int) -> torch.Tensor:
assert input.ndim == 4
assert dim_s > 0
input_cpu = input.cpu()
histogram = torch.zeros(
(
int(input.shape[0]),
int(dim_s),
int(input.shape[-2]),
int(input.shape[-1]),
),
dtype=torch.int64,
device=input_cpu.device,
)
count_spikes = CountSpikesCPU()
count_spikes.process(
input_cpu.data_ptr(),
int(input_cpu.shape[0]),
int(input_cpu.shape[1]),
int(input_cpu.shape[2]),
int(input_cpu.shape[3]),
histogram.data_ptr(),
int(histogram.shape[1]),
int(self._number_of_cpu_processes),
)
return histogram.to(device=input.device)

View file

@ -3,8 +3,6 @@ import torch
from network.PySpikeGenerationCPU import SpikeGenerationCPU
from network.PySpikeGenerationGPU import SpikeGenerationGPU
# from PyCountSpikesCPU import CountSpikesCPU
global_spike_generation_gpu_setting: list[torch.Tensor] = []
global_spike_size: list[torch.Tensor] = []
global_spike_generation_cpp: list[SpikeGenerationCPU | SpikeGenerationGPU] = []
@ -16,28 +14,21 @@ class SpikeLayer(torch.nn.Module):
_spike_generation_gpu_setting_position: int
_number_of_cpu_processes: int
_number_of_spikes: int
_spikes: torch.Tensor | None = None
_store_spikes: bool
device: torch.device
def __init__(
self,
number_of_spikes: int = 1,
number_of_spikes: int = -1,
number_of_cpu_processes: int = 1,
device: torch.device | None = None,
default_dtype: torch.dtype | None = None,
store_spikes: bool = False,
) -> None:
super().__init__()
assert device is not None
assert default_dtype is not None
self.device = device
self.default_dtype = default_dtype
self._number_of_cpu_processes = number_of_cpu_processes
self._number_of_spikes = number_of_spikes
self._store_spikes = store_spikes
global_spike_generation_gpu_setting.append(torch.tensor([0]))
global_spike_size.append(torch.tensor([0, 0, 0, 0]))
@ -62,7 +53,6 @@ class SpikeLayer(torch.nn.Module):
self,
input: torch.Tensor,
number_of_spikes: int | None = None,
store_spikes: bool | None = None,
) -> torch.Tensor:
if number_of_spikes is None:
@ -80,18 +70,7 @@ class SpikeLayer(torch.nn.Module):
dtype=torch.int64,
)
spikes = self.functional_spike_generation(input, parameter_list)
if (store_spikes is not None) and (store_spikes is True):
self._spikes = spikes.detach().clone()
elif (store_spikes is not None) and (store_spikes is False):
self._spikes = None
elif self._store_spikes is True:
self._spikes = spikes.detach().clone()
else:
self._spikes = None
return spikes
return self.functional_spike_generation(input, parameter_list)
class FunctionalSpikeGeneration(torch.autograd.Function):

View file

@ -3,10 +3,11 @@ import torch
from network.calculate_output_size import calculate_output_size
from network.Parameter import Config
from network.SbS import SbS
from network.SbSLayer import SbSLayer
from network.SplitOnOffLayer import SplitOnOffLayer
from network.Conv2dApproximation import Conv2dApproximation
from network.SbSReconstruction import SbSReconstruction
from network.InputSpikeImage import InputSpikeImage
def build_network(
@ -144,7 +145,7 @@ def build_network(
is_pooling_layer = True
network.append(
SbS(
SbSLayer(
number_of_input_neurons=in_channels,
number_of_neurons=out_channels,
input_size=input_size[-1],
@ -190,7 +191,7 @@ def build_network(
logging.info(f"Layer: {layer_id} -> SbS Reconstruction Layer")
assert layer_id > 0
assert isinstance(network[-1], SbS) is True
assert isinstance(network[-1], SbSLayer) is True
network.append(SbSReconstruction(network[-1]))
network[-1]._w_trainable = False
@ -365,6 +366,36 @@ def build_network(
).tolist()
input_size.append(input_size_temp)
# #############################################################
# Approx CONV2D layer:
# #############################################################
elif (
cfg.network_structure.layer_type[layer_id]
.upper()
.startswith("INPUT SPIKE IMAGE")
is True
):
logging.info(f"Layer: {layer_id} -> Input Spike Image Layer")
number_of_spikes: int = -1
if len(cfg.number_of_spikes) > layer_id:
number_of_spikes = cfg.number_of_spikes[layer_id]
elif len(cfg.number_of_spikes) == 1:
number_of_spikes = cfg.number_of_spikes[0]
network.append(
InputSpikeImage(
number_of_spikes=number_of_spikes,
number_of_cpu_processes=cfg.number_of_cpu_processes,
reshape=True,
normalize=True,
device=device,
)
)
input_size.append(input_size[-1])
# #############################################################
# Failure becaue we didn't found the selection of layer
# #############################################################

View file

@ -1,7 +1,7 @@
# %%
import torch
from network.Parameter import Config
from network.SbS import SbS
from network.SbSLayer import SbSLayer
from network.Conv2dApproximation import Conv2dApproximation
from network.Adam import Adam
@ -20,7 +20,7 @@ def build_optimizer(
for id in range(0, len(network)):
if (isinstance(network[id], SbS) is True) and (
if (isinstance(network[id], SbSLayer) is True) and (
network[id]._w_trainable is True
):
parameter_list_weights.append(network[id]._weights)

View file

@ -3,9 +3,10 @@ import torch
import glob
import numpy as np
from network.SbS import SbS
from network.SbSLayer import SbSLayer
from network.SplitOnOffLayer import SplitOnOffLayer
from network.Conv2dApproximation import Conv2dApproximation
import os
def load_previous_weights(
@ -14,22 +15,28 @@ def load_previous_weights(
logging,
device: torch.device,
default_dtype: torch.dtype,
order_id: float | int | None = None,
) -> None:
if order_id is None:
post_fix: str = ""
else:
post_fix = f"_{order_id}"
for id in range(0, len(network)):
# #################################################
# SbS
# #################################################
if isinstance(network[id], SbS) is True:
# Are there weights that overwrite the initial weights?
file_to_load = glob.glob(overload_path + "/Weight_L" + str(id) + "_*.npy")
if isinstance(network[id], SbSLayer) is True:
filename_wilcard = os.path.join(
overload_path, f"Weight_L{id}_*{post_fix}.npy"
)
file_to_load = glob.glob(filename_wilcard)
if len(file_to_load) > 1:
raise Exception(
f"Too many previous weights files {overload_path}/Weight_L{id}*.npy"
)
raise Exception(f"Too many previous weights files {filename_wilcard}")
if len(file_to_load) == 1:
network[id].weights = torch.tensor(
@ -45,13 +52,13 @@ def load_previous_weights(
# Conv2d weights
# #################################################
# Are there weights that overwrite the initial weights?
file_to_load = glob.glob(overload_path + "/Weight_L" + str(id) + "_*.npy")
filename_wilcard = os.path.join(
overload_path, f"Weight_L{id}_*{post_fix}.npy"
)
file_to_load = glob.glob(filename_wilcard)
if len(file_to_load) > 1:
raise Exception(
f"Too many previous weights files {overload_path}/Weight_L{id}*.npy"
)
raise Exception(f"Too many previous weights files {filename_wilcard}")
if len(file_to_load) == 1:
network[id]._parameters["weight"].data = torch.tensor(
@ -65,13 +72,13 @@ def load_previous_weights(
# Conv2d bias
# #################################################
# Are there biases that overwrite the initial weights?
file_to_load = glob.glob(overload_path + "/Bias_L" + str(id) + "_*.npy")
filename_wilcard = os.path.join(
overload_path, f"Bias_L{id}_*{post_fix}.npy"
)
file_to_load = glob.glob(filename_wilcard)
if len(file_to_load) > 1:
raise Exception(
f"Too many previous weights files {overload_path}/Weight_L{id}*.npy"
)
raise Exception(f"Too many previous weights files {filename_wilcard}")
if len(file_to_load) == 1:
network[id]._parameters["bias"].data = torch.tensor(
@ -87,13 +94,13 @@ def load_previous_weights(
# Approximate Conv2d weights
# #################################################
# Are there weights that overwrite the initial weights?
file_to_load = glob.glob(overload_path + "/Weight_L" + str(id) + "_*.npy")
filename_wilcard = os.path.join(
overload_path, f"Weight_L{id}_*{post_fix}.npy"
)
file_to_load = glob.glob(filename_wilcard)
if len(file_to_load) > 1:
raise Exception(
f"Too many previous weights files {overload_path}/Weight_L{id}*.npy"
)
raise Exception(f"Too many previous weights files {filename_wilcard}")
if len(file_to_load) == 1:
network[id].weights.data = torch.tensor(
@ -107,13 +114,13 @@ def load_previous_weights(
# Approximate Conv2d bias
# #################################################
# Are there biases that overwrite the initial weights?
file_to_load = glob.glob(overload_path + "/Bias_L" + str(id) + "_*.npy")
filename_wilcard = os.path.join(
overload_path, f"Bias_L{id}_*{post_fix}.npy"
)
file_to_load = glob.glob(filename_wilcard)
if len(file_to_load) > 1:
raise Exception(
f"Too many previous weights files {overload_path}/Weight_L{id}*.npy"
)
raise Exception(f"Too many previous weights files {filename_wilcard}")
if len(file_to_load) == 1:
network[id].bias.data = torch.tensor(
@ -127,13 +134,13 @@ def load_previous_weights(
# SplitOnOffLayer
# #################################################
if isinstance(network[id], SplitOnOffLayer) is True:
# Are there weights that overwrite the initial weights?
file_to_load = glob.glob(overload_path + "/Mean_L" + str(id) + "_*.npy")
filename_wilcard = os.path.join(
overload_path, f"Mean_L{id}_*{post_fix}.npy"
)
file_to_load = glob.glob(filename_wilcard)
if len(file_to_load) > 1:
raise Exception(
f"Too many previous mean files {overload_path}/Mean_L{id}*.npy"
)
raise Exception(f"Too many previous mean files {filename_wilcard}")
if len(file_to_load) == 1:
network[id].mean = torch.tensor(

View file

@ -3,7 +3,7 @@ import time
from network.Parameter import Config
from torch.utils.tensorboard import SummaryWriter
from network.SbS import SbS
from network.SbSLayer import SbSLayer
from network.save_weight_and_bias import save_weight_and_bias
from network.SbSReconstruction import SbSReconstruction
@ -19,7 +19,7 @@ def add_weight_and_bias_to_histogram(
# ################################################
# Log the SbS Weights
# ################################################
if isinstance(network[id], SbS) is True:
if isinstance(network[id], SbSLayer) is True:
if network[id]._w_trainable is True:
try:
@ -175,7 +175,7 @@ def forward_pass_train(
.to(device=device)
)
for id in range(0, len(network)):
if isinstance(network[id], SbS) is True:
if isinstance(network[id], SbSLayer) is True:
h_collection.append(network[id](h_collection[-1], labels))
else:
h_collection.append(network[id](h_collection[-1]))
@ -203,7 +203,7 @@ def forward_pass_test(
)
for id in range(0, len(network)):
if (cfg.extract_noisy_pictures is True) or (overwrite_number_of_spikes != -1):
if isinstance(network[id], SbS) is True:
if isinstance(network[id], SbSLayer) is True:
h_collection.append(
network[id](
h_collection[-1],
@ -228,7 +228,7 @@ def run_optimizer(
cfg: Config,
) -> None:
for id in range(0, len(network)):
if isinstance(network[id], SbS) is True:
if isinstance(network[id], SbSLayer) is True:
network[id].update_pre_care()
for optimizer_item in optimizer:
@ -236,7 +236,7 @@ def run_optimizer(
optimizer_item.step()
for id in range(0, len(network)):
if isinstance(network[id], SbS) is True:
if isinstance(network[id], SbSLayer) is True:
network[id].update_after_care(
cfg.learning_parameters.learning_rate_threshold_w
/ float(
@ -288,11 +288,11 @@ def run_lr_scheduler(
def deal_with_gradient_scale(epoch_id: int, mini_batch_number: int, network):
if (epoch_id == 0) and (mini_batch_number == 0):
for id in range(0, len(network)):
if isinstance(network[id], SbS) is True:
if isinstance(network[id], SbSLayer) is True:
network[id].after_batch(True)
else:
for id in range(0, len(network)):
if isinstance(network[id], SbS) is True:
if isinstance(network[id], SbSLayer) is True:
network[id].after_batch()
@ -309,6 +309,7 @@ def loop_train(
tb: SummaryWriter,
lr_scheduler,
last_test_performance: float,
order_id: float | int | None = None,
) -> tuple[float, float, float, float]:
correct_in_minibatch: int = 0
@ -529,7 +530,10 @@ def loop_train(
# Save the Weights and Biases
# ################################################
save_weight_and_bias(
cfg=cfg, network=network, iteration_number=epoch_id
cfg=cfg,
network=network,
iteration_number=epoch_id,
order_id=order_id,
)
# ################################################

View file

@ -3,14 +3,23 @@ import torch
from network.Parameter import Config
import numpy as np
from network.SbS import SbS
from network.SbSLayer import SbSLayer
from network.SplitOnOffLayer import SplitOnOffLayer
from network.Conv2dApproximation import Conv2dApproximation
import os
def save_weight_and_bias(
cfg: Config, network: torch.nn.modules.container.Sequential, iteration_number: int
cfg: Config,
network: torch.nn.modules.container.Sequential,
iteration_number: int,
order_id: float | int | None = None,
) -> None:
if order_id is None:
post_fix: str = ""
else:
post_fix = f"_{order_id}"
for id in range(0, len(network)):
@ -18,11 +27,14 @@ def save_weight_and_bias(
# Save the SbS Weights
# ################################################
if isinstance(network[id], SbS) is True:
if isinstance(network[id], SbSLayer) is True:
if network[id]._w_trainable is True:
np.save(
f"{cfg.weight_path}/Weight_L{id}_S{iteration_number}.npy",
os.path.join(
cfg.weight_path,
f"Weight_L{id}_S{iteration_number}{post_fix}.npy",
),
network[id].weights.detach().cpu().numpy(),
)
@ -34,13 +46,18 @@ def save_weight_and_bias(
if network[id]._w_trainable is True:
# Save the new values
np.save(
f"{cfg.weight_path}/Weight_L{id}_S{iteration_number}.npy",
os.path.join(
cfg.weight_path,
f"Weight_L{id}_S{iteration_number}{post_fix}.npy",
),
network[id]._parameters["weight"].data.detach().cpu().numpy(),
)
# Save the new values
np.save(
f"{cfg.weight_path}/Bias_L{id}_S{iteration_number}.npy",
os.path.join(
cfg.weight_path, f"Bias_L{id}_S{iteration_number}{post_fix}.npy"
),
network[id]._parameters["bias"].data.detach().cpu().numpy(),
)
@ -52,20 +69,28 @@ def save_weight_and_bias(
if network[id]._w_trainable is True:
# Save the new values
np.save(
f"{cfg.weight_path}/Weight_L{id}_S{iteration_number}.npy",
os.path.join(
cfg.weight_path,
f"Weight_L{id}_S{iteration_number}{post_fix}.npy",
),
network[id].weights.data.detach().cpu().numpy(),
)
# Save the new values
if network[id].bias is not None:
np.save(
f"{cfg.weight_path}/Bias_L{id}_S{iteration_number}.npy",
os.path.join(
cfg.weight_path,
f"Bias_L{id}_S{iteration_number}{post_fix}.npy",
),
network[id].bias.data.detach().cpu().numpy(),
)
if isinstance(network[id], SplitOnOffLayer) is True:
np.save(
f"{cfg.weight_path}/Mean_L{id}_S{iteration_number}.npy",
os.path.join(
cfg.weight_path, f"Mean_L{id}_S{iteration_number}{post_fix}.npy"
),
network[id].mean.detach().cpu().numpy(),
)