Add files via upload
This commit is contained in:
parent
5ac2f1dc96
commit
d23e2edd8d
14 changed files with 1406 additions and 84 deletions
|
@ -12,6 +12,8 @@ class DatasetMaster(torch.utils.data.Dataset, ABC):
|
|||
pattern_storage: np.ndarray
|
||||
number_of_pattern: int
|
||||
mean: list[float]
|
||||
initial_size: list[int]
|
||||
channel_size: int
|
||||
|
||||
# Initialize
|
||||
def __init__(
|
||||
|
@ -36,6 +38,9 @@ class DatasetMaster(torch.utils.data.Dataset, ABC):
|
|||
|
||||
self.mean = []
|
||||
|
||||
self.initial_size = [0, 0]
|
||||
self.channel_size = 0
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.number_of_pattern
|
||||
|
||||
|
@ -74,6 +79,9 @@ class DatasetMNIST(DatasetMaster):
|
|||
mean = self.pattern_storage.mean(3).mean(2).mean(0)
|
||||
self.mean = [*mean]
|
||||
|
||||
self.initial_size = [28, 28]
|
||||
self.channel_size = 1
|
||||
|
||||
def __getitem__(self, index: int) -> tuple[torch.Tensor, int]:
|
||||
|
||||
image = self.pattern_storage[index, 0:1, :, :]
|
||||
|
@ -154,6 +162,9 @@ class DatasetFashionMNIST(DatasetMaster):
|
|||
mean = self.pattern_storage.mean(3).mean(2).mean(0)
|
||||
self.mean = [*mean]
|
||||
|
||||
self.initial_size = [28, 28]
|
||||
self.channel_size = 1
|
||||
|
||||
def __getitem__(self, index: int) -> tuple[torch.Tensor, int]:
|
||||
|
||||
image = self.pattern_storage[index, 0:1, :, :]
|
||||
|
@ -240,6 +251,9 @@ class DatasetCIFAR(DatasetMaster):
|
|||
mean = self.pattern_storage.mean(3).mean(2).mean(0)
|
||||
self.mean = [*mean]
|
||||
|
||||
self.initial_size = [32, 32]
|
||||
self.channel_size = 3
|
||||
|
||||
def __getitem__(self, index: int) -> tuple[torch.Tensor, int]:
|
||||
|
||||
image = self.pattern_storage[index, :, :, :]
|
||||
|
|
451
network/HDynamicLayer.py
Normal file
451
network/HDynamicLayer.py
Normal file
|
@ -0,0 +1,451 @@
|
|||
import torch
|
||||
|
||||
from network.PyHDynamicCNNCPU import HDynamicCNNCPU
|
||||
from network.PyHDynamicCNNGPU import HDynamicCNNGPU
|
||||
|
||||
global_sbs_gpu_setting: list[torch.Tensor] = []
|
||||
global_sbs_size: list[torch.Tensor] = []
|
||||
global_sbs_hdynamic_cpp: list[HDynamicCNNCPU | HDynamicCNNGPU] = []
|
||||
|
||||
|
||||
class HDynamicLayer(torch.nn.Module):
|
||||
|
||||
_sbs_gpu_setting_position: int
|
||||
_sbs_hdynamic_cpp_position: int
|
||||
_gpu_tuning_factor: int
|
||||
_number_of_cpu_processes: int
|
||||
_output_size: list[int]
|
||||
_w_trainable: bool
|
||||
_output_layer: bool
|
||||
_local_learning: bool
|
||||
device: torch.device
|
||||
default_dtype: torch.dtype
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_size: list[int],
|
||||
output_layer: bool = False,
|
||||
local_learning: bool = False,
|
||||
number_of_cpu_processes: int = 1,
|
||||
w_trainable: bool = False,
|
||||
skip_gradient_calculation: bool = False,
|
||||
device: torch.device | None = None,
|
||||
default_dtype: torch.dtype | None = None,
|
||||
gpu_tuning_factor: int = 5,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
assert device is not None
|
||||
self.device = device
|
||||
self.default_dtype = default_dtype
|
||||
|
||||
self._gpu_tuning_factor = int(gpu_tuning_factor)
|
||||
self._number_of_cpu_processes = int(number_of_cpu_processes)
|
||||
self._w_trainable = bool(w_trainable)
|
||||
self._skip_gradient_calculation = bool(skip_gradient_calculation)
|
||||
self._output_size = output_size
|
||||
self._output_layer = bool(output_layer)
|
||||
self._local_learning = bool(local_learning)
|
||||
|
||||
global_sbs_gpu_setting.append(torch.tensor([0]))
|
||||
global_sbs_size.append(torch.tensor([0, 0, 0, 0]))
|
||||
|
||||
if device == torch.device("cpu"):
|
||||
global_sbs_hdynamic_cpp.append(HDynamicCNNCPU())
|
||||
else:
|
||||
global_sbs_hdynamic_cpp.append(HDynamicCNNGPU())
|
||||
|
||||
self._sbs_gpu_setting_position = len(global_sbs_gpu_setting) - 1
|
||||
self._sbs_hdynamic_cpp_position = len(global_sbs_hdynamic_cpp) - 1
|
||||
|
||||
self.functional_sbs = FunctionalSbS.apply
|
||||
|
||||
####################################################################
|
||||
# Forward #
|
||||
####################################################################
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
spike: torch.Tensor,
|
||||
epsilon_xy: torch.Tensor,
|
||||
epsilon_t_0: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
h_initial: torch.Tensor,
|
||||
last_grad_scale: torch.Tensor,
|
||||
labels: torch.Tensor | None = None,
|
||||
keep_last_grad_scale: bool = False,
|
||||
disable_scale_grade: bool = True,
|
||||
forgetting_offset: float = -1.0,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if labels is None:
|
||||
labels_copy: torch.Tensor = torch.tensor(
|
||||
[], dtype=torch.int64, device=self.device
|
||||
)
|
||||
else:
|
||||
labels_copy = (
|
||||
labels.detach().clone().type(dtype=torch.int64).to(device=self.device)
|
||||
)
|
||||
|
||||
if (spike.shape[-2] * spike.shape[-1]) > self._gpu_tuning_factor:
|
||||
gpu_tuning_factor = self._gpu_tuning_factor
|
||||
else:
|
||||
gpu_tuning_factor = 0
|
||||
|
||||
parameter_list = torch.tensor(
|
||||
[
|
||||
int(self._number_of_cpu_processes), # 0
|
||||
int(self._output_size[0]), # 1
|
||||
int(self._output_size[1]), # 2
|
||||
int(gpu_tuning_factor), # 3
|
||||
int(self._sbs_gpu_setting_position), # 4
|
||||
int(self._sbs_hdynamic_cpp_position), # 5
|
||||
int(self._w_trainable), # 6
|
||||
int(disable_scale_grade), # 7
|
||||
int(keep_last_grad_scale), # 8
|
||||
int(self._skip_gradient_calculation), # 9
|
||||
int(self._output_layer), # 10
|
||||
int(self._local_learning), # 11
|
||||
],
|
||||
dtype=torch.int64,
|
||||
)
|
||||
|
||||
# SbS forward functional
|
||||
return self.functional_sbs(
|
||||
input,
|
||||
spike,
|
||||
epsilon_xy,
|
||||
epsilon_t_0,
|
||||
weights,
|
||||
h_initial,
|
||||
parameter_list,
|
||||
last_grad_scale,
|
||||
torch.tensor(
|
||||
forgetting_offset, device=self.device, dtype=self.default_dtype
|
||||
),
|
||||
labels_copy,
|
||||
)
|
||||
|
||||
|
||||
class FunctionalSbS(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward( # type: ignore
|
||||
ctx,
|
||||
input: torch.Tensor,
|
||||
spikes: torch.Tensor,
|
||||
epsilon_xy: torch.Tensor | None,
|
||||
epsilon_t_0: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
h_initial: torch.Tensor,
|
||||
parameter_list: torch.Tensor,
|
||||
grad_output_scale: torch.Tensor,
|
||||
forgetting_offset: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
number_of_spikes: int = int(spikes.shape[1])
|
||||
|
||||
if input.device == torch.device("cpu"):
|
||||
hdyn_number_of_cpu_processes: int = int(parameter_list[0])
|
||||
else:
|
||||
hdyn_number_of_cpu_processes = -1
|
||||
|
||||
output_size_0: int = int(parameter_list[1])
|
||||
output_size_1: int = int(parameter_list[2])
|
||||
gpu_tuning_factor: int = int(parameter_list[3])
|
||||
|
||||
sbs_gpu_setting_position = int(parameter_list[4])
|
||||
sbs_hdynamic_cpp_position = int(parameter_list[5])
|
||||
|
||||
# ###########################################################
|
||||
# H dynamic
|
||||
# ###########################################################
|
||||
|
||||
assert epsilon_t_0.ndim == 1
|
||||
assert epsilon_t_0.shape[0] >= number_of_spikes
|
||||
|
||||
# ############################################
|
||||
# Make space for the results
|
||||
# ############################################
|
||||
|
||||
output = torch.empty(
|
||||
(
|
||||
int(input.shape[0]),
|
||||
int(weights.shape[1]),
|
||||
output_size_0,
|
||||
output_size_1,
|
||||
),
|
||||
dtype=input.dtype,
|
||||
device=input.device,
|
||||
)
|
||||
|
||||
assert output.is_contiguous() is True
|
||||
if epsilon_xy is not None:
|
||||
assert epsilon_xy.is_contiguous() is True
|
||||
assert epsilon_xy.ndim == 3
|
||||
assert epsilon_t_0.is_contiguous() is True
|
||||
assert weights.is_contiguous() is True
|
||||
assert spikes.is_contiguous() is True
|
||||
assert h_initial.is_contiguous() is True
|
||||
|
||||
assert weights.ndim == 2
|
||||
assert h_initial.ndim == 1
|
||||
|
||||
sbs_profile = global_sbs_gpu_setting[sbs_gpu_setting_position].clone()
|
||||
|
||||
sbs_size = global_sbs_size[sbs_gpu_setting_position].clone()
|
||||
|
||||
if input.device != torch.device("cpu"):
|
||||
if (
|
||||
(sbs_profile.numel() == 1)
|
||||
or (sbs_size[0] != int(output.shape[0]))
|
||||
or (sbs_size[1] != int(output.shape[1]))
|
||||
or (sbs_size[2] != int(output.shape[2]))
|
||||
or (sbs_size[3] != int(output.shape[3]))
|
||||
):
|
||||
sbs_profile = torch.zeros(
|
||||
(14, 7), dtype=torch.int64, device=torch.device("cpu")
|
||||
)
|
||||
|
||||
global_sbs_hdynamic_cpp[sbs_hdynamic_cpp_position].gpu_occupancy_export(
|
||||
int(output.shape[2]),
|
||||
int(output.shape[3]),
|
||||
int(output.shape[0]),
|
||||
int(output.shape[1]),
|
||||
sbs_profile.data_ptr(),
|
||||
int(sbs_profile.shape[0]),
|
||||
int(sbs_profile.shape[1]),
|
||||
)
|
||||
global_sbs_gpu_setting[sbs_gpu_setting_position] = sbs_profile.clone()
|
||||
sbs_size[0] = int(output.shape[0])
|
||||
sbs_size[1] = int(output.shape[1])
|
||||
sbs_size[2] = int(output.shape[2])
|
||||
sbs_size[3] = int(output.shape[3])
|
||||
global_sbs_size[sbs_gpu_setting_position] = sbs_size.clone()
|
||||
|
||||
else:
|
||||
global_sbs_hdynamic_cpp[sbs_hdynamic_cpp_position].gpu_occupancy_import(
|
||||
sbs_profile.data_ptr(),
|
||||
int(sbs_profile.shape[0]),
|
||||
int(sbs_profile.shape[1]),
|
||||
)
|
||||
|
||||
global_sbs_hdynamic_cpp[sbs_hdynamic_cpp_position].update(
|
||||
output.data_ptr(),
|
||||
int(output.shape[0]),
|
||||
int(output.shape[1]),
|
||||
int(output.shape[2]),
|
||||
int(output.shape[3]),
|
||||
epsilon_xy.data_ptr() if epsilon_xy is not None else int(0),
|
||||
int(epsilon_xy.shape[0]) if epsilon_xy is not None else int(0),
|
||||
int(epsilon_xy.shape[1]) if epsilon_xy is not None else int(0),
|
||||
int(epsilon_xy.shape[2]) if epsilon_xy is not None else int(0),
|
||||
epsilon_t_0.data_ptr(),
|
||||
int(epsilon_t_0.shape[0]),
|
||||
weights.data_ptr(),
|
||||
int(weights.shape[0]),
|
||||
int(weights.shape[1]),
|
||||
spikes.data_ptr(),
|
||||
int(spikes.shape[0]),
|
||||
int(spikes.shape[1]),
|
||||
int(spikes.shape[2]),
|
||||
int(spikes.shape[3]),
|
||||
h_initial.data_ptr(),
|
||||
int(h_initial.shape[0]),
|
||||
hdyn_number_of_cpu_processes,
|
||||
float(forgetting_offset.cpu().item()),
|
||||
int(gpu_tuning_factor),
|
||||
)
|
||||
|
||||
# ###########################################################
|
||||
# Save the necessary data for the backward pass
|
||||
# ###########################################################
|
||||
|
||||
ctx.save_for_backward(
|
||||
input,
|
||||
weights,
|
||||
output,
|
||||
parameter_list,
|
||||
grad_output_scale,
|
||||
labels,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
# ##############################################
|
||||
# Get the variables back
|
||||
# ##############################################
|
||||
(
|
||||
input,
|
||||
weights,
|
||||
output,
|
||||
parameter_list,
|
||||
last_grad_scale,
|
||||
labels,
|
||||
) = ctx.saved_tensors
|
||||
|
||||
assert labels.numel() > 0
|
||||
|
||||
# ##############################################
|
||||
# Default output
|
||||
# ##############################################
|
||||
grad_input = None
|
||||
grad_spikes = None
|
||||
grad_eps_xy = None
|
||||
grad_epsilon_t_0 = None
|
||||
grad_weights = None
|
||||
grad_h_initial = None
|
||||
grad_parameter_list = None
|
||||
grad_forgetting_offset = None
|
||||
grad_labels = None
|
||||
|
||||
# ##############################################
|
||||
# Parameters
|
||||
# ##############################################
|
||||
parameter_w_trainable: bool = bool(parameter_list[6])
|
||||
parameter_disable_scale_grade: bool = bool(parameter_list[7])
|
||||
parameter_keep_last_grad_scale: bool = bool(parameter_list[8])
|
||||
parameter_skip_gradient_calculation: bool = bool(parameter_list[9])
|
||||
parameter_output_layer: bool = bool(parameter_list[10])
|
||||
parameter_local_learning: bool = bool(parameter_list[11])
|
||||
|
||||
# ##############################################
|
||||
# Dealing with overall scale of the gradient
|
||||
# ##############################################
|
||||
if parameter_disable_scale_grade is False:
|
||||
if parameter_keep_last_grad_scale is True:
|
||||
last_grad_scale = torch.tensor(
|
||||
[torch.abs(grad_output).max(), last_grad_scale]
|
||||
).max()
|
||||
grad_output /= last_grad_scale
|
||||
grad_output_scale = last_grad_scale.clone()
|
||||
|
||||
input /= input.sum(dim=1, keepdim=True, dtype=weights.dtype)
|
||||
|
||||
# #################################################
|
||||
# User doesn't want us to calculate the gradients
|
||||
# #################################################
|
||||
|
||||
if parameter_skip_gradient_calculation is True:
|
||||
|
||||
return (
|
||||
grad_input,
|
||||
grad_spikes,
|
||||
grad_eps_xy,
|
||||
grad_epsilon_t_0,
|
||||
grad_weights,
|
||||
grad_h_initial,
|
||||
grad_parameter_list,
|
||||
grad_output_scale,
|
||||
grad_forgetting_offset,
|
||||
grad_labels,
|
||||
)
|
||||
|
||||
# #################################################
|
||||
# Calculate backprop error (grad_input)
|
||||
# #################################################
|
||||
|
||||
backprop_r: torch.Tensor = weights.unsqueeze(0).unsqueeze(-1).unsqueeze(
|
||||
-1
|
||||
) * output.unsqueeze(1)
|
||||
|
||||
backprop_bigr: torch.Tensor = backprop_r.sum(dim=2)
|
||||
|
||||
backprop_z: torch.Tensor = backprop_r * (
|
||||
1.0 / (backprop_bigr + 1e-20)
|
||||
).unsqueeze(2)
|
||||
grad_input: torch.Tensor = (backprop_z * grad_output.unsqueeze(1)).sum(2)
|
||||
del backprop_z
|
||||
|
||||
# #################################################
|
||||
# Calculate weight gradient (grad_weights)
|
||||
# #################################################
|
||||
|
||||
if parameter_w_trainable is False:
|
||||
|
||||
# #################################################
|
||||
# We don't train this weight
|
||||
# #################################################
|
||||
grad_weights = None
|
||||
|
||||
elif (parameter_output_layer is False) and (parameter_local_learning is True):
|
||||
# #################################################
|
||||
# Local learning
|
||||
# #################################################
|
||||
grad_weights = (
|
||||
(-2 * (input - backprop_bigr).unsqueeze(2) * output.unsqueeze(1))
|
||||
.sum(0)
|
||||
.sum(-1)
|
||||
.sum(-1)
|
||||
)
|
||||
|
||||
elif (parameter_output_layer is True) and (parameter_local_learning is True):
|
||||
|
||||
target_one_hot: torch.Tensor = torch.zeros(
|
||||
(
|
||||
labels.shape[0],
|
||||
output.shape[1],
|
||||
),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
|
||||
target_one_hot.scatter_(
|
||||
1,
|
||||
labels.to(input.device).unsqueeze(1),
|
||||
torch.ones(
|
||||
(labels.shape[0], 1),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
),
|
||||
)
|
||||
target_one_hot = target_one_hot.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
# (-2 * (input - backprop_bigr).unsqueeze(2) * (target_one_hot-output).unsqueeze(1))
|
||||
# (-2 * input.unsqueeze(2) * (target_one_hot-output).unsqueeze(1))
|
||||
grad_weights = (
|
||||
(
|
||||
-2
|
||||
* (input - backprop_bigr).unsqueeze(2)
|
||||
* target_one_hot.unsqueeze(1)
|
||||
)
|
||||
.sum(0)
|
||||
.sum(-1)
|
||||
.sum(-1)
|
||||
)
|
||||
|
||||
else:
|
||||
# #################################################
|
||||
# Backprop
|
||||
# #################################################
|
||||
backprop_f: torch.Tensor = output.unsqueeze(1) * (
|
||||
input / (backprop_bigr**2 + 1e-20)
|
||||
).unsqueeze(2)
|
||||
|
||||
result_omega: torch.Tensor = backprop_bigr.unsqueeze(
|
||||
2
|
||||
) * grad_output.unsqueeze(1)
|
||||
result_omega -= (backprop_r * grad_output.unsqueeze(1)).sum(2).unsqueeze(2)
|
||||
result_omega *= backprop_f
|
||||
del backprop_f
|
||||
grad_weights = result_omega.sum(0).sum(-1).sum(-1)
|
||||
del result_omega
|
||||
|
||||
del backprop_bigr
|
||||
del backprop_r
|
||||
|
||||
return (
|
||||
grad_input,
|
||||
grad_spikes,
|
||||
grad_eps_xy,
|
||||
grad_epsilon_t_0,
|
||||
grad_weights,
|
||||
grad_h_initial,
|
||||
grad_parameter_list,
|
||||
grad_output_scale,
|
||||
grad_forgetting_offset,
|
||||
grad_labels,
|
||||
)
|
104
network/InputSpikeImage.py
Normal file
104
network/InputSpikeImage.py
Normal file
|
@ -0,0 +1,104 @@
|
|||
import torch
|
||||
|
||||
from network.SpikeLayer import SpikeLayer
|
||||
from network.SpikeCountLayer import SpikeCountLayer
|
||||
|
||||
|
||||
class InputSpikeImage(torch.nn.Module):
|
||||
|
||||
_reshape: bool
|
||||
_normalize: bool
|
||||
_device: torch.device
|
||||
|
||||
number_of_spikes: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
number_of_spikes: int = -1,
|
||||
number_of_cpu_processes: int = 1,
|
||||
reshape: bool = False,
|
||||
normalize: bool = True,
|
||||
device: torch.device | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
assert device is not None
|
||||
self._device = device
|
||||
|
||||
self._reshape = bool(reshape)
|
||||
self._normalize = bool(normalize)
|
||||
|
||||
self.number_of_spikes = int(number_of_spikes)
|
||||
|
||||
if device != torch.device("cpu"):
|
||||
number_of_cpu_processes_spike_generator = 0
|
||||
else:
|
||||
number_of_cpu_processes_spike_generator = number_of_cpu_processes
|
||||
|
||||
self.spike_generator = SpikeLayer(
|
||||
number_of_cpu_processes=number_of_cpu_processes_spike_generator,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.spike_count = SpikeCountLayer(
|
||||
number_of_cpu_processes=number_of_cpu_processes
|
||||
)
|
||||
|
||||
####################################################################
|
||||
# Forward #
|
||||
####################################################################
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
if self.number_of_spikes < 1:
|
||||
return input
|
||||
|
||||
input_shape: list[int] = [
|
||||
int(input.shape[0]),
|
||||
int(input.shape[1]),
|
||||
int(input.shape[2]),
|
||||
int(input.shape[3]),
|
||||
]
|
||||
|
||||
if self._reshape is True:
|
||||
input_work = (
|
||||
input.detach()
|
||||
.clone()
|
||||
.to(self._device)
|
||||
.reshape(
|
||||
(input_shape[0], input_shape[1] * input_shape[2] * input_shape[3])
|
||||
)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(-1)
|
||||
)
|
||||
else:
|
||||
input_work = input.detach().clone().to(self._device)
|
||||
|
||||
spikes = self.spike_generator(
|
||||
input=input_work, number_of_spikes=self.number_of_spikes
|
||||
)
|
||||
|
||||
if self._reshape is True:
|
||||
dim_s: int = input_shape[1] * input_shape[2] * input_shape[3]
|
||||
else:
|
||||
dim_s = input_shape[1]
|
||||
|
||||
output: torch.Tensor = self.spike_count(spikes, dim_s)
|
||||
|
||||
if self._reshape is True:
|
||||
|
||||
output = (
|
||||
output.squeeze(-1)
|
||||
.squeeze(-1)
|
||||
.reshape(
|
||||
(input_shape[0], input_shape[1], input_shape[2], input_shape[3])
|
||||
)
|
||||
)
|
||||
|
||||
if self._normalize is True:
|
||||
output = output.type(dtype=input_work.dtype)
|
||||
output = output / output.sum(dim=-1, keepdim=True).sum(
|
||||
dim=-2, keepdim=True
|
||||
).sum(dim=-3, keepdim=True)
|
||||
|
||||
return output
|
|
@ -3,20 +3,22 @@ export
|
|||
|
||||
all:
|
||||
cd h_dynamic_cnn_cpu_cpp && $(MAKE) all
|
||||
cd h_dynamic_cnn_gpu_cpp && $(MAKE) all
|
||||
cd h_dynamic_cnn_gpu_cpp_v1 && $(MAKE) all
|
||||
cd spike_generation_cpu_cpp && $(MAKE) all
|
||||
cd spike_generation_gpu_cpp_v2 && $(MAKE) all
|
||||
cd multiplication_approximation_cpu_cpp && $(MAKE) all
|
||||
cd multiplication_approximation_gpu_cpp && $(MAKE) all
|
||||
cd count_spikes_cpu_cpp && $(MAKE) all
|
||||
cd sort_spikes_cpu_cpp && $(MAKE) all
|
||||
$(PYBIN)python3 pybind11_auto_pyi.py
|
||||
|
||||
clean:
|
||||
cd h_dynamic_cnn_cpu_cpp && $(MAKE) clean
|
||||
cd h_dynamic_cnn_gpu_cpp && $(MAKE) clean
|
||||
cd h_dynamic_cnn_gpu_cpp_v1 && $(MAKE) clean
|
||||
cd spike_generation_cpu_cpp && $(MAKE) clean
|
||||
cd spike_generation_gpu_cpp_v2 && $(MAKE) clean
|
||||
cd multiplication_approximation_cpu_cpp && $(MAKE) clean
|
||||
cd multiplication_approximation_gpu_cpp && $(MAKE) clean
|
||||
cd count_spikes_cpu_cpp && $(MAKE) clean
|
||||
cd sort_spikes_cpu_cpp && $(MAKE) clean
|
||||
|
||||
|
|
480
network/SbSLayer.py
Normal file
480
network/SbSLayer.py
Normal file
|
@ -0,0 +1,480 @@
|
|||
import torch
|
||||
|
||||
from network.SpikeLayer import SpikeLayer
|
||||
from network.HDynamicLayer import HDynamicLayer
|
||||
|
||||
from network.calculate_output_size import calculate_output_size
|
||||
from network.SortSpikesLayer import SortSpikesLayer
|
||||
|
||||
|
||||
class SbSLayer(torch.nn.Module):
|
||||
|
||||
_epsilon_xy: torch.Tensor | None = None
|
||||
_epsilon_xy_use: bool
|
||||
_epsilon_0: float
|
||||
_weights: torch.nn.parameter.Parameter
|
||||
_weights_exists: bool = False
|
||||
_kernel_size: list[int]
|
||||
_stride: list[int]
|
||||
_dilation: list[int]
|
||||
_padding: list[int]
|
||||
_output_size: torch.Tensor
|
||||
_number_of_spikes: int
|
||||
_number_of_cpu_processes: int
|
||||
_number_of_neurons: int
|
||||
_number_of_input_neurons: int
|
||||
_epsilon_xy_intitial: float
|
||||
_h_initial: torch.Tensor | None = None
|
||||
_w_trainable: bool
|
||||
_last_grad_scale: torch.nn.parameter.Parameter
|
||||
_keep_last_grad_scale: bool
|
||||
_disable_scale_grade: bool
|
||||
_forgetting_offset: float
|
||||
_weight_noise_range: list[float]
|
||||
_skip_gradient_calculation: bool
|
||||
_is_pooling_layer: bool
|
||||
_input_size: list[int]
|
||||
_output_layer: bool = False
|
||||
_local_learning: bool = False
|
||||
|
||||
device: torch.device
|
||||
default_dtype: torch.dtype
|
||||
_gpu_tuning_factor: int
|
||||
|
||||
_max_grad_weights: torch.Tensor | None = None
|
||||
|
||||
_number_of_grad_weight_contributions: float = 0.0
|
||||
|
||||
last_input_store: bool = False
|
||||
last_input_data: torch.Tensor | None = None
|
||||
|
||||
_cooldown_after_number_of_spikes: int = -1
|
||||
_reduction_cooldown: float = 1.0
|
||||
_layer_id: int = -1
|
||||
|
||||
spike_full_layer_input_distribution: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
number_of_input_neurons: int,
|
||||
number_of_neurons: int,
|
||||
input_size: list[int],
|
||||
forward_kernel_size: list[int],
|
||||
number_of_spikes: int,
|
||||
epsilon_xy_intitial: float = 0.1,
|
||||
epsilon_xy_use: bool = False,
|
||||
epsilon_0: float = 1.0,
|
||||
weight_noise_range: list[float] = [0.0, 1.0],
|
||||
is_pooling_layer: bool = False,
|
||||
strides: list[int] = [1, 1],
|
||||
dilation: list[int] = [0, 0],
|
||||
padding: list[int] = [0, 0],
|
||||
number_of_cpu_processes: int = 1,
|
||||
w_trainable: bool = False,
|
||||
keep_last_grad_scale: bool = False,
|
||||
disable_scale_grade: bool = True,
|
||||
forgetting_offset: float = -1.0,
|
||||
skip_gradient_calculation: bool = False,
|
||||
device: torch.device | None = None,
|
||||
default_dtype: torch.dtype | None = None,
|
||||
gpu_tuning_factor: int = 10,
|
||||
layer_id: int = -1,
|
||||
cooldown_after_number_of_spikes: int = -1,
|
||||
reduction_cooldown: float = 1.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
assert device is not None
|
||||
assert default_dtype is not None
|
||||
self.device = device
|
||||
self.default_dtype = default_dtype
|
||||
|
||||
self._w_trainable = bool(w_trainable)
|
||||
self._keep_last_grad_scale = bool(keep_last_grad_scale)
|
||||
self._skip_gradient_calculation = bool(skip_gradient_calculation)
|
||||
self._disable_scale_grade = bool(disable_scale_grade)
|
||||
self._epsilon_xy_intitial = float(epsilon_xy_intitial)
|
||||
self._stride = strides
|
||||
self._dilation = dilation
|
||||
self._padding = padding
|
||||
self._kernel_size = forward_kernel_size
|
||||
self._number_of_input_neurons = int(number_of_input_neurons)
|
||||
self._number_of_neurons = int(number_of_neurons)
|
||||
self._epsilon_0 = float(epsilon_0)
|
||||
self._number_of_cpu_processes = int(number_of_cpu_processes)
|
||||
self._number_of_spikes = int(number_of_spikes)
|
||||
self._weight_noise_range = weight_noise_range
|
||||
self._is_pooling_layer = bool(is_pooling_layer)
|
||||
self._cooldown_after_number_of_spikes = int(cooldown_after_number_of_spikes)
|
||||
self.reduction_cooldown = float(reduction_cooldown)
|
||||
self._layer_id = layer_id
|
||||
self._epsilon_xy_use = epsilon_xy_use
|
||||
|
||||
assert len(input_size) == 2
|
||||
self._input_size = input_size
|
||||
|
||||
# The GPU hates me...
|
||||
# Too many SbS threads == bad
|
||||
# Thus I need to limit them...
|
||||
# (Reminder: We cannot access the mini-batch size here,
|
||||
# which is part of the GPU thread size calculation...)
|
||||
|
||||
self._last_grad_scale = torch.nn.parameter.Parameter(
|
||||
torch.tensor(-1.0, dtype=self.default_dtype),
|
||||
requires_grad=True,
|
||||
)
|
||||
|
||||
self._forgetting_offset = float(forgetting_offset)
|
||||
|
||||
self._output_size = calculate_output_size(
|
||||
value=input_size,
|
||||
kernel_size=self._kernel_size,
|
||||
stride=self._stride,
|
||||
dilation=self._dilation,
|
||||
padding=self._padding,
|
||||
)
|
||||
|
||||
self.set_h_init_to_uniform()
|
||||
|
||||
self.spike_generator = SpikeLayer(
|
||||
number_of_spikes=self._number_of_spikes,
|
||||
number_of_cpu_processes=self._number_of_cpu_processes,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.h_dynamic = HDynamicLayer(
|
||||
output_size=self._output_size.tolist(),
|
||||
output_layer=self._output_layer,
|
||||
local_learning=self._local_learning,
|
||||
number_of_cpu_processes=number_of_cpu_processes,
|
||||
w_trainable=w_trainable,
|
||||
skip_gradient_calculation=skip_gradient_calculation,
|
||||
device=device,
|
||||
default_dtype=self.default_dtype,
|
||||
gpu_tuning_factor=gpu_tuning_factor,
|
||||
)
|
||||
|
||||
assert len(input_size) >= 2
|
||||
self.spikes_sorter = SortSpikesLayer(
|
||||
kernel_size=self._kernel_size,
|
||||
input_shape=[
|
||||
self._number_of_input_neurons,
|
||||
int(input_size[0]),
|
||||
int(input_size[1]),
|
||||
],
|
||||
output_size=self._output_size.clone(),
|
||||
strides=self._stride,
|
||||
dilation=self._dilation,
|
||||
padding=self._padding,
|
||||
number_of_cpu_processes=number_of_cpu_processes,
|
||||
)
|
||||
|
||||
# TODO: TEST
|
||||
if layer_id == 0:
|
||||
self.spike_full_layer_input_distribution = True
|
||||
|
||||
# ###############################################################
|
||||
# Initialize the weights
|
||||
# ###############################################################
|
||||
|
||||
if self._is_pooling_layer is True:
|
||||
self.weights = self._make_pooling_weights()
|
||||
|
||||
else:
|
||||
assert len(self._weight_noise_range) == 2
|
||||
weights = torch.empty(
|
||||
(
|
||||
int(self._kernel_size[0])
|
||||
* int(self._kernel_size[1])
|
||||
* int(self._number_of_input_neurons),
|
||||
int(self._number_of_neurons),
|
||||
),
|
||||
dtype=self.default_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
torch.nn.init.uniform_(
|
||||
weights,
|
||||
a=float(self._weight_noise_range[0]),
|
||||
b=float(self._weight_noise_range[1]),
|
||||
)
|
||||
self.weights = weights
|
||||
|
||||
####################################################################
|
||||
# Variables in and out #
|
||||
####################################################################
|
||||
|
||||
def get_epsilon_t(self, number_of_spikes: int):
|
||||
"""Generates the time series of the basic epsilon."""
|
||||
t = (
|
||||
torch.arange(
|
||||
0, number_of_spikes, dtype=self.default_dtype, device=self.device
|
||||
)
|
||||
+ 1
|
||||
)
|
||||
|
||||
# torch.ones((number_of_spikes), dtype=self.default_dtype, device=self.device
|
||||
epsilon_t: torch.Tensor = t ** (-1.0 / 2.0)
|
||||
|
||||
if (self._cooldown_after_number_of_spikes < number_of_spikes) and (
|
||||
self._cooldown_after_number_of_spikes >= 0
|
||||
):
|
||||
epsilon_t[
|
||||
self._cooldown_after_number_of_spikes : number_of_spikes
|
||||
] /= self._reduction_cooldown
|
||||
return epsilon_t
|
||||
|
||||
@property
|
||||
def weights(self) -> torch.Tensor | None:
|
||||
if self._weights_exists is False:
|
||||
return None
|
||||
else:
|
||||
return self._weights
|
||||
|
||||
@weights.setter
|
||||
def weights(self, value: torch.Tensor):
|
||||
assert value is not None
|
||||
assert torch.is_tensor(value) is True
|
||||
assert value.dim() == 2
|
||||
temp: torch.Tensor = (
|
||||
value.detach()
|
||||
.clone(memory_format=torch.contiguous_format)
|
||||
.type(dtype=self.default_dtype)
|
||||
.to(device=self.device)
|
||||
)
|
||||
temp /= temp.sum(dim=0, keepdim=True, dtype=self.default_dtype)
|
||||
if self._weights_exists is False:
|
||||
self._weights = torch.nn.parameter.Parameter(temp, requires_grad=True)
|
||||
self._weights_exists = True
|
||||
else:
|
||||
self._weights.data = temp
|
||||
|
||||
@property
|
||||
def h_initial(self) -> torch.Tensor | None:
|
||||
return self._h_initial
|
||||
|
||||
@h_initial.setter
|
||||
def h_initial(self, value: torch.Tensor):
|
||||
assert value is not None
|
||||
assert torch.is_tensor(value) is True
|
||||
assert value.dim() == 1
|
||||
assert value.dtype == self.default_dtype
|
||||
self._h_initial = (
|
||||
value.detach()
|
||||
.clone(memory_format=torch.contiguous_format)
|
||||
.type(dtype=self.default_dtype)
|
||||
.to(device=self.device)
|
||||
.requires_grad_(False)
|
||||
)
|
||||
|
||||
def update_pre_care(self):
|
||||
|
||||
if self._weights.grad is not None:
|
||||
assert self._number_of_grad_weight_contributions > 0
|
||||
self._weights.grad /= self._number_of_grad_weight_contributions
|
||||
self._number_of_grad_weight_contributions = 0.0
|
||||
|
||||
def update_after_care(self, threshold_weight: float):
|
||||
|
||||
if self._w_trainable is True:
|
||||
self.norm_weights()
|
||||
self.threshold_weights(threshold_weight)
|
||||
self.norm_weights()
|
||||
|
||||
def after_batch(self, new_state: bool = False):
|
||||
if self._keep_last_grad_scale is True:
|
||||
self._last_grad_scale.data = self._last_grad_scale.grad
|
||||
self._keep_last_grad_scale = new_state
|
||||
|
||||
self._last_grad_scale.grad = torch.zeros_like(self._last_grad_scale.grad)
|
||||
|
||||
####################################################################
|
||||
# Helper functions #
|
||||
####################################################################
|
||||
|
||||
def _make_pooling_weights(self) -> torch.Tensor:
|
||||
"""For generating the pooling weights."""
|
||||
|
||||
assert self._number_of_neurons is not None
|
||||
assert self._kernel_size is not None
|
||||
|
||||
weights: torch.Tensor = torch.zeros(
|
||||
(
|
||||
int(self._kernel_size[0]),
|
||||
int(self._kernel_size[1]),
|
||||
int(self._number_of_neurons),
|
||||
int(self._number_of_neurons),
|
||||
),
|
||||
dtype=self.default_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
for i in range(0, int(self._number_of_neurons)):
|
||||
weights[:, :, i, i] = 1.0
|
||||
|
||||
weights = weights.moveaxis(-1, 0).moveaxis(-1, 1)
|
||||
|
||||
weights = torch.nn.functional.unfold(
|
||||
input=weights,
|
||||
kernel_size=(int(self._kernel_size[0]), int(self._kernel_size[1])),
|
||||
dilation=(1, 1),
|
||||
padding=(0, 0),
|
||||
stride=(1, 1),
|
||||
).squeeze()
|
||||
|
||||
weights = torch.moveaxis(weights, 0, 1)
|
||||
|
||||
return weights
|
||||
|
||||
def set_h_init_to_uniform(self) -> None:
|
||||
|
||||
assert self._number_of_neurons > 2
|
||||
|
||||
self.h_initial: torch.Tensor = torch.full(
|
||||
(self._number_of_neurons,),
|
||||
(1.0 / float(self._number_of_neurons)),
|
||||
dtype=self.default_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def norm_weights(self) -> None:
|
||||
assert self._weights_exists is True
|
||||
temp: torch.Tensor = (
|
||||
self._weights.data.detach()
|
||||
.clone(memory_format=torch.contiguous_format)
|
||||
.type(dtype=self.default_dtype)
|
||||
.to(device=self.device)
|
||||
)
|
||||
temp /= temp.sum(dim=0, keepdim=True, dtype=self.default_dtype)
|
||||
self._weights.data = temp
|
||||
|
||||
def threshold_weights(self, threshold: float) -> None:
|
||||
assert self._weights_exists is True
|
||||
assert threshold >= 0
|
||||
|
||||
torch.clamp(
|
||||
self._weights.data,
|
||||
min=float(threshold),
|
||||
max=None,
|
||||
out=self._weights.data,
|
||||
)
|
||||
|
||||
####################################################################
|
||||
# Forward #
|
||||
####################################################################
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
labels: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# Are we happy with the input?
|
||||
assert input is not None
|
||||
assert torch.is_tensor(input) is True
|
||||
assert input.dim() == 4
|
||||
assert input.dtype == self.default_dtype
|
||||
assert input.shape[1] == self._number_of_input_neurons
|
||||
assert input.shape[2] == self._input_size[0]
|
||||
assert input.shape[3] == self._input_size[1]
|
||||
|
||||
# Are we happy with the rest of the network?
|
||||
assert self._epsilon_0 is not None
|
||||
assert self._h_initial is not None
|
||||
assert self._forgetting_offset is not None
|
||||
assert self._weights_exists is True
|
||||
assert self._weights is not None
|
||||
|
||||
# Convolution of the input...
|
||||
# Well, this is a convoltion layer
|
||||
# there needs to be convolution somewhere
|
||||
input_convolved = torch.nn.functional.fold(
|
||||
torch.nn.functional.unfold(
|
||||
input.requires_grad_(True),
|
||||
kernel_size=(int(self._kernel_size[0]), int(self._kernel_size[1])),
|
||||
dilation=(int(self._dilation[0]), int(self._dilation[1])),
|
||||
padding=(int(self._padding[0]), int(self._padding[1])),
|
||||
stride=(int(self._stride[0]), int(self._stride[1])),
|
||||
),
|
||||
output_size=tuple(self._output_size.tolist()),
|
||||
kernel_size=(1, 1),
|
||||
dilation=(1, 1),
|
||||
padding=(0, 0),
|
||||
stride=(1, 1),
|
||||
)
|
||||
|
||||
# We might need the convolved input for other layers
|
||||
# let us keep it for the future
|
||||
if self.last_input_store is True:
|
||||
self.last_input_data = input_convolved.detach().clone()
|
||||
self.last_input_data /= self.last_input_data.sum(dim=1, keepdim=True)
|
||||
else:
|
||||
self.last_input_data = None
|
||||
|
||||
epsilon_t_0: torch.Tensor = (
|
||||
(self.get_epsilon_t(self._number_of_spikes) * self._epsilon_0)
|
||||
.type(input.dtype)
|
||||
.to(input.device)
|
||||
)
|
||||
|
||||
if (self._epsilon_xy is None) and (self._epsilon_xy_use is True):
|
||||
self._epsilon_xy = torch.full(
|
||||
(
|
||||
input_convolved.shape[1],
|
||||
input_convolved.shape[2],
|
||||
input_convolved.shape[3],
|
||||
),
|
||||
float(self._epsilon_xy_intitial),
|
||||
dtype=self.default_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if self._epsilon_xy_use is True:
|
||||
assert self._epsilon_xy is not None
|
||||
# In the case somebody tried to replace the matrix with wrong dimensions
|
||||
assert self._epsilon_xy.shape[0] == input_convolved.shape[1]
|
||||
assert self._epsilon_xy.shape[1] == input_convolved.shape[2]
|
||||
assert self._epsilon_xy.shape[2] == input_convolved.shape[3]
|
||||
else:
|
||||
assert self._epsilon_xy is None
|
||||
|
||||
if self.spike_full_layer_input_distribution is False:
|
||||
spike = self.spike_generator(input_convolved, int(self._number_of_spikes))
|
||||
else:
|
||||
input_shape = input.shape
|
||||
input = (
|
||||
input.reshape(
|
||||
(input_shape[0], input_shape[1] * input_shape[2] * input_shape[3])
|
||||
)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(-1)
|
||||
)
|
||||
spike_unsorted = self.spike_generator(input, int(self._number_of_spikes))
|
||||
input = (
|
||||
input.squeeze(-1)
|
||||
.squeeze(-1)
|
||||
.reshape(
|
||||
(input_shape[0], input_shape[1], input_shape[2], input_shape[3])
|
||||
)
|
||||
)
|
||||
spike = self.spikes_sorter(spike_unsorted).to(device=input_convolved.device)
|
||||
|
||||
output = self.h_dynamic(
|
||||
input=input_convolved,
|
||||
spike=spike,
|
||||
epsilon_xy=self._epsilon_xy,
|
||||
epsilon_t_0=epsilon_t_0,
|
||||
weights=self._weights,
|
||||
h_initial=self._h_initial,
|
||||
last_grad_scale=self._last_grad_scale,
|
||||
labels=labels,
|
||||
keep_last_grad_scale=self._keep_last_grad_scale,
|
||||
disable_scale_grade=self._disable_scale_grade,
|
||||
forgetting_offset=self._forgetting_offset,
|
||||
)
|
||||
|
||||
self._number_of_grad_weight_contributions += (
|
||||
output.shape[0] * output.shape[-2] * output.shape[-1]
|
||||
)
|
||||
|
||||
return output
|
|
@ -1,15 +1,15 @@
|
|||
import torch
|
||||
|
||||
from network.SbS import SbS
|
||||
from network.SbSLayer import SbSLayer
|
||||
|
||||
|
||||
class SbSReconstruction(torch.nn.Module):
|
||||
|
||||
_the_sbs_layer: SbS
|
||||
_the_sbs_layer: SbSLayer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
the_sbs_layer: SbS,
|
||||
the_sbs_layer: SbSLayer,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
|
173
network/SortSpikesLayer.py
Normal file
173
network/SortSpikesLayer.py
Normal file
|
@ -0,0 +1,173 @@
|
|||
import torch
|
||||
|
||||
from network.PySortSpikesCPU import SortSpikesCPU
|
||||
|
||||
|
||||
class SortSpikesLayer(torch.nn.Module):
|
||||
|
||||
_kernel_size: list[int]
|
||||
_stride: list[int]
|
||||
_dilation: list[int]
|
||||
_padding: list[int]
|
||||
_output_size: torch.Tensor
|
||||
_number_of_cpu_processes: int
|
||||
_input_shape: list[int]
|
||||
|
||||
order: torch.Tensor | None = None
|
||||
order_convoled: torch.Tensor | None = None
|
||||
indices: torch.Tensor | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: list[int],
|
||||
input_shape: list[int],
|
||||
output_size: torch.Tensor,
|
||||
strides: list[int] = [1, 1],
|
||||
dilation: list[int] = [0, 0],
|
||||
padding: list[int] = [0, 0],
|
||||
number_of_cpu_processes: int = 1,
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
|
||||
self._stride = strides
|
||||
self._dilation = dilation
|
||||
self._padding = padding
|
||||
self._kernel_size = kernel_size
|
||||
self._output_size = output_size
|
||||
self._number_of_cpu_processes = number_of_cpu_processes
|
||||
self._input_shape = input_shape
|
||||
|
||||
self.sort_spikes = SortSpikesCPU()
|
||||
|
||||
self.order = (
|
||||
torch.arange(
|
||||
0,
|
||||
self._input_shape[0] * self._input_shape[1] * self._input_shape[2],
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
.reshape(
|
||||
(
|
||||
1,
|
||||
self._input_shape[0],
|
||||
self._input_shape[1],
|
||||
self._input_shape[2],
|
||||
)
|
||||
)
|
||||
.type(dtype=torch.float32)
|
||||
)
|
||||
|
||||
self.order_convoled = torch.nn.functional.fold(
|
||||
torch.nn.functional.unfold(
|
||||
self.order,
|
||||
kernel_size=(
|
||||
int(self._kernel_size[0]),
|
||||
int(self._kernel_size[1]),
|
||||
),
|
||||
dilation=(int(self._dilation[0]), int(self._dilation[1])),
|
||||
padding=(int(self._padding[0]), int(self._padding[1])),
|
||||
stride=(int(self._stride[0]), int(self._stride[1])),
|
||||
),
|
||||
output_size=tuple(self._output_size.tolist()),
|
||||
kernel_size=(1, 1),
|
||||
dilation=(1, 1),
|
||||
padding=(0, 0),
|
||||
stride=(1, 1),
|
||||
).type(dtype=torch.int64)
|
||||
|
||||
assert self.order_convoled is not None
|
||||
|
||||
self.order_convoled = self.order_convoled.reshape(
|
||||
(
|
||||
self.order_convoled.shape[1]
|
||||
* self.order_convoled.shape[2]
|
||||
* self.order_convoled.shape[3]
|
||||
)
|
||||
)
|
||||
|
||||
max_length: int = 0
|
||||
max_range: int = (
|
||||
self._input_shape[0] * self._input_shape[1] * self._input_shape[2]
|
||||
)
|
||||
for id in range(0, max_range):
|
||||
idx = torch.where(self.order_convoled == id)[0]
|
||||
max_length = max(max_length, int(idx.shape[0]))
|
||||
|
||||
self.indices = torch.full(
|
||||
(max_range, max_length),
|
||||
-1,
|
||||
dtype=torch.int64,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
|
||||
for id in range(0, max_range):
|
||||
idx = torch.where(self.order_convoled == id)[0]
|
||||
self.indices[id, 0 : int(idx.shape[0])] = idx
|
||||
|
||||
####################################################################
|
||||
# Forward #
|
||||
####################################################################
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert len(self._input_shape) == 3
|
||||
assert input.shape[-2] == 1
|
||||
assert input.shape[-1] == 1
|
||||
assert self.indices is not None
|
||||
|
||||
spikes_count = torch.zeros(
|
||||
(input.shape[0], int(self._output_size[0]), int(self._output_size[1])),
|
||||
device=torch.device("cpu"),
|
||||
dtype=torch.int64,
|
||||
)
|
||||
|
||||
input_cpu = input.clone().cpu()
|
||||
|
||||
self.sort_spikes.count(
|
||||
input_cpu.data_ptr(), # Input
|
||||
int(input_cpu.shape[0]),
|
||||
int(input_cpu.shape[1]),
|
||||
int(input_cpu.shape[2]),
|
||||
int(input_cpu.shape[3]),
|
||||
spikes_count.data_ptr(), # Output
|
||||
int(spikes_count.shape[0]),
|
||||
int(spikes_count.shape[1]),
|
||||
int(spikes_count.shape[2]),
|
||||
self.indices.data_ptr(), # Positions
|
||||
int(self.indices.shape[0]),
|
||||
int(self.indices.shape[1]),
|
||||
int(self._number_of_cpu_processes),
|
||||
)
|
||||
|
||||
spikes_output = torch.full(
|
||||
(
|
||||
input.shape[0],
|
||||
int(spikes_count.max()),
|
||||
int(self._output_size[0]),
|
||||
int(self._output_size[1]),
|
||||
),
|
||||
-1,
|
||||
dtype=torch.int64,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
|
||||
self.sort_spikes.process(
|
||||
input_cpu.data_ptr(), # Input
|
||||
int(input_cpu.shape[0]),
|
||||
int(input_cpu.shape[1]),
|
||||
int(input_cpu.shape[2]),
|
||||
int(input_cpu.shape[3]),
|
||||
spikes_output.data_ptr(), # Output
|
||||
int(spikes_output.shape[0]),
|
||||
int(spikes_output.shape[1]),
|
||||
int(spikes_output.shape[2]),
|
||||
int(spikes_output.shape[3]),
|
||||
self.indices.data_ptr(), # Positions
|
||||
int(self.indices.shape[0]),
|
||||
int(self.indices.shape[1]),
|
||||
int(self._number_of_cpu_processes),
|
||||
)
|
||||
|
||||
return spikes_output
|
52
network/SpikeCountLayer.py
Normal file
52
network/SpikeCountLayer.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
import torch
|
||||
|
||||
from network.PyCountSpikesCPU import CountSpikesCPU
|
||||
|
||||
|
||||
class SpikeCountLayer(torch.nn.Module):
|
||||
_number_of_cpu_processes: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
number_of_cpu_processes: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._number_of_cpu_processes = number_of_cpu_processes
|
||||
|
||||
####################################################################
|
||||
# Forward #
|
||||
####################################################################
|
||||
|
||||
def forward(self, input: torch.Tensor, dim_s: int) -> torch.Tensor:
|
||||
|
||||
assert input.ndim == 4
|
||||
assert dim_s > 0
|
||||
|
||||
input_cpu = input.cpu()
|
||||
|
||||
histogram = torch.zeros(
|
||||
(
|
||||
int(input.shape[0]),
|
||||
int(dim_s),
|
||||
int(input.shape[-2]),
|
||||
int(input.shape[-1]),
|
||||
),
|
||||
dtype=torch.int64,
|
||||
device=input_cpu.device,
|
||||
)
|
||||
|
||||
count_spikes = CountSpikesCPU()
|
||||
|
||||
count_spikes.process(
|
||||
input_cpu.data_ptr(),
|
||||
int(input_cpu.shape[0]),
|
||||
int(input_cpu.shape[1]),
|
||||
int(input_cpu.shape[2]),
|
||||
int(input_cpu.shape[3]),
|
||||
histogram.data_ptr(),
|
||||
int(histogram.shape[1]),
|
||||
int(self._number_of_cpu_processes),
|
||||
)
|
||||
|
||||
return histogram.to(device=input.device)
|
|
@ -3,8 +3,6 @@ import torch
|
|||
from network.PySpikeGenerationCPU import SpikeGenerationCPU
|
||||
from network.PySpikeGenerationGPU import SpikeGenerationGPU
|
||||
|
||||
# from PyCountSpikesCPU import CountSpikesCPU
|
||||
|
||||
global_spike_generation_gpu_setting: list[torch.Tensor] = []
|
||||
global_spike_size: list[torch.Tensor] = []
|
||||
global_spike_generation_cpp: list[SpikeGenerationCPU | SpikeGenerationGPU] = []
|
||||
|
@ -16,28 +14,21 @@ class SpikeLayer(torch.nn.Module):
|
|||
_spike_generation_gpu_setting_position: int
|
||||
_number_of_cpu_processes: int
|
||||
_number_of_spikes: int
|
||||
|
||||
_spikes: torch.Tensor | None = None
|
||||
_store_spikes: bool
|
||||
device: torch.device
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
number_of_spikes: int = 1,
|
||||
number_of_spikes: int = -1,
|
||||
number_of_cpu_processes: int = 1,
|
||||
device: torch.device | None = None,
|
||||
default_dtype: torch.dtype | None = None,
|
||||
store_spikes: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
assert device is not None
|
||||
assert default_dtype is not None
|
||||
self.device = device
|
||||
self.default_dtype = default_dtype
|
||||
|
||||
self._number_of_cpu_processes = number_of_cpu_processes
|
||||
self._number_of_spikes = number_of_spikes
|
||||
self._store_spikes = store_spikes
|
||||
|
||||
global_spike_generation_gpu_setting.append(torch.tensor([0]))
|
||||
global_spike_size.append(torch.tensor([0, 0, 0, 0]))
|
||||
|
@ -62,7 +53,6 @@ class SpikeLayer(torch.nn.Module):
|
|||
self,
|
||||
input: torch.Tensor,
|
||||
number_of_spikes: int | None = None,
|
||||
store_spikes: bool | None = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if number_of_spikes is None:
|
||||
|
@ -80,18 +70,7 @@ class SpikeLayer(torch.nn.Module):
|
|||
dtype=torch.int64,
|
||||
)
|
||||
|
||||
spikes = self.functional_spike_generation(input, parameter_list)
|
||||
|
||||
if (store_spikes is not None) and (store_spikes is True):
|
||||
self._spikes = spikes.detach().clone()
|
||||
elif (store_spikes is not None) and (store_spikes is False):
|
||||
self._spikes = None
|
||||
elif self._store_spikes is True:
|
||||
self._spikes = spikes.detach().clone()
|
||||
else:
|
||||
self._spikes = None
|
||||
|
||||
return spikes
|
||||
return self.functional_spike_generation(input, parameter_list)
|
||||
|
||||
|
||||
class FunctionalSpikeGeneration(torch.autograd.Function):
|
||||
|
|
|
@ -3,10 +3,11 @@ import torch
|
|||
|
||||
from network.calculate_output_size import calculate_output_size
|
||||
from network.Parameter import Config
|
||||
from network.SbS import SbS
|
||||
from network.SbSLayer import SbSLayer
|
||||
from network.SplitOnOffLayer import SplitOnOffLayer
|
||||
from network.Conv2dApproximation import Conv2dApproximation
|
||||
from network.SbSReconstruction import SbSReconstruction
|
||||
from network.InputSpikeImage import InputSpikeImage
|
||||
|
||||
|
||||
def build_network(
|
||||
|
@ -144,7 +145,7 @@ def build_network(
|
|||
is_pooling_layer = True
|
||||
|
||||
network.append(
|
||||
SbS(
|
||||
SbSLayer(
|
||||
number_of_input_neurons=in_channels,
|
||||
number_of_neurons=out_channels,
|
||||
input_size=input_size[-1],
|
||||
|
@ -190,7 +191,7 @@ def build_network(
|
|||
logging.info(f"Layer: {layer_id} -> SbS Reconstruction Layer")
|
||||
|
||||
assert layer_id > 0
|
||||
assert isinstance(network[-1], SbS) is True
|
||||
assert isinstance(network[-1], SbSLayer) is True
|
||||
|
||||
network.append(SbSReconstruction(network[-1]))
|
||||
network[-1]._w_trainable = False
|
||||
|
@ -365,6 +366,36 @@ def build_network(
|
|||
).tolist()
|
||||
input_size.append(input_size_temp)
|
||||
|
||||
# #############################################################
|
||||
# Approx CONV2D layer:
|
||||
# #############################################################
|
||||
|
||||
elif (
|
||||
cfg.network_structure.layer_type[layer_id]
|
||||
.upper()
|
||||
.startswith("INPUT SPIKE IMAGE")
|
||||
is True
|
||||
):
|
||||
logging.info(f"Layer: {layer_id} -> Input Spike Image Layer")
|
||||
|
||||
number_of_spikes: int = -1
|
||||
if len(cfg.number_of_spikes) > layer_id:
|
||||
number_of_spikes = cfg.number_of_spikes[layer_id]
|
||||
elif len(cfg.number_of_spikes) == 1:
|
||||
number_of_spikes = cfg.number_of_spikes[0]
|
||||
|
||||
network.append(
|
||||
InputSpikeImage(
|
||||
number_of_spikes=number_of_spikes,
|
||||
number_of_cpu_processes=cfg.number_of_cpu_processes,
|
||||
reshape=True,
|
||||
normalize=True,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
input_size.append(input_size[-1])
|
||||
|
||||
# #############################################################
|
||||
# Failure becaue we didn't found the selection of layer
|
||||
# #############################################################
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# %%
|
||||
import torch
|
||||
from network.Parameter import Config
|
||||
from network.SbS import SbS
|
||||
from network.SbSLayer import SbSLayer
|
||||
from network.Conv2dApproximation import Conv2dApproximation
|
||||
from network.Adam import Adam
|
||||
|
||||
|
@ -20,7 +20,7 @@ def build_optimizer(
|
|||
|
||||
for id in range(0, len(network)):
|
||||
|
||||
if (isinstance(network[id], SbS) is True) and (
|
||||
if (isinstance(network[id], SbSLayer) is True) and (
|
||||
network[id]._w_trainable is True
|
||||
):
|
||||
parameter_list_weights.append(network[id]._weights)
|
||||
|
|
|
@ -3,9 +3,10 @@ import torch
|
|||
import glob
|
||||
import numpy as np
|
||||
|
||||
from network.SbS import SbS
|
||||
from network.SbSLayer import SbSLayer
|
||||
from network.SplitOnOffLayer import SplitOnOffLayer
|
||||
from network.Conv2dApproximation import Conv2dApproximation
|
||||
import os
|
||||
|
||||
|
||||
def load_previous_weights(
|
||||
|
@ -14,22 +15,28 @@ def load_previous_weights(
|
|||
logging,
|
||||
device: torch.device,
|
||||
default_dtype: torch.dtype,
|
||||
order_id: float | int | None = None,
|
||||
) -> None:
|
||||
|
||||
if order_id is None:
|
||||
post_fix: str = ""
|
||||
else:
|
||||
post_fix = f"_{order_id}"
|
||||
|
||||
for id in range(0, len(network)):
|
||||
|
||||
# #################################################
|
||||
# SbS
|
||||
# #################################################
|
||||
|
||||
if isinstance(network[id], SbS) is True:
|
||||
# Are there weights that overwrite the initial weights?
|
||||
file_to_load = glob.glob(overload_path + "/Weight_L" + str(id) + "_*.npy")
|
||||
if isinstance(network[id], SbSLayer) is True:
|
||||
filename_wilcard = os.path.join(
|
||||
overload_path, f"Weight_L{id}_*{post_fix}.npy"
|
||||
)
|
||||
file_to_load = glob.glob(filename_wilcard)
|
||||
|
||||
if len(file_to_load) > 1:
|
||||
raise Exception(
|
||||
f"Too many previous weights files {overload_path}/Weight_L{id}*.npy"
|
||||
)
|
||||
raise Exception(f"Too many previous weights files {filename_wilcard}")
|
||||
|
||||
if len(file_to_load) == 1:
|
||||
network[id].weights = torch.tensor(
|
||||
|
@ -45,13 +52,13 @@ def load_previous_weights(
|
|||
# Conv2d weights
|
||||
# #################################################
|
||||
|
||||
# Are there weights that overwrite the initial weights?
|
||||
file_to_load = glob.glob(overload_path + "/Weight_L" + str(id) + "_*.npy")
|
||||
filename_wilcard = os.path.join(
|
||||
overload_path, f"Weight_L{id}_*{post_fix}.npy"
|
||||
)
|
||||
file_to_load = glob.glob(filename_wilcard)
|
||||
|
||||
if len(file_to_load) > 1:
|
||||
raise Exception(
|
||||
f"Too many previous weights files {overload_path}/Weight_L{id}*.npy"
|
||||
)
|
||||
raise Exception(f"Too many previous weights files {filename_wilcard}")
|
||||
|
||||
if len(file_to_load) == 1:
|
||||
network[id]._parameters["weight"].data = torch.tensor(
|
||||
|
@ -65,13 +72,13 @@ def load_previous_weights(
|
|||
# Conv2d bias
|
||||
# #################################################
|
||||
|
||||
# Are there biases that overwrite the initial weights?
|
||||
file_to_load = glob.glob(overload_path + "/Bias_L" + str(id) + "_*.npy")
|
||||
filename_wilcard = os.path.join(
|
||||
overload_path, f"Bias_L{id}_*{post_fix}.npy"
|
||||
)
|
||||
file_to_load = glob.glob(filename_wilcard)
|
||||
|
||||
if len(file_to_load) > 1:
|
||||
raise Exception(
|
||||
f"Too many previous weights files {overload_path}/Weight_L{id}*.npy"
|
||||
)
|
||||
raise Exception(f"Too many previous weights files {filename_wilcard}")
|
||||
|
||||
if len(file_to_load) == 1:
|
||||
network[id]._parameters["bias"].data = torch.tensor(
|
||||
|
@ -87,13 +94,13 @@ def load_previous_weights(
|
|||
# Approximate Conv2d weights
|
||||
# #################################################
|
||||
|
||||
# Are there weights that overwrite the initial weights?
|
||||
file_to_load = glob.glob(overload_path + "/Weight_L" + str(id) + "_*.npy")
|
||||
filename_wilcard = os.path.join(
|
||||
overload_path, f"Weight_L{id}_*{post_fix}.npy"
|
||||
)
|
||||
file_to_load = glob.glob(filename_wilcard)
|
||||
|
||||
if len(file_to_load) > 1:
|
||||
raise Exception(
|
||||
f"Too many previous weights files {overload_path}/Weight_L{id}*.npy"
|
||||
)
|
||||
raise Exception(f"Too many previous weights files {filename_wilcard}")
|
||||
|
||||
if len(file_to_load) == 1:
|
||||
network[id].weights.data = torch.tensor(
|
||||
|
@ -107,13 +114,13 @@ def load_previous_weights(
|
|||
# Approximate Conv2d bias
|
||||
# #################################################
|
||||
|
||||
# Are there biases that overwrite the initial weights?
|
||||
file_to_load = glob.glob(overload_path + "/Bias_L" + str(id) + "_*.npy")
|
||||
filename_wilcard = os.path.join(
|
||||
overload_path, f"Bias_L{id}_*{post_fix}.npy"
|
||||
)
|
||||
file_to_load = glob.glob(filename_wilcard)
|
||||
|
||||
if len(file_to_load) > 1:
|
||||
raise Exception(
|
||||
f"Too many previous weights files {overload_path}/Weight_L{id}*.npy"
|
||||
)
|
||||
raise Exception(f"Too many previous weights files {filename_wilcard}")
|
||||
|
||||
if len(file_to_load) == 1:
|
||||
network[id].bias.data = torch.tensor(
|
||||
|
@ -127,13 +134,13 @@ def load_previous_weights(
|
|||
# SplitOnOffLayer
|
||||
# #################################################
|
||||
if isinstance(network[id], SplitOnOffLayer) is True:
|
||||
# Are there weights that overwrite the initial weights?
|
||||
file_to_load = glob.glob(overload_path + "/Mean_L" + str(id) + "_*.npy")
|
||||
filename_wilcard = os.path.join(
|
||||
overload_path, f"Mean_L{id}_*{post_fix}.npy"
|
||||
)
|
||||
file_to_load = glob.glob(filename_wilcard)
|
||||
|
||||
if len(file_to_load) > 1:
|
||||
raise Exception(
|
||||
f"Too many previous mean files {overload_path}/Mean_L{id}*.npy"
|
||||
)
|
||||
raise Exception(f"Too many previous mean files {filename_wilcard}")
|
||||
|
||||
if len(file_to_load) == 1:
|
||||
network[id].mean = torch.tensor(
|
||||
|
|
|
@ -3,7 +3,7 @@ import time
|
|||
from network.Parameter import Config
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from network.SbS import SbS
|
||||
from network.SbSLayer import SbSLayer
|
||||
from network.save_weight_and_bias import save_weight_and_bias
|
||||
from network.SbSReconstruction import SbSReconstruction
|
||||
|
||||
|
@ -19,7 +19,7 @@ def add_weight_and_bias_to_histogram(
|
|||
# ################################################
|
||||
# Log the SbS Weights
|
||||
# ################################################
|
||||
if isinstance(network[id], SbS) is True:
|
||||
if isinstance(network[id], SbSLayer) is True:
|
||||
if network[id]._w_trainable is True:
|
||||
|
||||
try:
|
||||
|
@ -175,7 +175,7 @@ def forward_pass_train(
|
|||
.to(device=device)
|
||||
)
|
||||
for id in range(0, len(network)):
|
||||
if isinstance(network[id], SbS) is True:
|
||||
if isinstance(network[id], SbSLayer) is True:
|
||||
h_collection.append(network[id](h_collection[-1], labels))
|
||||
else:
|
||||
h_collection.append(network[id](h_collection[-1]))
|
||||
|
@ -203,7 +203,7 @@ def forward_pass_test(
|
|||
)
|
||||
for id in range(0, len(network)):
|
||||
if (cfg.extract_noisy_pictures is True) or (overwrite_number_of_spikes != -1):
|
||||
if isinstance(network[id], SbS) is True:
|
||||
if isinstance(network[id], SbSLayer) is True:
|
||||
h_collection.append(
|
||||
network[id](
|
||||
h_collection[-1],
|
||||
|
@ -228,7 +228,7 @@ def run_optimizer(
|
|||
cfg: Config,
|
||||
) -> None:
|
||||
for id in range(0, len(network)):
|
||||
if isinstance(network[id], SbS) is True:
|
||||
if isinstance(network[id], SbSLayer) is True:
|
||||
network[id].update_pre_care()
|
||||
|
||||
for optimizer_item in optimizer:
|
||||
|
@ -236,7 +236,7 @@ def run_optimizer(
|
|||
optimizer_item.step()
|
||||
|
||||
for id in range(0, len(network)):
|
||||
if isinstance(network[id], SbS) is True:
|
||||
if isinstance(network[id], SbSLayer) is True:
|
||||
network[id].update_after_care(
|
||||
cfg.learning_parameters.learning_rate_threshold_w
|
||||
/ float(
|
||||
|
@ -288,11 +288,11 @@ def run_lr_scheduler(
|
|||
def deal_with_gradient_scale(epoch_id: int, mini_batch_number: int, network):
|
||||
if (epoch_id == 0) and (mini_batch_number == 0):
|
||||
for id in range(0, len(network)):
|
||||
if isinstance(network[id], SbS) is True:
|
||||
if isinstance(network[id], SbSLayer) is True:
|
||||
network[id].after_batch(True)
|
||||
else:
|
||||
for id in range(0, len(network)):
|
||||
if isinstance(network[id], SbS) is True:
|
||||
if isinstance(network[id], SbSLayer) is True:
|
||||
network[id].after_batch()
|
||||
|
||||
|
||||
|
@ -309,6 +309,7 @@ def loop_train(
|
|||
tb: SummaryWriter,
|
||||
lr_scheduler,
|
||||
last_test_performance: float,
|
||||
order_id: float | int | None = None,
|
||||
) -> tuple[float, float, float, float]:
|
||||
|
||||
correct_in_minibatch: int = 0
|
||||
|
@ -529,7 +530,10 @@ def loop_train(
|
|||
# Save the Weights and Biases
|
||||
# ################################################
|
||||
save_weight_and_bias(
|
||||
cfg=cfg, network=network, iteration_number=epoch_id
|
||||
cfg=cfg,
|
||||
network=network,
|
||||
iteration_number=epoch_id,
|
||||
order_id=order_id,
|
||||
)
|
||||
|
||||
# ################################################
|
||||
|
|
|
@ -3,14 +3,23 @@ import torch
|
|||
from network.Parameter import Config
|
||||
|
||||
import numpy as np
|
||||
from network.SbS import SbS
|
||||
from network.SbSLayer import SbSLayer
|
||||
from network.SplitOnOffLayer import SplitOnOffLayer
|
||||
from network.Conv2dApproximation import Conv2dApproximation
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def save_weight_and_bias(
|
||||
cfg: Config, network: torch.nn.modules.container.Sequential, iteration_number: int
|
||||
cfg: Config,
|
||||
network: torch.nn.modules.container.Sequential,
|
||||
iteration_number: int,
|
||||
order_id: float | int | None = None,
|
||||
) -> None:
|
||||
if order_id is None:
|
||||
post_fix: str = ""
|
||||
else:
|
||||
post_fix = f"_{order_id}"
|
||||
|
||||
for id in range(0, len(network)):
|
||||
|
||||
|
@ -18,11 +27,14 @@ def save_weight_and_bias(
|
|||
# Save the SbS Weights
|
||||
# ################################################
|
||||
|
||||
if isinstance(network[id], SbS) is True:
|
||||
if isinstance(network[id], SbSLayer) is True:
|
||||
if network[id]._w_trainable is True:
|
||||
|
||||
np.save(
|
||||
f"{cfg.weight_path}/Weight_L{id}_S{iteration_number}.npy",
|
||||
os.path.join(
|
||||
cfg.weight_path,
|
||||
f"Weight_L{id}_S{iteration_number}{post_fix}.npy",
|
||||
),
|
||||
network[id].weights.detach().cpu().numpy(),
|
||||
)
|
||||
|
||||
|
@ -34,13 +46,18 @@ def save_weight_and_bias(
|
|||
if network[id]._w_trainable is True:
|
||||
# Save the new values
|
||||
np.save(
|
||||
f"{cfg.weight_path}/Weight_L{id}_S{iteration_number}.npy",
|
||||
os.path.join(
|
||||
cfg.weight_path,
|
||||
f"Weight_L{id}_S{iteration_number}{post_fix}.npy",
|
||||
),
|
||||
network[id]._parameters["weight"].data.detach().cpu().numpy(),
|
||||
)
|
||||
|
||||
# Save the new values
|
||||
np.save(
|
||||
f"{cfg.weight_path}/Bias_L{id}_S{iteration_number}.npy",
|
||||
os.path.join(
|
||||
cfg.weight_path, f"Bias_L{id}_S{iteration_number}{post_fix}.npy"
|
||||
),
|
||||
network[id]._parameters["bias"].data.detach().cpu().numpy(),
|
||||
)
|
||||
|
||||
|
@ -52,20 +69,28 @@ def save_weight_and_bias(
|
|||
if network[id]._w_trainable is True:
|
||||
# Save the new values
|
||||
np.save(
|
||||
f"{cfg.weight_path}/Weight_L{id}_S{iteration_number}.npy",
|
||||
os.path.join(
|
||||
cfg.weight_path,
|
||||
f"Weight_L{id}_S{iteration_number}{post_fix}.npy",
|
||||
),
|
||||
network[id].weights.data.detach().cpu().numpy(),
|
||||
)
|
||||
|
||||
# Save the new values
|
||||
if network[id].bias is not None:
|
||||
np.save(
|
||||
f"{cfg.weight_path}/Bias_L{id}_S{iteration_number}.npy",
|
||||
os.path.join(
|
||||
cfg.weight_path,
|
||||
f"Bias_L{id}_S{iteration_number}{post_fix}.npy",
|
||||
),
|
||||
network[id].bias.data.detach().cpu().numpy(),
|
||||
)
|
||||
|
||||
if isinstance(network[id], SplitOnOffLayer) is True:
|
||||
|
||||
np.save(
|
||||
f"{cfg.weight_path}/Mean_L{id}_S{iteration_number}.npy",
|
||||
os.path.join(
|
||||
cfg.weight_path, f"Mean_L{id}_S{iteration_number}{post_fix}.npy"
|
||||
),
|
||||
network[id].mean.detach().cpu().numpy(),
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue