diff --git a/network/Dataset.py b/network/Dataset.py index 44c314c..83b18fb 100644 --- a/network/Dataset.py +++ b/network/Dataset.py @@ -12,6 +12,8 @@ class DatasetMaster(torch.utils.data.Dataset, ABC): pattern_storage: np.ndarray number_of_pattern: int mean: list[float] + initial_size: list[int] + channel_size: int # Initialize def __init__( @@ -36,6 +38,9 @@ class DatasetMaster(torch.utils.data.Dataset, ABC): self.mean = [] + self.initial_size = [0, 0] + self.channel_size = 0 + def __len__(self) -> int: return self.number_of_pattern @@ -74,6 +79,9 @@ class DatasetMNIST(DatasetMaster): mean = self.pattern_storage.mean(3).mean(2).mean(0) self.mean = [*mean] + self.initial_size = [28, 28] + self.channel_size = 1 + def __getitem__(self, index: int) -> tuple[torch.Tensor, int]: image = self.pattern_storage[index, 0:1, :, :] @@ -154,6 +162,9 @@ class DatasetFashionMNIST(DatasetMaster): mean = self.pattern_storage.mean(3).mean(2).mean(0) self.mean = [*mean] + self.initial_size = [28, 28] + self.channel_size = 1 + def __getitem__(self, index: int) -> tuple[torch.Tensor, int]: image = self.pattern_storage[index, 0:1, :, :] @@ -240,6 +251,9 @@ class DatasetCIFAR(DatasetMaster): mean = self.pattern_storage.mean(3).mean(2).mean(0) self.mean = [*mean] + self.initial_size = [32, 32] + self.channel_size = 3 + def __getitem__(self, index: int) -> tuple[torch.Tensor, int]: image = self.pattern_storage[index, :, :, :] diff --git a/network/HDynamicLayer.py b/network/HDynamicLayer.py new file mode 100644 index 0000000..9d00f80 --- /dev/null +++ b/network/HDynamicLayer.py @@ -0,0 +1,451 @@ +import torch + +from network.PyHDynamicCNNCPU import HDynamicCNNCPU +from network.PyHDynamicCNNGPU import HDynamicCNNGPU + +global_sbs_gpu_setting: list[torch.Tensor] = [] +global_sbs_size: list[torch.Tensor] = [] +global_sbs_hdynamic_cpp: list[HDynamicCNNCPU | HDynamicCNNGPU] = [] + + +class HDynamicLayer(torch.nn.Module): + + _sbs_gpu_setting_position: int + _sbs_hdynamic_cpp_position: int + _gpu_tuning_factor: int + _number_of_cpu_processes: int + _output_size: list[int] + _w_trainable: bool + _output_layer: bool + _local_learning: bool + device: torch.device + default_dtype: torch.dtype + + def __init__( + self, + output_size: list[int], + output_layer: bool = False, + local_learning: bool = False, + number_of_cpu_processes: int = 1, + w_trainable: bool = False, + skip_gradient_calculation: bool = False, + device: torch.device | None = None, + default_dtype: torch.dtype | None = None, + gpu_tuning_factor: int = 5, + ) -> None: + super().__init__() + + assert device is not None + self.device = device + self.default_dtype = default_dtype + + self._gpu_tuning_factor = int(gpu_tuning_factor) + self._number_of_cpu_processes = int(number_of_cpu_processes) + self._w_trainable = bool(w_trainable) + self._skip_gradient_calculation = bool(skip_gradient_calculation) + self._output_size = output_size + self._output_layer = bool(output_layer) + self._local_learning = bool(local_learning) + + global_sbs_gpu_setting.append(torch.tensor([0])) + global_sbs_size.append(torch.tensor([0, 0, 0, 0])) + + if device == torch.device("cpu"): + global_sbs_hdynamic_cpp.append(HDynamicCNNCPU()) + else: + global_sbs_hdynamic_cpp.append(HDynamicCNNGPU()) + + self._sbs_gpu_setting_position = len(global_sbs_gpu_setting) - 1 + self._sbs_hdynamic_cpp_position = len(global_sbs_hdynamic_cpp) - 1 + + self.functional_sbs = FunctionalSbS.apply + + #################################################################### + # Forward # + #################################################################### + + def forward( + self, + input: torch.Tensor, + spike: torch.Tensor, + epsilon_xy: torch.Tensor, + epsilon_t_0: torch.Tensor, + weights: torch.Tensor, + h_initial: torch.Tensor, + last_grad_scale: torch.Tensor, + labels: torch.Tensor | None = None, + keep_last_grad_scale: bool = False, + disable_scale_grade: bool = True, + forgetting_offset: float = -1.0, + ) -> torch.Tensor: + + if labels is None: + labels_copy: torch.Tensor = torch.tensor( + [], dtype=torch.int64, device=self.device + ) + else: + labels_copy = ( + labels.detach().clone().type(dtype=torch.int64).to(device=self.device) + ) + + if (spike.shape[-2] * spike.shape[-1]) > self._gpu_tuning_factor: + gpu_tuning_factor = self._gpu_tuning_factor + else: + gpu_tuning_factor = 0 + + parameter_list = torch.tensor( + [ + int(self._number_of_cpu_processes), # 0 + int(self._output_size[0]), # 1 + int(self._output_size[1]), # 2 + int(gpu_tuning_factor), # 3 + int(self._sbs_gpu_setting_position), # 4 + int(self._sbs_hdynamic_cpp_position), # 5 + int(self._w_trainable), # 6 + int(disable_scale_grade), # 7 + int(keep_last_grad_scale), # 8 + int(self._skip_gradient_calculation), # 9 + int(self._output_layer), # 10 + int(self._local_learning), # 11 + ], + dtype=torch.int64, + ) + + # SbS forward functional + return self.functional_sbs( + input, + spike, + epsilon_xy, + epsilon_t_0, + weights, + h_initial, + parameter_list, + last_grad_scale, + torch.tensor( + forgetting_offset, device=self.device, dtype=self.default_dtype + ), + labels_copy, + ) + + +class FunctionalSbS(torch.autograd.Function): + @staticmethod + def forward( # type: ignore + ctx, + input: torch.Tensor, + spikes: torch.Tensor, + epsilon_xy: torch.Tensor | None, + epsilon_t_0: torch.Tensor, + weights: torch.Tensor, + h_initial: torch.Tensor, + parameter_list: torch.Tensor, + grad_output_scale: torch.Tensor, + forgetting_offset: torch.Tensor, + labels: torch.Tensor, + ) -> torch.Tensor: + + number_of_spikes: int = int(spikes.shape[1]) + + if input.device == torch.device("cpu"): + hdyn_number_of_cpu_processes: int = int(parameter_list[0]) + else: + hdyn_number_of_cpu_processes = -1 + + output_size_0: int = int(parameter_list[1]) + output_size_1: int = int(parameter_list[2]) + gpu_tuning_factor: int = int(parameter_list[3]) + + sbs_gpu_setting_position = int(parameter_list[4]) + sbs_hdynamic_cpp_position = int(parameter_list[5]) + + # ########################################################### + # H dynamic + # ########################################################### + + assert epsilon_t_0.ndim == 1 + assert epsilon_t_0.shape[0] >= number_of_spikes + + # ############################################ + # Make space for the results + # ############################################ + + output = torch.empty( + ( + int(input.shape[0]), + int(weights.shape[1]), + output_size_0, + output_size_1, + ), + dtype=input.dtype, + device=input.device, + ) + + assert output.is_contiguous() is True + if epsilon_xy is not None: + assert epsilon_xy.is_contiguous() is True + assert epsilon_xy.ndim == 3 + assert epsilon_t_0.is_contiguous() is True + assert weights.is_contiguous() is True + assert spikes.is_contiguous() is True + assert h_initial.is_contiguous() is True + + assert weights.ndim == 2 + assert h_initial.ndim == 1 + + sbs_profile = global_sbs_gpu_setting[sbs_gpu_setting_position].clone() + + sbs_size = global_sbs_size[sbs_gpu_setting_position].clone() + + if input.device != torch.device("cpu"): + if ( + (sbs_profile.numel() == 1) + or (sbs_size[0] != int(output.shape[0])) + or (sbs_size[1] != int(output.shape[1])) + or (sbs_size[2] != int(output.shape[2])) + or (sbs_size[3] != int(output.shape[3])) + ): + sbs_profile = torch.zeros( + (14, 7), dtype=torch.int64, device=torch.device("cpu") + ) + + global_sbs_hdynamic_cpp[sbs_hdynamic_cpp_position].gpu_occupancy_export( + int(output.shape[2]), + int(output.shape[3]), + int(output.shape[0]), + int(output.shape[1]), + sbs_profile.data_ptr(), + int(sbs_profile.shape[0]), + int(sbs_profile.shape[1]), + ) + global_sbs_gpu_setting[sbs_gpu_setting_position] = sbs_profile.clone() + sbs_size[0] = int(output.shape[0]) + sbs_size[1] = int(output.shape[1]) + sbs_size[2] = int(output.shape[2]) + sbs_size[3] = int(output.shape[3]) + global_sbs_size[sbs_gpu_setting_position] = sbs_size.clone() + + else: + global_sbs_hdynamic_cpp[sbs_hdynamic_cpp_position].gpu_occupancy_import( + sbs_profile.data_ptr(), + int(sbs_profile.shape[0]), + int(sbs_profile.shape[1]), + ) + + global_sbs_hdynamic_cpp[sbs_hdynamic_cpp_position].update( + output.data_ptr(), + int(output.shape[0]), + int(output.shape[1]), + int(output.shape[2]), + int(output.shape[3]), + epsilon_xy.data_ptr() if epsilon_xy is not None else int(0), + int(epsilon_xy.shape[0]) if epsilon_xy is not None else int(0), + int(epsilon_xy.shape[1]) if epsilon_xy is not None else int(0), + int(epsilon_xy.shape[2]) if epsilon_xy is not None else int(0), + epsilon_t_0.data_ptr(), + int(epsilon_t_0.shape[0]), + weights.data_ptr(), + int(weights.shape[0]), + int(weights.shape[1]), + spikes.data_ptr(), + int(spikes.shape[0]), + int(spikes.shape[1]), + int(spikes.shape[2]), + int(spikes.shape[3]), + h_initial.data_ptr(), + int(h_initial.shape[0]), + hdyn_number_of_cpu_processes, + float(forgetting_offset.cpu().item()), + int(gpu_tuning_factor), + ) + + # ########################################################### + # Save the necessary data for the backward pass + # ########################################################### + + ctx.save_for_backward( + input, + weights, + output, + parameter_list, + grad_output_scale, + labels, + ) + + return output + + @staticmethod + def backward(ctx, grad_output): + # ############################################## + # Get the variables back + # ############################################## + ( + input, + weights, + output, + parameter_list, + last_grad_scale, + labels, + ) = ctx.saved_tensors + + assert labels.numel() > 0 + + # ############################################## + # Default output + # ############################################## + grad_input = None + grad_spikes = None + grad_eps_xy = None + grad_epsilon_t_0 = None + grad_weights = None + grad_h_initial = None + grad_parameter_list = None + grad_forgetting_offset = None + grad_labels = None + + # ############################################## + # Parameters + # ############################################## + parameter_w_trainable: bool = bool(parameter_list[6]) + parameter_disable_scale_grade: bool = bool(parameter_list[7]) + parameter_keep_last_grad_scale: bool = bool(parameter_list[8]) + parameter_skip_gradient_calculation: bool = bool(parameter_list[9]) + parameter_output_layer: bool = bool(parameter_list[10]) + parameter_local_learning: bool = bool(parameter_list[11]) + + # ############################################## + # Dealing with overall scale of the gradient + # ############################################## + if parameter_disable_scale_grade is False: + if parameter_keep_last_grad_scale is True: + last_grad_scale = torch.tensor( + [torch.abs(grad_output).max(), last_grad_scale] + ).max() + grad_output /= last_grad_scale + grad_output_scale = last_grad_scale.clone() + + input /= input.sum(dim=1, keepdim=True, dtype=weights.dtype) + + # ################################################# + # User doesn't want us to calculate the gradients + # ################################################# + + if parameter_skip_gradient_calculation is True: + + return ( + grad_input, + grad_spikes, + grad_eps_xy, + grad_epsilon_t_0, + grad_weights, + grad_h_initial, + grad_parameter_list, + grad_output_scale, + grad_forgetting_offset, + grad_labels, + ) + + # ################################################# + # Calculate backprop error (grad_input) + # ################################################# + + backprop_r: torch.Tensor = weights.unsqueeze(0).unsqueeze(-1).unsqueeze( + -1 + ) * output.unsqueeze(1) + + backprop_bigr: torch.Tensor = backprop_r.sum(dim=2) + + backprop_z: torch.Tensor = backprop_r * ( + 1.0 / (backprop_bigr + 1e-20) + ).unsqueeze(2) + grad_input: torch.Tensor = (backprop_z * grad_output.unsqueeze(1)).sum(2) + del backprop_z + + # ################################################# + # Calculate weight gradient (grad_weights) + # ################################################# + + if parameter_w_trainable is False: + + # ################################################# + # We don't train this weight + # ################################################# + grad_weights = None + + elif (parameter_output_layer is False) and (parameter_local_learning is True): + # ################################################# + # Local learning + # ################################################# + grad_weights = ( + (-2 * (input - backprop_bigr).unsqueeze(2) * output.unsqueeze(1)) + .sum(0) + .sum(-1) + .sum(-1) + ) + + elif (parameter_output_layer is True) and (parameter_local_learning is True): + + target_one_hot: torch.Tensor = torch.zeros( + ( + labels.shape[0], + output.shape[1], + ), + device=input.device, + dtype=input.dtype, + ) + + target_one_hot.scatter_( + 1, + labels.to(input.device).unsqueeze(1), + torch.ones( + (labels.shape[0], 1), + device=input.device, + dtype=input.dtype, + ), + ) + target_one_hot = target_one_hot.unsqueeze(-1).unsqueeze(-1) + + # (-2 * (input - backprop_bigr).unsqueeze(2) * (target_one_hot-output).unsqueeze(1)) + # (-2 * input.unsqueeze(2) * (target_one_hot-output).unsqueeze(1)) + grad_weights = ( + ( + -2 + * (input - backprop_bigr).unsqueeze(2) + * target_one_hot.unsqueeze(1) + ) + .sum(0) + .sum(-1) + .sum(-1) + ) + + else: + # ################################################# + # Backprop + # ################################################# + backprop_f: torch.Tensor = output.unsqueeze(1) * ( + input / (backprop_bigr**2 + 1e-20) + ).unsqueeze(2) + + result_omega: torch.Tensor = backprop_bigr.unsqueeze( + 2 + ) * grad_output.unsqueeze(1) + result_omega -= (backprop_r * grad_output.unsqueeze(1)).sum(2).unsqueeze(2) + result_omega *= backprop_f + del backprop_f + grad_weights = result_omega.sum(0).sum(-1).sum(-1) + del result_omega + + del backprop_bigr + del backprop_r + + return ( + grad_input, + grad_spikes, + grad_eps_xy, + grad_epsilon_t_0, + grad_weights, + grad_h_initial, + grad_parameter_list, + grad_output_scale, + grad_forgetting_offset, + grad_labels, + ) diff --git a/network/InputSpikeImage.py b/network/InputSpikeImage.py new file mode 100644 index 0000000..03de7da --- /dev/null +++ b/network/InputSpikeImage.py @@ -0,0 +1,104 @@ +import torch + +from network.SpikeLayer import SpikeLayer +from network.SpikeCountLayer import SpikeCountLayer + + +class InputSpikeImage(torch.nn.Module): + + _reshape: bool + _normalize: bool + _device: torch.device + + number_of_spikes: int + + def __init__( + self, + number_of_spikes: int = -1, + number_of_cpu_processes: int = 1, + reshape: bool = False, + normalize: bool = True, + device: torch.device | None = None, + ) -> None: + super().__init__() + + assert device is not None + self._device = device + + self._reshape = bool(reshape) + self._normalize = bool(normalize) + + self.number_of_spikes = int(number_of_spikes) + + if device != torch.device("cpu"): + number_of_cpu_processes_spike_generator = 0 + else: + number_of_cpu_processes_spike_generator = number_of_cpu_processes + + self.spike_generator = SpikeLayer( + number_of_cpu_processes=number_of_cpu_processes_spike_generator, + device=device, + ) + + self.spike_count = SpikeCountLayer( + number_of_cpu_processes=number_of_cpu_processes + ) + + #################################################################### + # Forward # + #################################################################### + + def forward(self, input: torch.Tensor) -> torch.Tensor: + + if self.number_of_spikes < 1: + return input + + input_shape: list[int] = [ + int(input.shape[0]), + int(input.shape[1]), + int(input.shape[2]), + int(input.shape[3]), + ] + + if self._reshape is True: + input_work = ( + input.detach() + .clone() + .to(self._device) + .reshape( + (input_shape[0], input_shape[1] * input_shape[2] * input_shape[3]) + ) + .unsqueeze(-1) + .unsqueeze(-1) + ) + else: + input_work = input.detach().clone().to(self._device) + + spikes = self.spike_generator( + input=input_work, number_of_spikes=self.number_of_spikes + ) + + if self._reshape is True: + dim_s: int = input_shape[1] * input_shape[2] * input_shape[3] + else: + dim_s = input_shape[1] + + output: torch.Tensor = self.spike_count(spikes, dim_s) + + if self._reshape is True: + + output = ( + output.squeeze(-1) + .squeeze(-1) + .reshape( + (input_shape[0], input_shape[1], input_shape[2], input_shape[3]) + ) + ) + + if self._normalize is True: + output = output.type(dtype=input_work.dtype) + output = output / output.sum(dim=-1, keepdim=True).sum( + dim=-2, keepdim=True + ).sum(dim=-3, keepdim=True) + + return output diff --git a/network/Makefile b/network/Makefile index 3f14481..d122941 100644 --- a/network/Makefile +++ b/network/Makefile @@ -3,20 +3,22 @@ export all: cd h_dynamic_cnn_cpu_cpp && $(MAKE) all - cd h_dynamic_cnn_gpu_cpp && $(MAKE) all + cd h_dynamic_cnn_gpu_cpp_v1 && $(MAKE) all cd spike_generation_cpu_cpp && $(MAKE) all cd spike_generation_gpu_cpp_v2 && $(MAKE) all cd multiplication_approximation_cpu_cpp && $(MAKE) all cd multiplication_approximation_gpu_cpp && $(MAKE) all cd count_spikes_cpu_cpp && $(MAKE) all + cd sort_spikes_cpu_cpp && $(MAKE) all $(PYBIN)python3 pybind11_auto_pyi.py clean: cd h_dynamic_cnn_cpu_cpp && $(MAKE) clean - cd h_dynamic_cnn_gpu_cpp && $(MAKE) clean + cd h_dynamic_cnn_gpu_cpp_v1 && $(MAKE) clean cd spike_generation_cpu_cpp && $(MAKE) clean cd spike_generation_gpu_cpp_v2 && $(MAKE) clean cd multiplication_approximation_cpu_cpp && $(MAKE) clean cd multiplication_approximation_gpu_cpp && $(MAKE) clean cd count_spikes_cpu_cpp && $(MAKE) clean + cd sort_spikes_cpu_cpp && $(MAKE) clean diff --git a/network/SbSLayer.py b/network/SbSLayer.py new file mode 100644 index 0000000..c4db2a8 --- /dev/null +++ b/network/SbSLayer.py @@ -0,0 +1,480 @@ +import torch + +from network.SpikeLayer import SpikeLayer +from network.HDynamicLayer import HDynamicLayer + +from network.calculate_output_size import calculate_output_size +from network.SortSpikesLayer import SortSpikesLayer + + +class SbSLayer(torch.nn.Module): + + _epsilon_xy: torch.Tensor | None = None + _epsilon_xy_use: bool + _epsilon_0: float + _weights: torch.nn.parameter.Parameter + _weights_exists: bool = False + _kernel_size: list[int] + _stride: list[int] + _dilation: list[int] + _padding: list[int] + _output_size: torch.Tensor + _number_of_spikes: int + _number_of_cpu_processes: int + _number_of_neurons: int + _number_of_input_neurons: int + _epsilon_xy_intitial: float + _h_initial: torch.Tensor | None = None + _w_trainable: bool + _last_grad_scale: torch.nn.parameter.Parameter + _keep_last_grad_scale: bool + _disable_scale_grade: bool + _forgetting_offset: float + _weight_noise_range: list[float] + _skip_gradient_calculation: bool + _is_pooling_layer: bool + _input_size: list[int] + _output_layer: bool = False + _local_learning: bool = False + + device: torch.device + default_dtype: torch.dtype + _gpu_tuning_factor: int + + _max_grad_weights: torch.Tensor | None = None + + _number_of_grad_weight_contributions: float = 0.0 + + last_input_store: bool = False + last_input_data: torch.Tensor | None = None + + _cooldown_after_number_of_spikes: int = -1 + _reduction_cooldown: float = 1.0 + _layer_id: int = -1 + + spike_full_layer_input_distribution: bool = False + + def __init__( + self, + number_of_input_neurons: int, + number_of_neurons: int, + input_size: list[int], + forward_kernel_size: list[int], + number_of_spikes: int, + epsilon_xy_intitial: float = 0.1, + epsilon_xy_use: bool = False, + epsilon_0: float = 1.0, + weight_noise_range: list[float] = [0.0, 1.0], + is_pooling_layer: bool = False, + strides: list[int] = [1, 1], + dilation: list[int] = [0, 0], + padding: list[int] = [0, 0], + number_of_cpu_processes: int = 1, + w_trainable: bool = False, + keep_last_grad_scale: bool = False, + disable_scale_grade: bool = True, + forgetting_offset: float = -1.0, + skip_gradient_calculation: bool = False, + device: torch.device | None = None, + default_dtype: torch.dtype | None = None, + gpu_tuning_factor: int = 10, + layer_id: int = -1, + cooldown_after_number_of_spikes: int = -1, + reduction_cooldown: float = 1.0, + ) -> None: + super().__init__() + + assert device is not None + assert default_dtype is not None + self.device = device + self.default_dtype = default_dtype + + self._w_trainable = bool(w_trainable) + self._keep_last_grad_scale = bool(keep_last_grad_scale) + self._skip_gradient_calculation = bool(skip_gradient_calculation) + self._disable_scale_grade = bool(disable_scale_grade) + self._epsilon_xy_intitial = float(epsilon_xy_intitial) + self._stride = strides + self._dilation = dilation + self._padding = padding + self._kernel_size = forward_kernel_size + self._number_of_input_neurons = int(number_of_input_neurons) + self._number_of_neurons = int(number_of_neurons) + self._epsilon_0 = float(epsilon_0) + self._number_of_cpu_processes = int(number_of_cpu_processes) + self._number_of_spikes = int(number_of_spikes) + self._weight_noise_range = weight_noise_range + self._is_pooling_layer = bool(is_pooling_layer) + self._cooldown_after_number_of_spikes = int(cooldown_after_number_of_spikes) + self.reduction_cooldown = float(reduction_cooldown) + self._layer_id = layer_id + self._epsilon_xy_use = epsilon_xy_use + + assert len(input_size) == 2 + self._input_size = input_size + + # The GPU hates me... + # Too many SbS threads == bad + # Thus I need to limit them... + # (Reminder: We cannot access the mini-batch size here, + # which is part of the GPU thread size calculation...) + + self._last_grad_scale = torch.nn.parameter.Parameter( + torch.tensor(-1.0, dtype=self.default_dtype), + requires_grad=True, + ) + + self._forgetting_offset = float(forgetting_offset) + + self._output_size = calculate_output_size( + value=input_size, + kernel_size=self._kernel_size, + stride=self._stride, + dilation=self._dilation, + padding=self._padding, + ) + + self.set_h_init_to_uniform() + + self.spike_generator = SpikeLayer( + number_of_spikes=self._number_of_spikes, + number_of_cpu_processes=self._number_of_cpu_processes, + device=self.device, + ) + + self.h_dynamic = HDynamicLayer( + output_size=self._output_size.tolist(), + output_layer=self._output_layer, + local_learning=self._local_learning, + number_of_cpu_processes=number_of_cpu_processes, + w_trainable=w_trainable, + skip_gradient_calculation=skip_gradient_calculation, + device=device, + default_dtype=self.default_dtype, + gpu_tuning_factor=gpu_tuning_factor, + ) + + assert len(input_size) >= 2 + self.spikes_sorter = SortSpikesLayer( + kernel_size=self._kernel_size, + input_shape=[ + self._number_of_input_neurons, + int(input_size[0]), + int(input_size[1]), + ], + output_size=self._output_size.clone(), + strides=self._stride, + dilation=self._dilation, + padding=self._padding, + number_of_cpu_processes=number_of_cpu_processes, + ) + + # TODO: TEST + if layer_id == 0: + self.spike_full_layer_input_distribution = True + + # ############################################################### + # Initialize the weights + # ############################################################### + + if self._is_pooling_layer is True: + self.weights = self._make_pooling_weights() + + else: + assert len(self._weight_noise_range) == 2 + weights = torch.empty( + ( + int(self._kernel_size[0]) + * int(self._kernel_size[1]) + * int(self._number_of_input_neurons), + int(self._number_of_neurons), + ), + dtype=self.default_dtype, + device=self.device, + ) + + torch.nn.init.uniform_( + weights, + a=float(self._weight_noise_range[0]), + b=float(self._weight_noise_range[1]), + ) + self.weights = weights + + #################################################################### + # Variables in and out # + #################################################################### + + def get_epsilon_t(self, number_of_spikes: int): + """Generates the time series of the basic epsilon.""" + t = ( + torch.arange( + 0, number_of_spikes, dtype=self.default_dtype, device=self.device + ) + + 1 + ) + + # torch.ones((number_of_spikes), dtype=self.default_dtype, device=self.device + epsilon_t: torch.Tensor = t ** (-1.0 / 2.0) + + if (self._cooldown_after_number_of_spikes < number_of_spikes) and ( + self._cooldown_after_number_of_spikes >= 0 + ): + epsilon_t[ + self._cooldown_after_number_of_spikes : number_of_spikes + ] /= self._reduction_cooldown + return epsilon_t + + @property + def weights(self) -> torch.Tensor | None: + if self._weights_exists is False: + return None + else: + return self._weights + + @weights.setter + def weights(self, value: torch.Tensor): + assert value is not None + assert torch.is_tensor(value) is True + assert value.dim() == 2 + temp: torch.Tensor = ( + value.detach() + .clone(memory_format=torch.contiguous_format) + .type(dtype=self.default_dtype) + .to(device=self.device) + ) + temp /= temp.sum(dim=0, keepdim=True, dtype=self.default_dtype) + if self._weights_exists is False: + self._weights = torch.nn.parameter.Parameter(temp, requires_grad=True) + self._weights_exists = True + else: + self._weights.data = temp + + @property + def h_initial(self) -> torch.Tensor | None: + return self._h_initial + + @h_initial.setter + def h_initial(self, value: torch.Tensor): + assert value is not None + assert torch.is_tensor(value) is True + assert value.dim() == 1 + assert value.dtype == self.default_dtype + self._h_initial = ( + value.detach() + .clone(memory_format=torch.contiguous_format) + .type(dtype=self.default_dtype) + .to(device=self.device) + .requires_grad_(False) + ) + + def update_pre_care(self): + + if self._weights.grad is not None: + assert self._number_of_grad_weight_contributions > 0 + self._weights.grad /= self._number_of_grad_weight_contributions + self._number_of_grad_weight_contributions = 0.0 + + def update_after_care(self, threshold_weight: float): + + if self._w_trainable is True: + self.norm_weights() + self.threshold_weights(threshold_weight) + self.norm_weights() + + def after_batch(self, new_state: bool = False): + if self._keep_last_grad_scale is True: + self._last_grad_scale.data = self._last_grad_scale.grad + self._keep_last_grad_scale = new_state + + self._last_grad_scale.grad = torch.zeros_like(self._last_grad_scale.grad) + + #################################################################### + # Helper functions # + #################################################################### + + def _make_pooling_weights(self) -> torch.Tensor: + """For generating the pooling weights.""" + + assert self._number_of_neurons is not None + assert self._kernel_size is not None + + weights: torch.Tensor = torch.zeros( + ( + int(self._kernel_size[0]), + int(self._kernel_size[1]), + int(self._number_of_neurons), + int(self._number_of_neurons), + ), + dtype=self.default_dtype, + device=self.device, + ) + + for i in range(0, int(self._number_of_neurons)): + weights[:, :, i, i] = 1.0 + + weights = weights.moveaxis(-1, 0).moveaxis(-1, 1) + + weights = torch.nn.functional.unfold( + input=weights, + kernel_size=(int(self._kernel_size[0]), int(self._kernel_size[1])), + dilation=(1, 1), + padding=(0, 0), + stride=(1, 1), + ).squeeze() + + weights = torch.moveaxis(weights, 0, 1) + + return weights + + def set_h_init_to_uniform(self) -> None: + + assert self._number_of_neurons > 2 + + self.h_initial: torch.Tensor = torch.full( + (self._number_of_neurons,), + (1.0 / float(self._number_of_neurons)), + dtype=self.default_dtype, + device=self.device, + ) + + def norm_weights(self) -> None: + assert self._weights_exists is True + temp: torch.Tensor = ( + self._weights.data.detach() + .clone(memory_format=torch.contiguous_format) + .type(dtype=self.default_dtype) + .to(device=self.device) + ) + temp /= temp.sum(dim=0, keepdim=True, dtype=self.default_dtype) + self._weights.data = temp + + def threshold_weights(self, threshold: float) -> None: + assert self._weights_exists is True + assert threshold >= 0 + + torch.clamp( + self._weights.data, + min=float(threshold), + max=None, + out=self._weights.data, + ) + + #################################################################### + # Forward # + #################################################################### + + def forward( + self, + input: torch.Tensor, + labels: torch.Tensor | None = None, + ) -> torch.Tensor: + + # Are we happy with the input? + assert input is not None + assert torch.is_tensor(input) is True + assert input.dim() == 4 + assert input.dtype == self.default_dtype + assert input.shape[1] == self._number_of_input_neurons + assert input.shape[2] == self._input_size[0] + assert input.shape[3] == self._input_size[1] + + # Are we happy with the rest of the network? + assert self._epsilon_0 is not None + assert self._h_initial is not None + assert self._forgetting_offset is not None + assert self._weights_exists is True + assert self._weights is not None + + # Convolution of the input... + # Well, this is a convoltion layer + # there needs to be convolution somewhere + input_convolved = torch.nn.functional.fold( + torch.nn.functional.unfold( + input.requires_grad_(True), + kernel_size=(int(self._kernel_size[0]), int(self._kernel_size[1])), + dilation=(int(self._dilation[0]), int(self._dilation[1])), + padding=(int(self._padding[0]), int(self._padding[1])), + stride=(int(self._stride[0]), int(self._stride[1])), + ), + output_size=tuple(self._output_size.tolist()), + kernel_size=(1, 1), + dilation=(1, 1), + padding=(0, 0), + stride=(1, 1), + ) + + # We might need the convolved input for other layers + # let us keep it for the future + if self.last_input_store is True: + self.last_input_data = input_convolved.detach().clone() + self.last_input_data /= self.last_input_data.sum(dim=1, keepdim=True) + else: + self.last_input_data = None + + epsilon_t_0: torch.Tensor = ( + (self.get_epsilon_t(self._number_of_spikes) * self._epsilon_0) + .type(input.dtype) + .to(input.device) + ) + + if (self._epsilon_xy is None) and (self._epsilon_xy_use is True): + self._epsilon_xy = torch.full( + ( + input_convolved.shape[1], + input_convolved.shape[2], + input_convolved.shape[3], + ), + float(self._epsilon_xy_intitial), + dtype=self.default_dtype, + device=self.device, + ) + + if self._epsilon_xy_use is True: + assert self._epsilon_xy is not None + # In the case somebody tried to replace the matrix with wrong dimensions + assert self._epsilon_xy.shape[0] == input_convolved.shape[1] + assert self._epsilon_xy.shape[1] == input_convolved.shape[2] + assert self._epsilon_xy.shape[2] == input_convolved.shape[3] + else: + assert self._epsilon_xy is None + + if self.spike_full_layer_input_distribution is False: + spike = self.spike_generator(input_convolved, int(self._number_of_spikes)) + else: + input_shape = input.shape + input = ( + input.reshape( + (input_shape[0], input_shape[1] * input_shape[2] * input_shape[3]) + ) + .unsqueeze(-1) + .unsqueeze(-1) + ) + spike_unsorted = self.spike_generator(input, int(self._number_of_spikes)) + input = ( + input.squeeze(-1) + .squeeze(-1) + .reshape( + (input_shape[0], input_shape[1], input_shape[2], input_shape[3]) + ) + ) + spike = self.spikes_sorter(spike_unsorted).to(device=input_convolved.device) + + output = self.h_dynamic( + input=input_convolved, + spike=spike, + epsilon_xy=self._epsilon_xy, + epsilon_t_0=epsilon_t_0, + weights=self._weights, + h_initial=self._h_initial, + last_grad_scale=self._last_grad_scale, + labels=labels, + keep_last_grad_scale=self._keep_last_grad_scale, + disable_scale_grade=self._disable_scale_grade, + forgetting_offset=self._forgetting_offset, + ) + + self._number_of_grad_weight_contributions += ( + output.shape[0] * output.shape[-2] * output.shape[-1] + ) + + return output diff --git a/network/SbSReconstruction.py b/network/SbSReconstruction.py index be6420f..0bdb52d 100644 --- a/network/SbSReconstruction.py +++ b/network/SbSReconstruction.py @@ -1,15 +1,15 @@ import torch -from network.SbS import SbS +from network.SbSLayer import SbSLayer class SbSReconstruction(torch.nn.Module): - _the_sbs_layer: SbS + _the_sbs_layer: SbSLayer def __init__( self, - the_sbs_layer: SbS, + the_sbs_layer: SbSLayer, ) -> None: super().__init__() diff --git a/network/SortSpikesLayer.py b/network/SortSpikesLayer.py new file mode 100644 index 0000000..a4c4424 --- /dev/null +++ b/network/SortSpikesLayer.py @@ -0,0 +1,173 @@ +import torch + +from network.PySortSpikesCPU import SortSpikesCPU + + +class SortSpikesLayer(torch.nn.Module): + + _kernel_size: list[int] + _stride: list[int] + _dilation: list[int] + _padding: list[int] + _output_size: torch.Tensor + _number_of_cpu_processes: int + _input_shape: list[int] + + order: torch.Tensor | None = None + order_convoled: torch.Tensor | None = None + indices: torch.Tensor | None = None + + def __init__( + self, + kernel_size: list[int], + input_shape: list[int], + output_size: torch.Tensor, + strides: list[int] = [1, 1], + dilation: list[int] = [0, 0], + padding: list[int] = [0, 0], + number_of_cpu_processes: int = 1, + ) -> None: + + super().__init__() + + self._stride = strides + self._dilation = dilation + self._padding = padding + self._kernel_size = kernel_size + self._output_size = output_size + self._number_of_cpu_processes = number_of_cpu_processes + self._input_shape = input_shape + + self.sort_spikes = SortSpikesCPU() + + self.order = ( + torch.arange( + 0, + self._input_shape[0] * self._input_shape[1] * self._input_shape[2], + device=torch.device("cpu"), + ) + .reshape( + ( + 1, + self._input_shape[0], + self._input_shape[1], + self._input_shape[2], + ) + ) + .type(dtype=torch.float32) + ) + + self.order_convoled = torch.nn.functional.fold( + torch.nn.functional.unfold( + self.order, + kernel_size=( + int(self._kernel_size[0]), + int(self._kernel_size[1]), + ), + dilation=(int(self._dilation[0]), int(self._dilation[1])), + padding=(int(self._padding[0]), int(self._padding[1])), + stride=(int(self._stride[0]), int(self._stride[1])), + ), + output_size=tuple(self._output_size.tolist()), + kernel_size=(1, 1), + dilation=(1, 1), + padding=(0, 0), + stride=(1, 1), + ).type(dtype=torch.int64) + + assert self.order_convoled is not None + + self.order_convoled = self.order_convoled.reshape( + ( + self.order_convoled.shape[1] + * self.order_convoled.shape[2] + * self.order_convoled.shape[3] + ) + ) + + max_length: int = 0 + max_range: int = ( + self._input_shape[0] * self._input_shape[1] * self._input_shape[2] + ) + for id in range(0, max_range): + idx = torch.where(self.order_convoled == id)[0] + max_length = max(max_length, int(idx.shape[0])) + + self.indices = torch.full( + (max_range, max_length), + -1, + dtype=torch.int64, + device=torch.device("cpu"), + ) + + for id in range(0, max_range): + idx = torch.where(self.order_convoled == id)[0] + self.indices[id, 0 : int(idx.shape[0])] = idx + + #################################################################### + # Forward # + #################################################################### + def forward( + self, + input: torch.Tensor, + ) -> torch.Tensor: + + assert len(self._input_shape) == 3 + assert input.shape[-2] == 1 + assert input.shape[-1] == 1 + assert self.indices is not None + + spikes_count = torch.zeros( + (input.shape[0], int(self._output_size[0]), int(self._output_size[1])), + device=torch.device("cpu"), + dtype=torch.int64, + ) + + input_cpu = input.clone().cpu() + + self.sort_spikes.count( + input_cpu.data_ptr(), # Input + int(input_cpu.shape[0]), + int(input_cpu.shape[1]), + int(input_cpu.shape[2]), + int(input_cpu.shape[3]), + spikes_count.data_ptr(), # Output + int(spikes_count.shape[0]), + int(spikes_count.shape[1]), + int(spikes_count.shape[2]), + self.indices.data_ptr(), # Positions + int(self.indices.shape[0]), + int(self.indices.shape[1]), + int(self._number_of_cpu_processes), + ) + + spikes_output = torch.full( + ( + input.shape[0], + int(spikes_count.max()), + int(self._output_size[0]), + int(self._output_size[1]), + ), + -1, + dtype=torch.int64, + device=torch.device("cpu"), + ) + + self.sort_spikes.process( + input_cpu.data_ptr(), # Input + int(input_cpu.shape[0]), + int(input_cpu.shape[1]), + int(input_cpu.shape[2]), + int(input_cpu.shape[3]), + spikes_output.data_ptr(), # Output + int(spikes_output.shape[0]), + int(spikes_output.shape[1]), + int(spikes_output.shape[2]), + int(spikes_output.shape[3]), + self.indices.data_ptr(), # Positions + int(self.indices.shape[0]), + int(self.indices.shape[1]), + int(self._number_of_cpu_processes), + ) + + return spikes_output diff --git a/network/SpikeCountLayer.py b/network/SpikeCountLayer.py new file mode 100644 index 0000000..0d4c6dc --- /dev/null +++ b/network/SpikeCountLayer.py @@ -0,0 +1,52 @@ +import torch + +from network.PyCountSpikesCPU import CountSpikesCPU + + +class SpikeCountLayer(torch.nn.Module): + _number_of_cpu_processes: int + + def __init__( + self, + number_of_cpu_processes: int = 1, + ) -> None: + super().__init__() + + self._number_of_cpu_processes = number_of_cpu_processes + + #################################################################### + # Forward # + #################################################################### + + def forward(self, input: torch.Tensor, dim_s: int) -> torch.Tensor: + + assert input.ndim == 4 + assert dim_s > 0 + + input_cpu = input.cpu() + + histogram = torch.zeros( + ( + int(input.shape[0]), + int(dim_s), + int(input.shape[-2]), + int(input.shape[-1]), + ), + dtype=torch.int64, + device=input_cpu.device, + ) + + count_spikes = CountSpikesCPU() + + count_spikes.process( + input_cpu.data_ptr(), + int(input_cpu.shape[0]), + int(input_cpu.shape[1]), + int(input_cpu.shape[2]), + int(input_cpu.shape[3]), + histogram.data_ptr(), + int(histogram.shape[1]), + int(self._number_of_cpu_processes), + ) + + return histogram.to(device=input.device) diff --git a/network/SpikeLayer.py b/network/SpikeLayer.py index f7f7060..25a86d6 100644 --- a/network/SpikeLayer.py +++ b/network/SpikeLayer.py @@ -3,8 +3,6 @@ import torch from network.PySpikeGenerationCPU import SpikeGenerationCPU from network.PySpikeGenerationGPU import SpikeGenerationGPU -# from PyCountSpikesCPU import CountSpikesCPU - global_spike_generation_gpu_setting: list[torch.Tensor] = [] global_spike_size: list[torch.Tensor] = [] global_spike_generation_cpp: list[SpikeGenerationCPU | SpikeGenerationGPU] = [] @@ -16,28 +14,21 @@ class SpikeLayer(torch.nn.Module): _spike_generation_gpu_setting_position: int _number_of_cpu_processes: int _number_of_spikes: int - - _spikes: torch.Tensor | None = None - _store_spikes: bool + device: torch.device def __init__( self, - number_of_spikes: int = 1, + number_of_spikes: int = -1, number_of_cpu_processes: int = 1, device: torch.device | None = None, - default_dtype: torch.dtype | None = None, - store_spikes: bool = False, ) -> None: super().__init__() assert device is not None - assert default_dtype is not None self.device = device - self.default_dtype = default_dtype self._number_of_cpu_processes = number_of_cpu_processes self._number_of_spikes = number_of_spikes - self._store_spikes = store_spikes global_spike_generation_gpu_setting.append(torch.tensor([0])) global_spike_size.append(torch.tensor([0, 0, 0, 0])) @@ -62,7 +53,6 @@ class SpikeLayer(torch.nn.Module): self, input: torch.Tensor, number_of_spikes: int | None = None, - store_spikes: bool | None = None, ) -> torch.Tensor: if number_of_spikes is None: @@ -80,18 +70,7 @@ class SpikeLayer(torch.nn.Module): dtype=torch.int64, ) - spikes = self.functional_spike_generation(input, parameter_list) - - if (store_spikes is not None) and (store_spikes is True): - self._spikes = spikes.detach().clone() - elif (store_spikes is not None) and (store_spikes is False): - self._spikes = None - elif self._store_spikes is True: - self._spikes = spikes.detach().clone() - else: - self._spikes = None - - return spikes + return self.functional_spike_generation(input, parameter_list) class FunctionalSpikeGeneration(torch.autograd.Function): diff --git a/network/build_network.py b/network/build_network.py index c506eb2..532116e 100644 --- a/network/build_network.py +++ b/network/build_network.py @@ -3,10 +3,11 @@ import torch from network.calculate_output_size import calculate_output_size from network.Parameter import Config -from network.SbS import SbS +from network.SbSLayer import SbSLayer from network.SplitOnOffLayer import SplitOnOffLayer from network.Conv2dApproximation import Conv2dApproximation from network.SbSReconstruction import SbSReconstruction +from network.InputSpikeImage import InputSpikeImage def build_network( @@ -144,7 +145,7 @@ def build_network( is_pooling_layer = True network.append( - SbS( + SbSLayer( number_of_input_neurons=in_channels, number_of_neurons=out_channels, input_size=input_size[-1], @@ -190,7 +191,7 @@ def build_network( logging.info(f"Layer: {layer_id} -> SbS Reconstruction Layer") assert layer_id > 0 - assert isinstance(network[-1], SbS) is True + assert isinstance(network[-1], SbSLayer) is True network.append(SbSReconstruction(network[-1])) network[-1]._w_trainable = False @@ -365,6 +366,36 @@ def build_network( ).tolist() input_size.append(input_size_temp) + # ############################################################# + # Approx CONV2D layer: + # ############################################################# + + elif ( + cfg.network_structure.layer_type[layer_id] + .upper() + .startswith("INPUT SPIKE IMAGE") + is True + ): + logging.info(f"Layer: {layer_id} -> Input Spike Image Layer") + + number_of_spikes: int = -1 + if len(cfg.number_of_spikes) > layer_id: + number_of_spikes = cfg.number_of_spikes[layer_id] + elif len(cfg.number_of_spikes) == 1: + number_of_spikes = cfg.number_of_spikes[0] + + network.append( + InputSpikeImage( + number_of_spikes=number_of_spikes, + number_of_cpu_processes=cfg.number_of_cpu_processes, + reshape=True, + normalize=True, + device=device, + ) + ) + + input_size.append(input_size[-1]) + # ############################################################# # Failure becaue we didn't found the selection of layer # ############################################################# diff --git a/network/build_optimizer.py b/network/build_optimizer.py index 06f1df6..91888fe 100644 --- a/network/build_optimizer.py +++ b/network/build_optimizer.py @@ -1,7 +1,7 @@ # %% import torch from network.Parameter import Config -from network.SbS import SbS +from network.SbSLayer import SbSLayer from network.Conv2dApproximation import Conv2dApproximation from network.Adam import Adam @@ -20,7 +20,7 @@ def build_optimizer( for id in range(0, len(network)): - if (isinstance(network[id], SbS) is True) and ( + if (isinstance(network[id], SbSLayer) is True) and ( network[id]._w_trainable is True ): parameter_list_weights.append(network[id]._weights) diff --git a/network/load_previous_weights.py b/network/load_previous_weights.py index 636164f..80fd13c 100644 --- a/network/load_previous_weights.py +++ b/network/load_previous_weights.py @@ -3,9 +3,10 @@ import torch import glob import numpy as np -from network.SbS import SbS +from network.SbSLayer import SbSLayer from network.SplitOnOffLayer import SplitOnOffLayer from network.Conv2dApproximation import Conv2dApproximation +import os def load_previous_weights( @@ -14,22 +15,28 @@ def load_previous_weights( logging, device: torch.device, default_dtype: torch.dtype, + order_id: float | int | None = None, ) -> None: + if order_id is None: + post_fix: str = "" + else: + post_fix = f"_{order_id}" + for id in range(0, len(network)): # ################################################# # SbS # ################################################# - if isinstance(network[id], SbS) is True: - # Are there weights that overwrite the initial weights? - file_to_load = glob.glob(overload_path + "/Weight_L" + str(id) + "_*.npy") + if isinstance(network[id], SbSLayer) is True: + filename_wilcard = os.path.join( + overload_path, f"Weight_L{id}_*{post_fix}.npy" + ) + file_to_load = glob.glob(filename_wilcard) if len(file_to_load) > 1: - raise Exception( - f"Too many previous weights files {overload_path}/Weight_L{id}*.npy" - ) + raise Exception(f"Too many previous weights files {filename_wilcard}") if len(file_to_load) == 1: network[id].weights = torch.tensor( @@ -45,13 +52,13 @@ def load_previous_weights( # Conv2d weights # ################################################# - # Are there weights that overwrite the initial weights? - file_to_load = glob.glob(overload_path + "/Weight_L" + str(id) + "_*.npy") + filename_wilcard = os.path.join( + overload_path, f"Weight_L{id}_*{post_fix}.npy" + ) + file_to_load = glob.glob(filename_wilcard) if len(file_to_load) > 1: - raise Exception( - f"Too many previous weights files {overload_path}/Weight_L{id}*.npy" - ) + raise Exception(f"Too many previous weights files {filename_wilcard}") if len(file_to_load) == 1: network[id]._parameters["weight"].data = torch.tensor( @@ -65,13 +72,13 @@ def load_previous_weights( # Conv2d bias # ################################################# - # Are there biases that overwrite the initial weights? - file_to_load = glob.glob(overload_path + "/Bias_L" + str(id) + "_*.npy") + filename_wilcard = os.path.join( + overload_path, f"Bias_L{id}_*{post_fix}.npy" + ) + file_to_load = glob.glob(filename_wilcard) if len(file_to_load) > 1: - raise Exception( - f"Too many previous weights files {overload_path}/Weight_L{id}*.npy" - ) + raise Exception(f"Too many previous weights files {filename_wilcard}") if len(file_to_load) == 1: network[id]._parameters["bias"].data = torch.tensor( @@ -87,13 +94,13 @@ def load_previous_weights( # Approximate Conv2d weights # ################################################# - # Are there weights that overwrite the initial weights? - file_to_load = glob.glob(overload_path + "/Weight_L" + str(id) + "_*.npy") + filename_wilcard = os.path.join( + overload_path, f"Weight_L{id}_*{post_fix}.npy" + ) + file_to_load = glob.glob(filename_wilcard) if len(file_to_load) > 1: - raise Exception( - f"Too many previous weights files {overload_path}/Weight_L{id}*.npy" - ) + raise Exception(f"Too many previous weights files {filename_wilcard}") if len(file_to_load) == 1: network[id].weights.data = torch.tensor( @@ -107,13 +114,13 @@ def load_previous_weights( # Approximate Conv2d bias # ################################################# - # Are there biases that overwrite the initial weights? - file_to_load = glob.glob(overload_path + "/Bias_L" + str(id) + "_*.npy") + filename_wilcard = os.path.join( + overload_path, f"Bias_L{id}_*{post_fix}.npy" + ) + file_to_load = glob.glob(filename_wilcard) if len(file_to_load) > 1: - raise Exception( - f"Too many previous weights files {overload_path}/Weight_L{id}*.npy" - ) + raise Exception(f"Too many previous weights files {filename_wilcard}") if len(file_to_load) == 1: network[id].bias.data = torch.tensor( @@ -127,13 +134,13 @@ def load_previous_weights( # SplitOnOffLayer # ################################################# if isinstance(network[id], SplitOnOffLayer) is True: - # Are there weights that overwrite the initial weights? - file_to_load = glob.glob(overload_path + "/Mean_L" + str(id) + "_*.npy") + filename_wilcard = os.path.join( + overload_path, f"Mean_L{id}_*{post_fix}.npy" + ) + file_to_load = glob.glob(filename_wilcard) if len(file_to_load) > 1: - raise Exception( - f"Too many previous mean files {overload_path}/Mean_L{id}*.npy" - ) + raise Exception(f"Too many previous mean files {filename_wilcard}") if len(file_to_load) == 1: network[id].mean = torch.tensor( diff --git a/network/loop_train_test.py b/network/loop_train_test.py index 2593af8..bc594de 100644 --- a/network/loop_train_test.py +++ b/network/loop_train_test.py @@ -3,7 +3,7 @@ import time from network.Parameter import Config from torch.utils.tensorboard import SummaryWriter -from network.SbS import SbS +from network.SbSLayer import SbSLayer from network.save_weight_and_bias import save_weight_and_bias from network.SbSReconstruction import SbSReconstruction @@ -19,7 +19,7 @@ def add_weight_and_bias_to_histogram( # ################################################ # Log the SbS Weights # ################################################ - if isinstance(network[id], SbS) is True: + if isinstance(network[id], SbSLayer) is True: if network[id]._w_trainable is True: try: @@ -175,7 +175,7 @@ def forward_pass_train( .to(device=device) ) for id in range(0, len(network)): - if isinstance(network[id], SbS) is True: + if isinstance(network[id], SbSLayer) is True: h_collection.append(network[id](h_collection[-1], labels)) else: h_collection.append(network[id](h_collection[-1])) @@ -203,7 +203,7 @@ def forward_pass_test( ) for id in range(0, len(network)): if (cfg.extract_noisy_pictures is True) or (overwrite_number_of_spikes != -1): - if isinstance(network[id], SbS) is True: + if isinstance(network[id], SbSLayer) is True: h_collection.append( network[id]( h_collection[-1], @@ -228,7 +228,7 @@ def run_optimizer( cfg: Config, ) -> None: for id in range(0, len(network)): - if isinstance(network[id], SbS) is True: + if isinstance(network[id], SbSLayer) is True: network[id].update_pre_care() for optimizer_item in optimizer: @@ -236,7 +236,7 @@ def run_optimizer( optimizer_item.step() for id in range(0, len(network)): - if isinstance(network[id], SbS) is True: + if isinstance(network[id], SbSLayer) is True: network[id].update_after_care( cfg.learning_parameters.learning_rate_threshold_w / float( @@ -288,11 +288,11 @@ def run_lr_scheduler( def deal_with_gradient_scale(epoch_id: int, mini_batch_number: int, network): if (epoch_id == 0) and (mini_batch_number == 0): for id in range(0, len(network)): - if isinstance(network[id], SbS) is True: + if isinstance(network[id], SbSLayer) is True: network[id].after_batch(True) else: for id in range(0, len(network)): - if isinstance(network[id], SbS) is True: + if isinstance(network[id], SbSLayer) is True: network[id].after_batch() @@ -309,6 +309,7 @@ def loop_train( tb: SummaryWriter, lr_scheduler, last_test_performance: float, + order_id: float | int | None = None, ) -> tuple[float, float, float, float]: correct_in_minibatch: int = 0 @@ -529,7 +530,10 @@ def loop_train( # Save the Weights and Biases # ################################################ save_weight_and_bias( - cfg=cfg, network=network, iteration_number=epoch_id + cfg=cfg, + network=network, + iteration_number=epoch_id, + order_id=order_id, ) # ################################################ diff --git a/network/save_weight_and_bias.py b/network/save_weight_and_bias.py index ee42451..ad4e178 100644 --- a/network/save_weight_and_bias.py +++ b/network/save_weight_and_bias.py @@ -3,14 +3,23 @@ import torch from network.Parameter import Config import numpy as np -from network.SbS import SbS +from network.SbSLayer import SbSLayer from network.SplitOnOffLayer import SplitOnOffLayer from network.Conv2dApproximation import Conv2dApproximation +import os + def save_weight_and_bias( - cfg: Config, network: torch.nn.modules.container.Sequential, iteration_number: int + cfg: Config, + network: torch.nn.modules.container.Sequential, + iteration_number: int, + order_id: float | int | None = None, ) -> None: + if order_id is None: + post_fix: str = "" + else: + post_fix = f"_{order_id}" for id in range(0, len(network)): @@ -18,11 +27,14 @@ def save_weight_and_bias( # Save the SbS Weights # ################################################ - if isinstance(network[id], SbS) is True: + if isinstance(network[id], SbSLayer) is True: if network[id]._w_trainable is True: np.save( - f"{cfg.weight_path}/Weight_L{id}_S{iteration_number}.npy", + os.path.join( + cfg.weight_path, + f"Weight_L{id}_S{iteration_number}{post_fix}.npy", + ), network[id].weights.detach().cpu().numpy(), ) @@ -34,13 +46,18 @@ def save_weight_and_bias( if network[id]._w_trainable is True: # Save the new values np.save( - f"{cfg.weight_path}/Weight_L{id}_S{iteration_number}.npy", + os.path.join( + cfg.weight_path, + f"Weight_L{id}_S{iteration_number}{post_fix}.npy", + ), network[id]._parameters["weight"].data.detach().cpu().numpy(), ) # Save the new values np.save( - f"{cfg.weight_path}/Bias_L{id}_S{iteration_number}.npy", + os.path.join( + cfg.weight_path, f"Bias_L{id}_S{iteration_number}{post_fix}.npy" + ), network[id]._parameters["bias"].data.detach().cpu().numpy(), ) @@ -52,20 +69,28 @@ def save_weight_and_bias( if network[id]._w_trainable is True: # Save the new values np.save( - f"{cfg.weight_path}/Weight_L{id}_S{iteration_number}.npy", + os.path.join( + cfg.weight_path, + f"Weight_L{id}_S{iteration_number}{post_fix}.npy", + ), network[id].weights.data.detach().cpu().numpy(), ) # Save the new values if network[id].bias is not None: np.save( - f"{cfg.weight_path}/Bias_L{id}_S{iteration_number}.npy", + os.path.join( + cfg.weight_path, + f"Bias_L{id}_S{iteration_number}{post_fix}.npy", + ), network[id].bias.data.detach().cpu().numpy(), ) if isinstance(network[id], SplitOnOffLayer) is True: np.save( - f"{cfg.weight_path}/Mean_L{id}_S{iteration_number}.npy", + os.path.join( + cfg.weight_path, f"Mean_L{id}_S{iteration_number}{post_fix}.npy" + ), network[id].mean.detach().cpu().numpy(), )