From 9bdadfef7cac03b484014361983346331106695f Mon Sep 17 00:00:00 2001 From: David Rotermund <54365609+davrot@users.noreply.github.com> Date: Wed, 26 Jul 2023 12:42:54 +0200 Subject: [PATCH] Add files via upload --- network/AutoNNMFLayer.py | 503 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 503 insertions(+) create mode 100644 network/AutoNNMFLayer.py diff --git a/network/AutoNNMFLayer.py b/network/AutoNNMFLayer.py new file mode 100644 index 0000000..20bdf85 --- /dev/null +++ b/network/AutoNNMFLayer.py @@ -0,0 +1,503 @@ +import torch + +from network.calculate_output_size import calculate_output_size + + +class AutoNNMFLayer(torch.nn.Module): + _epsilon_0: float + _weights: torch.nn.parameter.Parameter + _weights_exists: bool = False + _kernel_size: list[int] + _stride: list[int] + _dilation: list[int] + _padding: list[int] + _output_size: torch.Tensor + _inbetween_size: torch.Tensor + _number_of_neurons: int + _number_of_input_neurons: int + _h_initial: torch.Tensor | None = None + _w_trainable: bool + _weight_noise_range: list[float] + _input_size: list[int] + _output_layer: bool = False + _number_of_iterations: int + _local_learning: bool = False + + device: torch.device + default_dtype: torch.dtype + + _number_of_grad_weight_contributions: float = 0.0 + + last_input_store: bool = False + last_input_data: torch.Tensor | None = None + + _layer_id: int = -1 + + _skip_gradient_calculation: bool + + _keep_last_grad_scale: bool + _disable_scale_grade: bool + _last_grad_scale: torch.nn.parameter.Parameter + + def __init__( + self, + number_of_input_neurons: int, + number_of_neurons: int, + input_size: list[int], + forward_kernel_size: list[int], + number_of_iterations: int, + epsilon_0: float = 1.0, + weight_noise_range: list[float] = [0.0, 1.0], + strides: list[int] = [1, 1], + dilation: list[int] = [0, 0], + padding: list[int] = [0, 0], + w_trainable: bool = False, + device: torch.device | None = None, + default_dtype: torch.dtype | None = None, + layer_id: int = -1, + local_learning: bool = False, + output_layer: bool = False, + skip_gradient_calculation: bool = False, + keep_last_grad_scale: bool = False, + disable_scale_grade: bool = True, + ) -> None: + super().__init__() + + assert device is not None + assert default_dtype is not None + self.device = device + self.default_dtype = default_dtype + + self._w_trainable = bool(w_trainable) + self._stride = strides + self._dilation = dilation + self._padding = padding + self._kernel_size = forward_kernel_size + self._number_of_input_neurons = int(number_of_input_neurons) + self._number_of_neurons = int(number_of_neurons) + self._epsilon_0 = float(epsilon_0) + self._number_of_iterations = int(number_of_iterations) + self._weight_noise_range = weight_noise_range + self._layer_id = layer_id + self._local_learning = local_learning + self._output_layer = output_layer + self._skip_gradient_calculation = skip_gradient_calculation + + self._keep_last_grad_scale = bool(keep_last_grad_scale) + self._disable_scale_grade = bool(disable_scale_grade) + self._last_grad_scale = torch.nn.parameter.Parameter( + torch.tensor(-1.0, dtype=self.default_dtype), + requires_grad=True, + ) + + assert len(input_size) == 2 + self._input_size = input_size + self._output_size = torch.tensor(input_size) + + self._inbetween_size = calculate_output_size( + value=input_size, + kernel_size=self._kernel_size, + stride=self._stride, + dilation=self._dilation, + padding=self._padding, + ) + + self.set_h_init_to_uniform() + + # ############################################################### + # Initialize the weights + # ############################################################### + + assert len(self._weight_noise_range) == 2 + weights = torch.empty( + ( + int(self._kernel_size[0]) + * int(self._kernel_size[1]) + * int(self._number_of_input_neurons), + int(self._number_of_neurons), + ), + dtype=self.default_dtype, + device=self.device, + ) + + torch.nn.init.uniform_( + weights, + a=float(self._weight_noise_range[0]), + b=float(self._weight_noise_range[1]), + ) + self.weights = weights + + self.functional_nnmf_sbs_bp = FunctionalAutoNNMF.apply + + @property + def weights(self) -> torch.Tensor | None: + if self._weights_exists is False: + return None + else: + return self._weights + + @weights.setter + def weights(self, value: torch.Tensor): + assert value is not None + assert torch.is_tensor(value) is True + assert value.dim() == 2 + temp: torch.Tensor = ( + value.detach() + .clone(memory_format=torch.contiguous_format) + .type(dtype=self.default_dtype) + .to(device=self.device) + ) + temp /= temp.sum(dim=0, keepdim=True, dtype=self.default_dtype) + if self._weights_exists is False: + self._weights = torch.nn.parameter.Parameter(temp, requires_grad=True) + self._weights_exists = True + else: + self._weights.data = temp + + @property + def h_initial(self) -> torch.Tensor | None: + return self._h_initial + + @h_initial.setter + def h_initial(self, value: torch.Tensor): + assert value is not None + assert torch.is_tensor(value) is True + assert value.dim() == 1 + assert value.dtype == self.default_dtype + self._h_initial = ( + value.detach() + .clone(memory_format=torch.contiguous_format) + .type(dtype=self.default_dtype) + .to(device=self.device) + .requires_grad_(False) + ) + + def update_pre_care(self): + if self._weights.grad is not None: + assert self._number_of_grad_weight_contributions > 0 + self._weights.grad /= self._number_of_grad_weight_contributions + self._number_of_grad_weight_contributions = 0.0 + + def update_after_care(self, threshold_weight: float): + if self._w_trainable is True: + self.norm_weights() + self.threshold_weights(threshold_weight) + self.norm_weights() + + def after_batch(self, new_state: bool = False): + if self._keep_last_grad_scale is True: + self._last_grad_scale.data = self._last_grad_scale.grad # type: ignore + self._keep_last_grad_scale = new_state + + self._last_grad_scale.grad = torch.zeros_like(self._last_grad_scale.grad) # type: ignore + + def set_h_init_to_uniform(self) -> None: + assert self._number_of_neurons > 2 + + self.h_initial: torch.Tensor = torch.full( + (self._number_of_neurons,), + (1.0 / float(self._number_of_neurons)), + dtype=self.default_dtype, + device=self.device, + ) + + def norm_weights(self) -> None: + assert self._weights_exists is True + temp: torch.Tensor = ( + self._weights.data.detach() + .clone(memory_format=torch.contiguous_format) + .type(dtype=self.default_dtype) + .to(device=self.device) + ) + temp /= temp.sum(dim=0, keepdim=True, dtype=self.default_dtype) + self._weights.data = temp + + def threshold_weights(self, threshold: float) -> None: + assert self._weights_exists is True + assert threshold >= 0 + + torch.clamp( + self._weights.data, + min=float(threshold), + max=None, + out=self._weights.data, + ) + + #################################################################### + # Forward # + #################################################################### + + def forward( + self, + input: torch.Tensor, + ) -> torch.Tensor: + # Are we happy with the input? + assert input is not None + assert torch.is_tensor(input) is True + assert input.dim() == 4 + assert input.dtype == self.default_dtype + assert input.shape[1] == self._number_of_input_neurons + assert input.shape[2] == self._input_size[0] + assert input.shape[3] == self._input_size[1] + + # Are we happy with the rest of the network? + assert self._epsilon_0 is not None + assert self._h_initial is not None + assert self._weights_exists is True + assert self._weights is not None + + # Convolution of the input... + # Well, this is a convoltion layer + # there needs to be convolution somewhere + input_convolved = torch.nn.functional.fold( + torch.nn.functional.unfold( + input.requires_grad_(True), + kernel_size=(int(self._kernel_size[0]), int(self._kernel_size[1])), + dilation=(int(self._dilation[0]), int(self._dilation[1])), + padding=(int(self._padding[0]), int(self._padding[1])), + stride=(int(self._stride[0]), int(self._stride[1])), + ), + output_size=tuple(self._inbetween_size.tolist()), + kernel_size=(1, 1), + dilation=(1, 1), + padding=(0, 0), + stride=(1, 1), + ) + + # We might need the convolved input for other layers + # let us keep it for the future + if self.last_input_store is True: + self.last_input_data = input_convolved.detach().clone() + self.last_input_data /= self.last_input_data.sum(dim=1, keepdim=True) + else: + self.last_input_data = None + + input_convolved = input_convolved / ( + input_convolved.sum(dim=1, keepdim=True) + 1e-20 + ) + + parameter_list = torch.tensor( + [ + int(self._inbetween_size[0]), # 0 + int(self._inbetween_size[1]), # 1 + int(self._number_of_iterations), # 2 + int(self._w_trainable), # 3 + int(self._skip_gradient_calculation), # 4 + int(self._keep_last_grad_scale), # 5 + int(self._disable_scale_grade), # 6 + int(self._local_learning), # 7 + int(self._output_layer), # 8 + ], + dtype=torch.int64, + ) + + epsilon_0 = torch.tensor(self._epsilon_0) + + h = self.functional_nnmf_sbs_bp( + input_convolved, + epsilon_0, + self._weights, + self._h_initial, + parameter_list, + self._last_grad_scale, + ) + + self._number_of_grad_weight_contributions += ( + h.shape[0] * h.shape[-2] * h.shape[-1] + ) + + # The decoding weight will not be optimized!!! + # (self._weights.detach().clone()) + output_convolved = ( + h.unsqueeze(1) * (self._weights.detach().clone()).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + ).sum(dim=2) + + output = torch.nn.functional.fold( + torch.nn.functional.unfold( + output_convolved, + kernel_size=(1, 1), + dilation=1, + padding=0, + stride=1, + ), + output_size=(int(input.shape[-2]), int(input.shape[-1])), + kernel_size=(int(self._kernel_size[0]), int(self._kernel_size[1])), + dilation=(int(self._dilation[0]), int(self._dilation[1])), + padding=(int(self._padding[0]), int(self._padding[1])), + stride=(int(self._stride[0]), int(self._stride[1])), + ) + + return output + + +class FunctionalAutoNNMF(torch.autograd.Function): + @staticmethod + def forward( # type: ignore + ctx, + input: torch.Tensor, + epsilon_0: torch.Tensor, + weights: torch.Tensor, + h_initial: torch.Tensor, + parameter_list: torch.Tensor, + grad_output_scale: torch.Tensor, + ) -> torch.Tensor: + output_size_0: int = int(parameter_list[0]) + output_size_1: int = int(parameter_list[1]) + number_of_iterations: int = int(parameter_list[2]) + + _epsilon_0 = epsilon_0.item() + + h = torch.tile( + h_initial.unsqueeze(0).unsqueeze(-1).unsqueeze(-1), + dims=[ + int(input.shape[0]), + 1, + int(output_size_0), + int(output_size_1), + ], + ) + + for _ in range(0, number_of_iterations): + h_w = h.unsqueeze(1) * weights.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + h_w = h_w / (h_w.sum(dim=2, keepdim=True) + 1e-20) + h_w = (h_w * input.unsqueeze(2)).sum(dim=1) + if _epsilon_0 > 0: + h = h + _epsilon_0 * h_w + else: + h = h_w + h = h / (h.sum(dim=1, keepdim=True) + 1e-20) + + ctx.save_for_backward( + input, + weights, + h, + parameter_list, + grad_output_scale, + ) + + return h + + @staticmethod + def backward(ctx, grad_output): + # ############################################## + # Get the variables back + # ############################################## + ( + input, + weights, + output, + parameter_list, + last_grad_scale, + ) = ctx.saved_tensors + + # ############################################## + # Default output + # ############################################## + grad_input = None + grad_epsilon_0 = None + grad_weights = None + grad_h_initial = None + grad_parameter_list = None + + # ############################################## + # Parameters + # ############################################## + parameter_w_trainable: bool = bool(parameter_list[3]) + parameter_skip_gradient_calculation: bool = bool(parameter_list[4]) + parameter_keep_last_grad_scale: bool = bool(parameter_list[5]) + parameter_disable_scale_grade: bool = bool(parameter_list[6]) + parameter_local_learning: bool = bool(parameter_list[7]) + parameter_output_layer: bool = bool(parameter_list[8]) + + # ############################################## + # Dealing with overall scale of the gradient + # ############################################## + if parameter_disable_scale_grade is False: + if parameter_keep_last_grad_scale is True: + last_grad_scale = torch.tensor( + [torch.abs(grad_output).max(), last_grad_scale] + ).max() + grad_output /= last_grad_scale + grad_output_scale = last_grad_scale.clone() + + input /= input.sum(dim=1, keepdim=True, dtype=weights.dtype) + 1e-20 + + # ################################################# + # User doesn't want us to calculate the gradients + # ################################################# + + if parameter_skip_gradient_calculation is True: + return ( + grad_input, + grad_epsilon_0, + grad_weights, + grad_h_initial, + grad_parameter_list, + grad_output_scale, + ) + + # ################################################# + # Calculate backprop error (grad_input) + # ################################################# + + backprop_r: torch.Tensor = weights.unsqueeze(0).unsqueeze(-1).unsqueeze( + -1 + ) * output.unsqueeze(1) + + backprop_bigr: torch.Tensor = backprop_r.sum(dim=2) + + backprop_z: torch.Tensor = backprop_r * ( + 1.0 / (backprop_bigr + 1e-20) + ).unsqueeze(2) + grad_input: torch.Tensor = (backprop_z * grad_output.unsqueeze(1)).sum(2) + del backprop_z + + # ################################################# + # Calculate weight gradient (grad_weights) + # ################################################# + + if parameter_w_trainable is False: + # ################################################# + # We don't train this weight + # ################################################# + grad_weights = None + + elif (parameter_output_layer is False) and (parameter_local_learning is True): + # ################################################# + # Local learning + # ################################################# + grad_weights = ( + (-2 * (input - backprop_bigr).unsqueeze(2) * output.unsqueeze(1)) + .sum(0) + .sum(-1) + .sum(-1) + ) + + else: + # ################################################# + # Backprop + # ################################################# + backprop_f: torch.Tensor = output.unsqueeze(1) * ( + input / (backprop_bigr**2 + 1e-20) + ).unsqueeze(2) + + result_omega: torch.Tensor = backprop_bigr.unsqueeze( + 2 + ) * grad_output.unsqueeze(1) + result_omega -= (backprop_r * grad_output.unsqueeze(1)).sum(2).unsqueeze(2) + result_omega *= backprop_f + del backprop_f + grad_weights = result_omega.sum(0).sum(-1).sum(-1) + del result_omega + + del backprop_bigr + del backprop_r + + return ( + grad_input, + grad_epsilon_0, + grad_weights, + grad_h_initial, + grad_parameter_list, + grad_output_scale, + )