diff --git a/network/Adam.py b/network/Adam.py index 45bde9b..4ed7baf 100644 --- a/network/Adam.py +++ b/network/Adam.py @@ -15,6 +15,7 @@ class Adam(torch.optim.Optimizer): self, params, sbs_setting: list[bool], + logging, lr: float = 1e-3, beta1: float = 0.9, beta2: float = 0.999, @@ -41,6 +42,7 @@ class Adam(torch.optim.Optimizer): self.beta2 = beta2 self.eps = eps self.maximize = maximize + self._logging = logging defaults = dict( lr=lr, @@ -149,8 +151,12 @@ class Adam(torch.optim.Optimizer): if sbs_setting[i] is False: param -= step_size * (exp_avg / denom) else: - delta = torch.exp(-step_size * (exp_avg / denom)) - print( - f"{float(delta.min()) - 1.0:.4e} {float(delta.max()) - 1.0:.4e} {lr:.4e}" + # delta = torch.exp(-step_size * (exp_avg / denom)) + delta = torch.tanh(-step_size * (exp_avg / denom)) + delta += 1.0 + delta *= 0.5 + delta += 0.5 + self._logging.info( + f"ADAM: Layer {i} -> dw_min:{float(delta.min()):.4e} dw_max:{float(delta.max()):.4e} lr:{lr:.4e}" ) param *= delta diff --git a/network/Conv2dApproximation.py b/network/Conv2dApproximation.py index 3a3525c..9276d4f 100644 --- a/network/Conv2dApproximation.py +++ b/network/Conv2dApproximation.py @@ -3,6 +3,10 @@ import math from network.CPP.PyMultiApp import MultiApp +global_multiapp_gpu_setting: list[torch.Tensor] = [] +global_multiapp_size: list[torch.Tensor] = [] +global_multiapp_cpp: list[MultiApp] = [] + class Conv2dApproximation(torch.nn.Module): @@ -26,6 +30,9 @@ class Conv2dApproximation(torch.nn.Module): device: torch.device dtype: torch.dtype + multiapp_gpu_setting_position: int = -1 + multiapp_cpp_position: int = -1 + def __init__( self, in_channels: int, @@ -68,6 +75,12 @@ class Conv2dApproximation(torch.nn.Module): self.number_of_trunc_bits = number_of_trunc_bits self.number_of_frac = number_of_frac + global_multiapp_gpu_setting.append(torch.tensor([0])) + global_multiapp_size.append(torch.tensor([0, 0, 0, 0])) + global_multiapp_cpp.append(MultiApp()) + self.multiapp_gpu_setting_position = len(global_multiapp_gpu_setting) - 1 + self.multiapp_cpp_position = len(global_multiapp_cpp) - 1 + if self.use_bias is True: self.bias: torch.nn.parameter.Parameter | None = ( torch.nn.parameter.Parameter( @@ -190,6 +203,8 @@ class Conv2dApproximation(torch.nn.Module): assert input.dim() == 4 assert self.kernel_size is not None + assert self.multiapp_gpu_setting_position != -1 + assert self.multiapp_cpp_position != -1 input_size = torch.Tensor([int(input.shape[-2]), int(input.shape[-1])]).type( dtype=torch.int64 @@ -232,6 +247,8 @@ class Conv2dApproximation(torch.nn.Module): int(self.number_of_trunc_bits), # 1 int(self.number_of_frac), # 2 int(number_of_cpu_processes), # 3 + int(self.multiapp_gpu_setting_position), # 4 + int(self.multiapp_cpp_position), # 5 ], dtype=torch.int64, ) @@ -267,6 +284,8 @@ class FunctionalMultiConv2d(torch.autograd.Function): number_of_trunc_bits = int(parameter_list[1]) number_of_frac = int(parameter_list[2]) number_of_processes = int(parameter_list[3]) + multiapp_gpu_setting_position = int(parameter_list[4]) + multiapp_cpp_position = int(parameter_list[5]) assert input.device == weights.device @@ -278,9 +297,54 @@ class FunctionalMultiConv2d(torch.autograd.Function): ) assert output.is_contiguous() is True - multiplier: MultiApp = MultiApp() + multiapp_profile = global_multiapp_gpu_setting[ + multiapp_gpu_setting_position + ].clone() - multiplier.update_with_init_vector_multi_pattern( + multiapp_size = global_multiapp_size[multiapp_gpu_setting_position].clone() + + if input.device != torch.device("cpu"): + if ( + (multiapp_profile.numel() == 1) + or (multiapp_size[0] != int(output.shape[0])) + or (multiapp_size[1] != int(output.shape[1])) + or (multiapp_size[2] != int(output.shape[2])) + or (multiapp_size[3] != int(output.shape[3])) + ): + multiapp_profile = torch.zeros( + (1, 7), dtype=torch.int64, device=torch.device("cpu") + ) + + global_multiapp_cpp[multiapp_cpp_position].gpu_occupancy_export( + int(output.shape[2]), + int(output.shape[3]), + int(output.shape[0]), + int(output.shape[1]), + multiapp_profile.data_ptr(), + int(multiapp_profile.shape[0]), + int(multiapp_profile.shape[1]), + ) + global_multiapp_gpu_setting[ + multiapp_gpu_setting_position + ] = multiapp_profile.clone() + + multiapp_size[0] = int(output.shape[0]) + multiapp_size[1] = int(output.shape[1]) + multiapp_size[2] = int(output.shape[2]) + multiapp_size[3] = int(output.shape[3]) + + global_multiapp_size[ + multiapp_gpu_setting_position + ] = multiapp_size.clone() + + else: + global_multiapp_cpp[multiapp_cpp_position].gpu_occupancy_import( + multiapp_profile.data_ptr(), + int(multiapp_profile.shape[0]), + int(multiapp_profile.shape[1]), + ) + + global_multiapp_cpp[multiapp_cpp_position].update_entrypoint( input.data_ptr(), weights.data_ptr(), output.data_ptr(), diff --git a/network/Parameter.py b/network/Parameter.py index 77fa859..1469334 100644 --- a/network/Parameter.py +++ b/network/Parameter.py @@ -47,8 +47,8 @@ class LearningParameters: weight_noise_range: list[float] = field(default_factory=list) eps_xy_intitial: float = field(default=0.1) - # disable_scale_grade: bool = field(default=False) - # kepp_last_grad_scale: bool = field(default=True) + disable_scale_grade: bool = field(default=False) + kepp_last_grad_scale: bool = field(default=True) sbs_skip_gradient_calculation: list[bool] = field(default_factory=list) diff --git a/network/SbS.py b/network/SbS.py index 0334ad5..da56e01 100644 --- a/network/SbS.py +++ b/network/SbS.py @@ -4,6 +4,13 @@ from network.CPP.PySpikeGeneration2DManyIP import SpikeGeneration2DManyIP from network.CPP.PyHDynamicCNNManyIP import HDynamicCNNManyIP from network.calculate_output_size import calculate_output_size +global_sbs_gpu_setting: list[torch.Tensor] = [] +global_sbs_size: list[torch.Tensor] = [] +global_sbs_hdynamic_cpp: list[HDynamicCNNManyIP] = [] +global_spike_generation_gpu_setting: list[torch.Tensor] = [] +global_spike_size: list[torch.Tensor] = [] +global_spike_generation_cpp: list[SpikeGeneration2DManyIP] = [] + class SbS(torch.nn.Module): @@ -24,9 +31,9 @@ class SbS(torch.nn.Module): _epsilon_xy_intitial: float _h_initial: torch.Tensor | None = None _w_trainable: bool - # _last_grad_scale: torch.nn.parameter.Parameter - # _keep_last_grad_scale: bool - # _disable_scale_grade: bool + _last_grad_scale: torch.nn.parameter.Parameter + _keep_last_grad_scale: bool + _disable_scale_grade: bool _forgetting_offset: torch.Tensor | None = None _weight_noise_range: list[float] _skip_gradient_calculation: bool @@ -43,6 +50,14 @@ class SbS(torch.nn.Module): _number_of_grad_weight_contributions: float = 0.0 + last_input_store: bool = False + last_input_data: torch.Tensor | None = None + + sbs_gpu_setting_position: int = -1 + sbs_hdynamic_cpp_position: int = -1 + spike_generation_cpp_position: int = -1 + spike_generation_gpu_setting_position: int = -1 + def __init__( self, number_of_input_neurons: int, @@ -60,13 +75,14 @@ class SbS(torch.nn.Module): padding: list[int] = [0, 0], number_of_cpu_processes: int = 1, w_trainable: bool = False, - # keep_last_grad_scale: bool = False, - # disable_scale_grade: bool = True, + keep_last_grad_scale: bool = False, + disable_scale_grade: bool = True, forgetting_offset: float = -1.0, skip_gradient_calculation: bool = False, device: torch.device | None = None, default_dtype: torch.dtype | None = None, gpu_tuning_factor: int = 5, + layer_id: int = -1, ) -> None: super().__init__() @@ -76,9 +92,9 @@ class SbS(torch.nn.Module): self.default_dtype = default_dtype self._w_trainable = bool(w_trainable) - # self._keep_last_grad_scale = bool(keep_last_grad_scale) + self._keep_last_grad_scale = bool(keep_last_grad_scale) self._skip_gradient_calculation = bool(skip_gradient_calculation) - # self._disable_scale_grade = bool(disable_scale_grade) + self._disable_scale_grade = bool(disable_scale_grade) self._epsilon_xy_intitial = float(epsilon_xy_intitial) self._stride = strides self._dilation = dilation @@ -95,6 +111,21 @@ class SbS(torch.nn.Module): assert len(input_size) == 2 self._input_size = input_size + global_sbs_gpu_setting.append(torch.tensor([0])) + global_spike_generation_gpu_setting.append(torch.tensor([0])) + global_sbs_size.append(torch.tensor([0, 0, 0, 0])) + global_spike_size.append(torch.tensor([0, 0, 0, 0])) + + global_sbs_hdynamic_cpp.append(HDynamicCNNManyIP()) + global_spike_generation_cpp.append(SpikeGeneration2DManyIP()) + + self.sbs_gpu_setting_position = len(global_sbs_gpu_setting) - 1 + self.sbs_hdynamic_cpp_position = len(global_sbs_hdynamic_cpp) - 1 + self.spike_generation_cpp_position = len(global_spike_generation_cpp) - 1 + self.spike_generation_gpu_setting_position = ( + len(global_spike_generation_gpu_setting) - 1 + ) + # The GPU hates me... # Too many SbS threads == bad # Thus I need to limit them... @@ -105,10 +136,10 @@ class SbS(torch.nn.Module): else: self._gpu_tuning_factor = 0 - # self._last_grad_scale = torch.nn.parameter.Parameter( - # torch.tensor(-1.0, dtype=self.default_dtype), - # requires_grad=True, - # ) + self._last_grad_scale = torch.nn.parameter.Parameter( + torch.tensor(-1.0, dtype=self.default_dtype), + requires_grad=True, + ) self._forgetting_offset = torch.tensor( forgetting_offset, dtype=self.default_dtype, device=self.device @@ -234,12 +265,12 @@ class SbS(torch.nn.Module): self.threshold_weights(threshold_weight) self.norm_weights() - # def after_batch(self, new_state: bool = False): - # if self._keep_last_grad_scale is True: - # self._last_grad_scale.data = self._last_grad_scale.grad - # self._keep_last_grad_scale = new_state + def after_batch(self, new_state: bool = False): + if self._keep_last_grad_scale is True: + self._last_grad_scale.data = self._last_grad_scale.grad + self._keep_last_grad_scale = new_state - # self._last_grad_scale.grad = torch.zeros_like(self._last_grad_scale.grad) + self._last_grad_scale.grad = torch.zeros_like(self._last_grad_scale.grad) #################################################################### # Helper functions # @@ -339,6 +370,20 @@ class SbS(torch.nn.Module): assert self._weights_exists is True assert self._weights is not None + assert self.sbs_gpu_setting_position != -1 + assert self.sbs_hdynamic_cpp_position != -1 + assert self.spike_generation_cpp_position != -1 + assert self.spike_generation_gpu_setting_position != -1 + + if labels is None: + labels_copy: torch.Tensor = torch.tensor( + [], dtype=torch.int64, device=self.device + ) + else: + labels_copy = ( + labels.detach().clone().type(dtype=torch.int64).to(device=self.device) + ) + input_convolved = torch.nn.functional.fold( torch.nn.functional.unfold( input.requires_grad_(True), @@ -354,6 +399,12 @@ class SbS(torch.nn.Module): stride=(1, 1), ) + if self.last_input_store is True: + self.last_input_data = input_convolved.detach().clone() + self.last_input_data /= self.last_input_data.sum(dim=1, keepdim=True) + else: + self.last_input_data = None + epsilon_t_0: torch.Tensor = ( (self._epsilon_t * self._epsilon_0).type(input.dtype).to(input.device) ) @@ -361,8 +412,8 @@ class SbS(torch.nn.Module): parameter_list = torch.tensor( [ int(self._w_trainable), # 0 - int(0), # int(self._disable_scale_grade), # 1 - int(0), # int(self._keep_last_grad_scale), # 2 + int(self._disable_scale_grade), # 1 + int(self._keep_last_grad_scale), # 2 int(self._skip_gradient_calculation), # 3 int(self._number_of_spikes), # 4 int(self._number_of_cpu_processes), # 5 @@ -371,6 +422,10 @@ class SbS(torch.nn.Module): int(self._gpu_tuning_factor), # 8 int(self._output_layer), # 9 int(self._local_learning), # 10 + int(self.sbs_gpu_setting_position), # 11 + int(self.sbs_hdynamic_cpp_position), # 12 + int(self.spike_generation_cpp_position), # 13 + int(self.spike_generation_gpu_setting_position), # 14 ], dtype=torch.int64, ) @@ -401,8 +456,9 @@ class SbS(torch.nn.Module): self._weights, self._h_initial, parameter_list, - # self._last_grad_scale, + self._last_grad_scale, self._forgetting_offset, + labels_copy, ) self._number_of_grad_weight_contributions += ( @@ -422,8 +478,9 @@ class FunctionalSbS(torch.autograd.Function): weights: torch.Tensor, h_initial: torch.Tensor, parameter_list: torch.Tensor, - # grad_output_scale: torch.Tensor, + grad_output_scale: torch.Tensor, forgetting_offset: torch.Tensor, + labels: torch.Tensor, ) -> torch.Tensor: assert input.dim() == 4 @@ -444,6 +501,11 @@ class FunctionalSbS(torch.autograd.Function): output_size_1: int = int(parameter_list[7]) gpu_tuning_factor: int = int(parameter_list[8]) + sbs_gpu_setting_position = int(parameter_list[11]) + sbs_hdynamic_cpp_position = int(parameter_list[12]) + spike_generation_cpp_position = int(parameter_list[13]) + spike_generation_gpu_setting_position = int(parameter_list[14]) + # ########################################################### # Spike generation # ########################################################### @@ -480,9 +542,59 @@ class FunctionalSbS(torch.autograd.Function): assert spikes.is_contiguous() is True # time_start: float = time.perf_counter() - spike_generation: SpikeGeneration2DManyIP = SpikeGeneration2DManyIP() + spike_generation_profile = global_spike_generation_gpu_setting[ + spike_generation_gpu_setting_position + ].clone() - spike_generation.spike_generation( + spike_generation_size = global_spike_size[ + spike_generation_gpu_setting_position + ].clone() + + if input.device != torch.device("cpu"): + if ( + (spike_generation_profile.numel() == 1) + or (spike_generation_size[0] != int(spikes.shape[0])) + or (spike_generation_size[1] != int(spikes.shape[1])) + or (spike_generation_size[2] != int(spikes.shape[2])) + or (spike_generation_size[3] != int(spikes.shape[3])) + ): + spike_generation_profile = torch.zeros( + (1, 7), dtype=torch.int64, device=torch.device("cpu") + ) + + global_spike_generation_cpp[ + spike_generation_cpp_position + ].gpu_occupancy_export( + int(spikes.shape[2]), + int(spikes.shape[3]), + int(spikes.shape[0]), + int(spikes.shape[1]), + spike_generation_profile.data_ptr(), + int(spike_generation_profile.shape[0]), + int(spike_generation_profile.shape[1]), + ) + global_spike_generation_gpu_setting[ + spike_generation_gpu_setting_position + ] = spike_generation_profile.clone() + + spike_generation_size[0] = int(spikes.shape[0]) + spike_generation_size[1] = int(spikes.shape[1]) + spike_generation_size[2] = int(spikes.shape[2]) + spike_generation_size[3] = int(spikes.shape[3]) + global_spike_size[ + spike_generation_gpu_setting_position + ] = spike_generation_size.clone() + + else: + global_spike_generation_cpp[ + spike_generation_cpp_position + ].gpu_occupancy_import( + spike_generation_profile.data_ptr(), + int(spike_generation_profile.shape[0]), + int(spike_generation_profile.shape[1]), + ) + + global_spike_generation_cpp[spike_generation_cpp_position].spike_generation( input_cumsum.data_ptr(), int(input_cumsum.shape[0]), int(input_cumsum.shape[1]), @@ -536,9 +648,46 @@ class FunctionalSbS(torch.autograd.Function): assert weights.ndim == 2 assert h_initial.ndim == 1 - h_dynamic: HDynamicCNNManyIP = HDynamicCNNManyIP() + sbs_profile = global_sbs_gpu_setting[sbs_gpu_setting_position].clone() - h_dynamic.update( + sbs_size = global_sbs_size[sbs_gpu_setting_position].clone() + + if input.device != torch.device("cpu"): + if ( + (sbs_profile.numel() == 1) + or (sbs_size[0] != int(output.shape[0])) + or (sbs_size[1] != int(output.shape[1])) + or (sbs_size[2] != int(output.shape[2])) + or (sbs_size[3] != int(output.shape[3])) + ): + sbs_profile = torch.zeros( + (14, 7), dtype=torch.int64, device=torch.device("cpu") + ) + + global_sbs_hdynamic_cpp[sbs_hdynamic_cpp_position].gpu_occupancy_export( + int(output.shape[2]), + int(output.shape[3]), + int(output.shape[0]), + int(output.shape[1]), + sbs_profile.data_ptr(), + int(sbs_profile.shape[0]), + int(sbs_profile.shape[1]), + ) + global_sbs_gpu_setting[sbs_gpu_setting_position] = sbs_profile.clone() + sbs_size[0] = int(output.shape[0]) + sbs_size[1] = int(output.shape[1]) + sbs_size[2] = int(output.shape[2]) + sbs_size[3] = int(output.shape[3]) + global_sbs_size[sbs_gpu_setting_position] = sbs_size.clone() + + else: + global_sbs_hdynamic_cpp[sbs_hdynamic_cpp_position].gpu_occupancy_import( + sbs_profile.data_ptr(), + int(sbs_profile.shape[0]), + int(sbs_profile.shape[1]), + ) + + global_sbs_hdynamic_cpp[sbs_hdynamic_cpp_position].update( output.data_ptr(), int(output.shape[0]), int(output.shape[1]), @@ -575,7 +724,8 @@ class FunctionalSbS(torch.autograd.Function): weights, output, parameter_list, - # grad_output_scale, + grad_output_scale, + labels, ) return output @@ -590,9 +740,12 @@ class FunctionalSbS(torch.autograd.Function): weights, output, parameter_list, - # last_grad_scale, + last_grad_scale, + labels, ) = ctx.saved_tensors + assert labels.numel() > 0 + # ############################################## # Default output # ############################################## @@ -603,13 +756,14 @@ class FunctionalSbS(torch.autograd.Function): grad_h_initial = None grad_parameter_list = None grad_forgetting_offset = None + grad_labels = None # ############################################## # Parameters # ############################################## parameter_w_trainable: bool = bool(parameter_list[0]) - # parameter_disable_scale_grade: bool = bool(parameter_list[1]) - # parameter_keep_last_grad_scale: bool = bool(parameter_list[2]) + parameter_disable_scale_grade: bool = bool(parameter_list[1]) + parameter_keep_last_grad_scale: bool = bool(parameter_list[2]) parameter_skip_gradient_calculation: bool = bool(parameter_list[3]) parameter_output_layer: bool = bool(parameter_list[9]) parameter_local_learning: bool = bool(parameter_list[10]) @@ -617,13 +771,13 @@ class FunctionalSbS(torch.autograd.Function): # ############################################## # Dealing with overall scale of the gradient # ############################################## - # if parameter_disable_scale_grade is False: - # if parameter_keep_last_grad_scale is True: - # last_grad_scale = torch.tensor( - # [torch.abs(grad_output).max(), last_grad_scale] - # ).max() - # grad_output /= last_grad_scale - # grad_output_scale = last_grad_scale.clone() + if parameter_disable_scale_grade is False: + if parameter_keep_last_grad_scale is True: + last_grad_scale = torch.tensor( + [torch.abs(grad_output).max(), last_grad_scale] + ).max() + grad_output /= last_grad_scale + grad_output_scale = last_grad_scale.clone() input /= input.sum(dim=1, keepdim=True, dtype=weights.dtype) @@ -640,8 +794,9 @@ class FunctionalSbS(torch.autograd.Function): grad_weights, grad_h_initial, grad_parameter_list, - # grad_output_scale, + grad_output_scale, grad_forgetting_offset, + grad_labels, ) # ################################################# @@ -682,6 +837,41 @@ class FunctionalSbS(torch.autograd.Function): .sum(-1) ) + elif (parameter_output_layer is True) and (parameter_local_learning is True): + + target_one_hot: torch.Tensor = torch.zeros( + ( + labels.shape[0], + output.shape[1], + ), + device=input.device, + dtype=input.dtype, + ) + + target_one_hot.scatter_( + 1, + labels.to(input.device).unsqueeze(1), + torch.ones( + (labels.shape[0], 1), + device=input.device, + dtype=input.dtype, + ), + ) + target_one_hot = target_one_hot.unsqueeze(-1).unsqueeze(-1) + + # (-2 * (input - backprop_bigr).unsqueeze(2) * (target_one_hot-output).unsqueeze(1)) + # (-2 * input.unsqueeze(2) * (target_one_hot-output).unsqueeze(1)) + grad_weights = ( + ( + -2 + * (input - backprop_bigr).unsqueeze(2) + * target_one_hot.unsqueeze(1) + ) + .sum(0) + .sum(-1) + .sum(-1) + ) + else: # ################################################# # Backprop @@ -709,6 +899,7 @@ class FunctionalSbS(torch.autograd.Function): grad_weights, grad_h_initial, grad_parameter_list, - # grad_output_scale, + grad_output_scale, grad_forgetting_offset, + grad_labels, ) diff --git a/network/SbSReconstruction.py b/network/SbSReconstruction.py new file mode 100644 index 0000000..be6420f --- /dev/null +++ b/network/SbSReconstruction.py @@ -0,0 +1,33 @@ +import torch + +from network.SbS import SbS + + +class SbSReconstruction(torch.nn.Module): + + _the_sbs_layer: SbS + + def __init__( + self, + the_sbs_layer: SbS, + ) -> None: + super().__init__() + + self._the_sbs_layer = the_sbs_layer + self.device = self._the_sbs_layer.device + self.default_dtype = self._the_sbs_layer.default_dtype + + def forward(self, input: torch.Tensor) -> torch.Tensor: + + assert self._the_sbs_layer._weights_exists is True + + input_norm = input / input.sum(dim=1, keepdim=True) + + output = ( + self._the_sbs_layer._weights.data.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + * input_norm.unsqueeze(1) + ).sum(dim=2) + + output /= output.sum(dim=1, keepdim=True) + + return output diff --git a/network/build_network.py b/network/build_network.py index 6d5e306..f075a0b 100644 --- a/network/build_network.py +++ b/network/build_network.py @@ -6,6 +6,7 @@ from network.Parameter import Config from network.SbS import SbS from network.SplitOnOffLayer import SplitOnOffLayer from network.Conv2dApproximation import Conv2dApproximation +from network.SbSReconstruction import SbSReconstruction def build_network( @@ -159,12 +160,13 @@ def build_network( padding=padding, number_of_cpu_processes=cfg.number_of_cpu_processes, w_trainable=w_trainable, - # keep_last_grad_scale=cfg.learning_parameters.kepp_last_grad_scale, - # disable_scale_grade=cfg.learning_parameters.disable_scale_grade, + keep_last_grad_scale=cfg.learning_parameters.kepp_last_grad_scale, + disable_scale_grade=cfg.learning_parameters.disable_scale_grade, forgetting_offset=cfg.forgetting_offset, skip_gradient_calculation=sbs_skip_gradient_calculation, device=device, default_dtype=default_dtype, + layer_id=layer_id, ) ) # Adding the x,y output dimensions @@ -178,6 +180,25 @@ def build_network( if cfg.network_structure.layer_type[layer_id].upper().find("LOCAL") != -1: network[-1]._local_learning = True + elif ( + cfg.network_structure.layer_type[layer_id] + .upper() + .startswith("RECONSTRUCTION") + is True + ): + logging.info(f"Layer: {layer_id} -> SbS Reconstruction Layer") + + assert layer_id > 0 + assert isinstance(network[-1], SbS) is True + + network.append(SbSReconstruction(network[-1])) + network[-1]._w_trainable = False + + if layer_id == len(cfg.network_structure.layer_type) - 1: + network[-2].last_input_store = True + + input_size.append(input_size[-1]) + # ############################################################# # Split On Off Layer: # ############################################################# diff --git a/network/build_optimizer.py b/network/build_optimizer.py index d0f96d8..06f1df6 100644 --- a/network/build_optimizer.py +++ b/network/build_optimizer.py @@ -57,10 +57,13 @@ def build_optimizer( optimizer_wf = Adam( parameter_list_weights, parameter_list_sbs, + logging=logging, lr=cfg.learning_parameters.learning_rate_gamma_w, ) else: - optimizer_wf = Adam(parameter_list_weights, parameter_list_sbs) + optimizer_wf = Adam( + parameter_list_weights, parameter_list_sbs, logging=logging + ) elif cfg.learning_parameters.optimizer_name == "SGD": logging.info("Using optimizer: SGD") diff --git a/network/loop_train_test.py b/network/loop_train_test.py index 6801417..fc75c9e 100644 --- a/network/loop_train_test.py +++ b/network/loop_train_test.py @@ -5,6 +5,7 @@ from torch.utils.tensorboard import SummaryWriter from network.SbS import SbS from network.save_weight_and_bias import save_weight_and_bias +from network.SbSReconstruction import SbSReconstruction def add_weight_and_bias_to_histogram( @@ -94,7 +95,7 @@ def loss_function( device=device, dtype=default_dtype, ), - ).unsqueeze(-1).unsqueeze(-1) + ) h_y1 = torch.log(h + 1e-20) @@ -119,6 +120,44 @@ def loss_function( return None +def loss_function_reconstruction( + h_reco: torch.Tensor, + h_input: torch.Tensor, + loss_mode: int = 0, + loss_coeffs_mse: float = 0.0, + loss_coeffs_kldiv: float = 0.0, +) -> torch.Tensor | None: + assert loss_mode >= 0 + assert loss_mode <= 0 + + assert h_reco.ndim == 4 + assert h_input.ndim == 4 + assert h_reco.shape[0] == h_input.shape[0] + assert h_reco.shape[1] == h_input.shape[1] + assert h_reco.shape[2] == h_input.shape[2] + assert h_reco.shape[3] == h_input.shape[3] + + if loss_mode == 0: + + h_reco_log = torch.log(h_reco + 1e-20) + + my_loss: torch.Tensor = ( + torch.nn.functional.mse_loss( + h_reco, + h_input, + reduction="sum", + ) + * loss_coeffs_mse + + torch.nn.functional.kl_div(h_reco_log, h_input + 1e-20, reduction="sum") + * loss_coeffs_kldiv + ) / (loss_coeffs_kldiv + loss_coeffs_mse) + + return my_loss + + else: + return None + + def forward_pass_train( input: torch.Tensor, labels: torch.Tensor, @@ -228,15 +267,15 @@ def run_lr_scheduler( tb.flush() -# def deal_with_gradient_scale(epoch_id: int, mini_batch_number: int, network): -# if (epoch_id == 0) and (mini_batch_number == 0): -# for id in range(0, len(network)): -# if isinstance(network[id], SbS) is True: -# network[id].after_batch(True) -# else: -# for id in range(0, len(network)): -# if isinstance(network[id], SbS) is True: -# network[id].after_batch() +def deal_with_gradient_scale(epoch_id: int, mini_batch_number: int, network): + if (epoch_id == 0) and (mini_batch_number == 0): + for id in range(0, len(network)): + if isinstance(network[id], SbS) is True: + network[id].after_batch(True) + else: + for id in range(0, len(network)): + if isinstance(network[id], SbS) is True: + network[id].after_batch() def loop_train( @@ -318,12 +357,20 @@ def loop_train( if last_test_performance < 0: logging.info("") else: - logging.info( - ( - f"\t\t\tLast test performance: " - f"{last_test_performance/100.0:^6.2%}" + if isinstance(network[-1], SbSReconstruction) is False: + logging.info( + ( + f"\t\t\tLast test performance: " + f"{last_test_performance/100.0:^6.2%}" + ) + ) + else: + logging.info( + ( + f"\t\t\tLast test performance: " + f"{last_test_performance:^6.2e}" + ) ) - ) logging.info("----------------") number_of_pattern_in_minibatch += h_x_labels.shape[0] @@ -345,18 +392,30 @@ def loop_train( # ##################################################### # Calculate the loss function # ##################################################### - my_loss: torch.Tensor | None = loss_function( - h=h_collection[-1], - labels=h_x_labels, - device=device, - default_dtype=default_dtype, - loss_mode=cfg.learning_parameters.loss_mode, - number_of_output_neurons=int( - cfg.network_structure.number_of_output_neurons - ), - loss_coeffs_mse=float(cfg.learning_parameters.loss_coeffs_mse), - loss_coeffs_kldiv=float(cfg.learning_parameters.loss_coeffs_kldiv), - ) + + if isinstance(network[-1], SbSReconstruction) is False: + my_loss: torch.Tensor | None = loss_function( + h=h_collection[-1], + labels=h_x_labels, + device=device, + default_dtype=default_dtype, + loss_mode=cfg.learning_parameters.loss_mode, + number_of_output_neurons=int( + cfg.network_structure.number_of_output_neurons + ), + loss_coeffs_mse=float(cfg.learning_parameters.loss_coeffs_mse), + loss_coeffs_kldiv=float(cfg.learning_parameters.loss_coeffs_kldiv), + ) + else: + assert cfg.learning_parameters.lr_scheduler_use_performance is False + my_loss = loss_function_reconstruction( + h_reco=h_collection[-1], + h_input=network[-2].last_input_data, + loss_mode=cfg.learning_parameters.loss_mode, + loss_coeffs_mse=float(cfg.learning_parameters.loss_coeffs_mse), + loss_coeffs_kldiv=float(cfg.learning_parameters.loss_coeffs_kldiv), + ) + assert my_loss is not None time_after_forward_and_loss: float = time.perf_counter() @@ -374,16 +433,17 @@ def loop_train( # Performance measures # ##################################################### - correct_in_minibatch += ( - (h_collection[-1].argmax(dim=1).squeeze().cpu() == h_x_labels) - .sum() - .item() - ) - full_correct += ( - (h_collection[-1].argmax(dim=1).squeeze().cpu() == h_x_labels) - .sum() - .item() - ) + if isinstance(network[-1], SbSReconstruction) is False: + correct_in_minibatch += ( + (h_collection[-1].argmax(dim=1).squeeze().cpu() == h_x_labels) + .sum() + .item() + ) + full_correct += ( + (h_collection[-1].argmax(dim=1).squeeze().cpu() == h_x_labels) + .sum() + .item() + ) # We measure the scale of the propagated error # during the first minibatch @@ -391,11 +451,11 @@ def loop_train( # the future error with it # Kind of deals with the vanishing / # exploding gradients - # deal_with_gradient_scale( - # epoch_id=epoch_id, - # mini_batch_number=mini_batch_number, - # network=network, - # ) + deal_with_gradient_scale( + epoch_id=epoch_id, + mini_batch_number=mini_batch_number, + network=network, + ) # Measure the time for one mini-batch time_forward += time_after_forward_and_loss - time_mini_batch_start @@ -403,21 +463,38 @@ def loop_train( if number_of_pattern_in_minibatch >= cfg.get_update_after_x_pattern(): - logging.info( - ( - f"{epoch_id:^6}=>{mini_batch_number:^6} " - f"\t\tTraining {number_of_pattern_in_minibatch^6} pattern " - f"with {correct_in_minibatch/number_of_pattern_in_minibatch:^6.2%} " - f"\tForward time: \t{time_forward:^6.2f}sec" + if isinstance(network[-1], SbSReconstruction) is False: + logging.info( + ( + f"{epoch_id:^6}=>{mini_batch_number:^6} " + f"\t\tTraining {number_of_pattern_in_minibatch^6} pattern " + f"with {correct_in_minibatch/number_of_pattern_in_minibatch:^6.2%} " + f"\tForward time: \t{time_forward:^6.2f}sec" + ) ) - ) - logging.info( - ( - f"\t\t\tLoss: {loss_in_minibatch/number_of_pattern_in_minibatch:^15.3e} " - f"\t\t\tBackward time: \t{time_backward:^6.2f}sec " + logging.info( + ( + f"\t\t\tLoss: {loss_in_minibatch/number_of_pattern_in_minibatch:^15.3e} " + f"\t\t\tBackward time: \t{time_backward:^6.2f}sec " + ) + ) + + else: + logging.info( + ( + f"{epoch_id:^6}=>{mini_batch_number:^6} " + f"\t\tTraining {number_of_pattern_in_minibatch^6} pattern " + f"\t\t\tForward time: \t{time_forward:^6.2f}sec" + ) + ) + + logging.info( + ( + f"\t\t\tLoss: {loss_in_minibatch/number_of_pattern_in_minibatch:^15.3e} " + f"\t\t\tBackward time: \t{time_backward:^6.2f}sec " + ) ) - ) my_loss_for_batch = loss_in_minibatch / number_of_pattern_in_minibatch @@ -510,3 +587,65 @@ def loop_test( tb.flush() return performance + + +def loop_test_reconstruction( + epoch_id: int, + cfg: Config, + network: torch.nn.modules.container.Sequential, + my_loader_test: torch.utils.data.dataloader.DataLoader, + the_dataset_test, + device: torch.device, + default_dtype: torch.dtype, + logging, + tb: SummaryWriter, +) -> float: + + test_count: int = 0 + test_loss: float = 0.0 + test_complete: int = the_dataset_test.__len__() + + logging.info("") + logging.info("Testing:") + + for h_x, h_x_labels in my_loader_test: + time_0 = time.perf_counter() + + h_collection = forward_pass_test( + input=h_x, + the_dataset_test=the_dataset_test, + cfg=cfg, + network=network, + device=device, + default_dtype=default_dtype, + ) + + my_loss: torch.Tensor | None = loss_function_reconstruction( + h_reco=h_collection[-1], + h_input=network[-2].last_input_data, + loss_mode=cfg.learning_parameters.loss_mode, + loss_coeffs_mse=float(cfg.learning_parameters.loss_coeffs_mse), + loss_coeffs_kldiv=float(cfg.learning_parameters.loss_coeffs_kldiv), + ) + + assert my_loss is not None + test_count += h_x_labels.shape[0] + test_loss += my_loss.item() + + performance = test_loss / test_count + time_1 = time.perf_counter() + time_measure_a = time_1 - time_0 + + logging.info( + ( + f"\t\t{test_count} of {test_complete}" + f" with {performance:^6.2e} \t Time used: {time_measure_a:^6.2f}sec" + ) + ) + + logging.info("") + + tb.add_scalar("Test Error", performance, epoch_id) + tb.flush() + + return performance