diff --git a/NNMFConv2d.py b/NNMFConv2d.py new file mode 100644 index 0000000..51b21eb --- /dev/null +++ b/NNMFConv2d.py @@ -0,0 +1,502 @@ +import torch +from non_linear_weigth_function import non_linear_weigth_function + + +class NNMFConv2d(torch.nn.Module): + + in_channels: int + out_channels: int + kernel_size: tuple[int, ...] + stride: tuple[int, ...] + padding: str | tuple[int, ...] + dilation: tuple[int, ...] + weight: torch.Tensor + bias: None | torch.Tensor + output_size: None | torch.Tensor = None + convolution_contribution_map: None | torch.Tensor = None + iterations: int + convolution_contribution_map_enable: bool + epsilon: float | None + init_min: float + init_max: float + beta: torch.Tensor | None + positive_function_type: int + use_convolution: bool + local_learning: bool + local_learning_kl: bool + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: tuple[int, int], + stride: tuple[int, int] = (1, 1), + padding: str | tuple[int, int] = (0, 0), + dilation: tuple[int, int] = (1, 1), + device=None, + dtype=None, + iterations: int = 20, + convolution_contribution_map_enable: bool = False, + epsilon: float | None = None, + init_min: float = 0.0, + init_max: float = 1.0, + beta: float | None = None, + positive_function_type: int = 0, + use_convolution: bool = False, + local_learning: bool = False, + local_learning_kl: bool = False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + + super().__init__() + + valid_padding_strings = {"same", "valid"} + if isinstance(padding, str): + if padding not in valid_padding_strings: + raise ValueError( + f"Invalid padding string {padding!r}, should be one of {valid_padding_strings}" + ) + if padding == "same" and any(s != 1 for s in stride): + raise ValueError( + "padding='same' is not supported for strided convolutions" + ) + + self.positive_function_type = positive_function_type + self.init_min = init_min + self.init_max = init_max + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + + self.stride = stride + self.padding = padding + self.dilation = dilation + + self.iterations = iterations + self.convolution_contribution_map_enable = convolution_contribution_map_enable + + self.local_learning = local_learning + self.local_learning_kl = local_learning_kl + + self.weight = torch.nn.parameter.Parameter( + torch.empty((out_channels, in_channels, *kernel_size), **factory_kwargs) + ) + + if beta is not None: + self.beta = torch.nn.parameter.Parameter(torch.empty((1), **factory_kwargs)) + self.beta.data[0] = beta + else: + self.beta = None + + self.reset_parameters() + self.functional_nnmf_conv2d = FunctionalNNMFConv2d.apply + + self.epsilon = epsilon + self.use_convolution = use_convolution + + def extra_repr(self) -> str: + s: str = f"{self.in_channels}, {self.out_channels}" + + s += f", kernel_size={self.kernel_size}" + s += f", stride={self.stride}, iterations={self.iterations}" + s += f", epsilon={self.epsilon}" + s += f", use_convolution={self.use_convolution}" + + if self.use_convolution: + s += f", ccmap={self.convolution_contribution_map_enable}" + + s += f", pfunctype={self.positive_function_type}" + s += f", local_learning={self.local_learning}" + + if self.local_learning: + s += f", local_learning_kl={self.local_learning_kl}" + + if self.padding != (0,) * len(self.padding): + s += f", padding={self.padding}" + + if self.dilation != (1,) * len(self.dilation): + s += f", dilation={self.dilation}" + + return s + + def reset_parameters(self) -> None: + torch.nn.init.uniform_(self.weight, a=self.init_min, b=self.init_max) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + + if input.ndim == 2: + input = input.unsqueeze(-1) + if input.ndim == 3: + input = input.unsqueeze(-1) + + if self.output_size is None: + self.output_size = torch.tensor( + torch.nn.functional.conv2d( + torch.zeros( + 1, + input.shape[1], + input.shape[2], + input.shape[3], + device=self.weight.device, + dtype=self.weight.dtype, + ), + torch.zeros_like(self.weight), + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + ).shape, + requires_grad=False, + ) + assert self.output_size is not None + + if self.use_convolution is False: + + input = torch.nn.functional.fold( + torch.nn.functional.unfold( + input.requires_grad_(True), + kernel_size=self.kernel_size, + dilation=self.dilation, + padding=self.padding, + stride=self.stride, + ), + output_size=self.output_size[-2:], + kernel_size=(1, 1), + dilation=(1, 1), + padding=(0, 0), + stride=(1, 1), + ) + + if ( + (self.convolution_contribution_map is None) + and (self.convolution_contribution_map_enable) + and (self.use_convolution) + ): + + self.convolution_contribution_map = torch.nn.functional.conv_transpose2d( + torch.full( + self.output_size.tolist(), + 1.0 / float(self.output_size[1]), + dtype=self.weight.dtype, + device=self.weight.device, + requires_grad=False, + ), + torch.ones_like(self.weight, requires_grad=False), + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + ) * ( + (input.shape[1] * input.shape[2] * input.shape[3]) + / (self.weight.shape[1] * self.weight.shape[2] * self.weight.shape[3]) + ) + + if self.convolution_contribution_map_enable and self.use_convolution: + assert self.convolution_contribution_map is not None + + positive_weights = non_linear_weigth_function( + self.weight, self.beta, self.positive_function_type + ) + positive_weights = positive_weights / ( + positive_weights.sum((1, 2, 3), keepdim=True) + 10e-20 + ) + + if self.use_convolution is False: + positive_weights = positive_weights.reshape( + positive_weights.shape[0], + positive_weights.shape[1] + * positive_weights.shape[2] + * positive_weights.shape[3], + ) + + # Prepare input + if self.use_convolution: + input = input / (input.sum((1, 2, 3), keepdim=True) + 10e-20) + if self.convolution_contribution_map is not None: + input = input * self.convolution_contribution_map + else: + input = input / (input.sum(dim=1, keepdim=True) + 10e-20) + + return self.functional_nnmf_conv2d( + input, + positive_weights, + self.output_size, + self.iterations, + self.stride, + self.padding, + self.dilation, + self.epsilon, + self.use_convolution, + self.local_learning, + self.local_learning_kl, + ) + + +class FunctionalNNMFConv2d(torch.autograd.Function): + @staticmethod + def forward( # type: ignore + ctx, + input: torch.Tensor, + weight: torch.Tensor, + output_size: torch.Tensor, + iterations: int, + stride: tuple[int, int], + padding: str | tuple[int, int], + dilation: tuple[int, int], + epsilon: float | None, + use_convolution: bool, + local_learning: bool, + local_learning_kl: bool, + ) -> torch.Tensor: + + # Prepare h + output_size[0] = input.shape[0] + h = torch.full( + output_size.tolist(), + 1.0 / float(output_size[1]), + device=input.device, + dtype=input.dtype, + ) + + if use_convolution: + for _ in range(0, iterations): + factor_x_div_r: torch.Tensor = input / ( + torch.nn.functional.conv_transpose2d( + h, + weight, + stride=stride, + padding=padding, + dilation=dilation, + ) + + 10e-20 + ) + + if epsilon is None: + h *= torch.nn.functional.conv2d( + factor_x_div_r, + weight, + stride=stride, + padding=padding, + dilation=dilation, + ) + else: + h *= 1 + epsilon * torch.nn.functional.conv2d( + factor_x_div_r, + weight, + stride=stride, + padding=padding, + dilation=dilation, + ) + + h /= h.sum(1, keepdim=True) + 10e-20 + else: + h = h.movedim(1, -1) + input = input.movedim(1, -1) + for _ in range(0, iterations): + reconstruction = torch.nn.functional.linear(h, weight.T) + reconstruction += 1e-20 + if epsilon is None: + h *= torch.nn.functional.linear((input / reconstruction), weight) + else: + h *= 1 + epsilon * torch.nn.functional.linear( + (input / reconstruction), weight + ) + h /= h.sum(-1, keepdim=True) + 10e-20 + h = h.movedim(-1, 1) + input = input.movedim(-1, 1) + + # ########################################################### + # Save the necessary data for the backward pass + # ########################################################### + ctx.save_for_backward(input, weight, h) + + ctx.stride = stride + ctx.padding = padding + ctx.dilation = dilation + ctx.use_convolution = use_convolution + ctx.local_learning = local_learning + ctx.local_learning_kl = local_learning_kl + + assert torch.isfinite(h).all() + return h + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, grad_output: torch.Tensor) -> tuple[ # type: ignore + torch.Tensor | None, + torch.Tensor | None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ]: + + # ############################################## + # Default values + # ############################################## + grad_input: torch.Tensor | None = None + grad_weight: torch.Tensor | None = None + + # ############################################## + # Get the variables back + # ############################################## + (input, weight, h) = ctx.saved_tensors + + if ctx.use_convolution: + big_r: torch.Tensor = torch.nn.functional.conv_transpose2d( + h, + weight, + stride=ctx.stride, + padding=ctx.padding, + dilation=ctx.dilation, + ) + big_r_div = 1.0 / (big_r + 1e-20) + + factor_x_div_r: torch.Tensor = input * big_r_div + + grad_input = ( + torch.nn.functional.conv_transpose2d( + (h * grad_output), + weight, + stride=ctx.stride, + padding=ctx.padding, + dilation=ctx.dilation, + ) + * big_r_div + ) + + del big_r_div + if ctx.local_learning is False: + del big_r + grad_weight = -torch.nn.functional.conv2d( + (factor_x_div_r * grad_input).movedim(0, 1), + h.movedim(0, 1), + stride=ctx.dilation, + padding=ctx.padding, + dilation=ctx.stride, + ) + + grad_weight += torch.nn.functional.conv2d( + factor_x_div_r.movedim(0, 1), + (h * grad_output).movedim(0, 1), + stride=ctx.dilation, + padding=ctx.padding, + dilation=ctx.stride, + ) + + else: + if ctx.local_learning_kl: + grad_weight = -torch.nn.functional.conv2d( + factor_x_div_r.movedim(0, 1), + h.movedim(0, 1), + stride=ctx.dilation, + padding=ctx.padding, + dilation=ctx.stride, + ) + else: + grad_weight = -torch.nn.functional.conv2d( + (2 * (input - big_r)).movedim(0, 1), + h.movedim(0, 1), + stride=ctx.dilation, + padding=ctx.padding, + dilation=ctx.stride, + ) + grad_weight = grad_weight.movedim(0, 1) + else: + h = h.movedim(1, -1) + grad_output = grad_output.movedim(1, -1) + input = input.movedim(1, -1) + big_r = torch.nn.functional.linear(h, weight.T) + big_r_div = 1.0 / (big_r + 1e-20) + + factor_x_div_r = input * big_r_div + + grad_input = ( + torch.nn.functional.linear(h * grad_output, weight.T) * big_r_div + ) + + del big_r_div + + if ctx.local_learning is False: + del big_r + + grad_weight = -torch.nn.functional.linear( + h.reshape( + grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2], + h.shape[3], + ).T, + (factor_x_div_r * grad_input) + .reshape( + grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2], + grad_input.shape[3], + ) + .T, + ) + + grad_weight += torch.nn.functional.linear( + (h * grad_output) + .reshape( + grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2], + h.shape[3], + ) + .T, + factor_x_div_r.reshape( + grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2], + grad_input.shape[3], + ).T, + ) + + else: + if ctx.local_learning_kl: + grad_weight = -torch.nn.functional.linear( + h.reshape( + grad_input.shape[0] + * grad_input.shape[1] + * grad_input.shape[2], + h.shape[3], + ).T, + factor_x_div_r.reshape( + grad_input.shape[0] + * grad_input.shape[1] + * grad_input.shape[2], + grad_input.shape[3], + ).T, + ) + else: + grad_weight = -torch.nn.functional.linear( + h.reshape( + grad_input.shape[0] + * grad_input.shape[1] + * grad_input.shape[2], + h.shape[3], + ).T, + (2 * (input - big_r)) + .reshape( + grad_input.shape[0] + * grad_input.shape[1] + * grad_input.shape[2], + grad_input.shape[3], + ) + .T, + ) + grad_input = grad_input.movedim(-1, 1) + assert torch.isfinite(grad_input).all() + assert torch.isfinite(grad_weight).all() + + return ( + grad_input, + grad_weight, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) diff --git a/NNMFConv2dP.py b/NNMFConv2dP.py new file mode 100644 index 0000000..9dedda3 --- /dev/null +++ b/NNMFConv2dP.py @@ -0,0 +1,480 @@ +import torch +from non_linear_weigth_function import non_linear_weigth_function + + +class NNMFConv2dP(torch.nn.Module): + + in_channels: int + out_channels: int + kernel_size: tuple[int, ...] + stride: tuple[int, ...] + padding: str | tuple[int, ...] + dilation: tuple[int, ...] + weight: torch.Tensor + bias: None | torch.Tensor + output_size: None | torch.Tensor = None + convolution_contribution_map: None | torch.Tensor = None + iterations: int + convolution_contribution_map_enable: bool + epsilon: float | None + init_min: float + init_max: float + beta: torch.Tensor | None + positive_function_type: int + use_convolution: bool + local_learning: bool + local_learning_kl: bool + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: tuple[int, int], + stride: tuple[int, int] = (1, 1), + padding: str | tuple[int, int] = (0, 0), + dilation: tuple[int, int] = (1, 1), + device=None, + dtype=None, + iterations: int = 20, + convolution_contribution_map_enable: bool = False, + epsilon: float | None = None, + init_min: float = 0.0, + init_max: float = 1.0, + beta: float | None = None, + positive_function_type: int = 0, + use_convolution: bool = False, + local_learning: bool = False, + local_learning_kl: bool = False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + + super().__init__() + + valid_padding_strings = {"same", "valid"} + if isinstance(padding, str): + if padding not in valid_padding_strings: + raise ValueError( + f"Invalid padding string {padding!r}, should be one of {valid_padding_strings}" + ) + if padding == "same" and any(s != 1 for s in stride): + raise ValueError( + "padding='same' is not supported for strided convolutions" + ) + + self.positive_function_type = positive_function_type + self.init_min = init_min + self.init_max = init_max + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + + self.stride = stride + self.padding = padding + self.dilation = dilation + + self.iterations = iterations + self.convolution_contribution_map_enable = convolution_contribution_map_enable + + self.local_learning = local_learning + self.local_learning_kl = local_learning_kl + + self.weight = torch.nn.parameter.Parameter( + torch.empty((out_channels, in_channels, *kernel_size), **factory_kwargs) + ) + + if beta is not None: + self.beta = torch.nn.parameter.Parameter(torch.empty((1), **factory_kwargs)) + self.beta.data[0] = beta + else: + self.beta = None + + self.reset_parameters() + self.functional_nnmf_conv2d = FunctionalNNMFConv2dP.apply + + self.epsilon = epsilon + self.use_convolution = use_convolution + + assert self.use_convolution is False + + def extra_repr(self) -> str: + s: str = f"{self.in_channels}, {self.out_channels}" + + s += f", kernel_size={self.kernel_size}" + s += f", stride={self.stride}, iterations={self.iterations}" + s += f", epsilon={self.epsilon}" + s += f", use_convolution={self.use_convolution}" + + if self.use_convolution: + s += f", ccmap={self.convolution_contribution_map_enable}" + + s += f", pfunctype={self.positive_function_type}" + s += f", local_learning={self.local_learning}" + + if self.local_learning: + s += f", local_learning_kl={self.local_learning_kl}" + + if self.padding != (0,) * len(self.padding): + s += f", padding={self.padding}" + + if self.dilation != (1,) * len(self.dilation): + s += f", dilation={self.dilation}" + + return s + + def reset_parameters(self) -> None: + torch.nn.init.uniform_(self.weight, a=self.init_min, b=self.init_max) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + + if input.ndim == 2: + input = input.unsqueeze(-1) + if input.ndim == 3: + input = input.unsqueeze(-1) + + if self.output_size is None: + self.output_size = torch.tensor( + torch.nn.functional.conv2d( + torch.zeros( + 1, + input.shape[1], + input.shape[2], + input.shape[3], + device=self.weight.device, + dtype=self.weight.dtype, + ), + torch.zeros_like(self.weight), + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + ).shape, + requires_grad=False, + ) + assert self.output_size is not None + + input = torch.nn.functional.fold( + torch.nn.functional.unfold( + input.requires_grad_(True), + kernel_size=self.kernel_size, + dilation=self.dilation, + padding=self.padding, + stride=self.stride, + ), + output_size=self.output_size[-2:], + kernel_size=(1, 1), + dilation=(1, 1), + padding=(0, 0), + stride=(1, 1), + ) + + positive_weights = non_linear_weigth_function( + self.weight, self.beta, self.positive_function_type + ) + positive_weights = positive_weights / ( + positive_weights.sum((1, 2, 3), keepdim=True) + 10e-20 + ) + + positive_weights = positive_weights.reshape( + positive_weights.shape[0], + positive_weights.shape[1] + * positive_weights.shape[2] + * positive_weights.shape[3], + ) + + # Prepare input + input = input / (input.sum(dim=1, keepdim=True) + 10e-20) + + h_dyn = self.functional_nnmf_conv2d( + input, + positive_weights, + self.output_size, + self.iterations, + self.stride, + self.padding, + self.dilation, + self.epsilon, + self.use_convolution, + self.local_learning, + self.local_learning_kl, + ) + self.reco = False + if self.reco: + print(h_dyn.shape) + print(positive_weights.shape) + print(input.shape) + exit() + output = torch.cat((h_dyn, input), dim=1) + else: + output = torch.cat((h_dyn, input), dim=1) + return output + + +class FunctionalNNMFConv2dP(torch.autograd.Function): + @staticmethod + def forward( # type: ignore + ctx, + input: torch.Tensor, + weight: torch.Tensor, + output_size: torch.Tensor, + iterations: int, + stride: tuple[int, int], + padding: str | tuple[int, int], + dilation: tuple[int, int], + epsilon: float | None, + use_convolution: bool, + local_learning: bool, + local_learning_kl: bool, + ) -> torch.Tensor: + + # Prepare h + output_size[0] = input.shape[0] + h = torch.full( + output_size.tolist(), + 1.0 / float(output_size[1]), + device=input.device, + dtype=input.dtype, + ) + + if use_convolution: + for _ in range(0, iterations): + factor_x_div_r: torch.Tensor = input / ( + torch.nn.functional.conv_transpose2d( + h, + weight, + stride=stride, + padding=padding, + dilation=dilation, + ) + + 10e-20 + ) + + if epsilon is None: + h *= torch.nn.functional.conv2d( + factor_x_div_r, + weight, + stride=stride, + padding=padding, + dilation=dilation, + ) + else: + h *= 1 + epsilon * torch.nn.functional.conv2d( + factor_x_div_r, + weight, + stride=stride, + padding=padding, + dilation=dilation, + ) + + h /= h.sum(1, keepdim=True) + 10e-20 + else: + h = h.movedim(1, -1) + input = input.movedim(1, -1) + for _ in range(0, iterations): + reconstruction = torch.nn.functional.linear(h, weight.T) + reconstruction += 1e-20 + if epsilon is None: + h *= torch.nn.functional.linear((input / reconstruction), weight) + else: + h *= 1 + epsilon * torch.nn.functional.linear( + (input / reconstruction), weight + ) + h /= h.sum(-1, keepdim=True) + 10e-20 + h = h.movedim(-1, 1) + input = input.movedim(-1, 1) + + # ########################################################### + # Save the necessary data for the backward pass + # ########################################################### + ctx.save_for_backward(input, weight, h) + + ctx.stride = stride + ctx.padding = padding + ctx.dilation = dilation + ctx.use_convolution = use_convolution + ctx.local_learning = local_learning + ctx.local_learning_kl = local_learning_kl + + assert torch.isfinite(h).all() + return h + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, grad_output: torch.Tensor) -> tuple[ # type: ignore + torch.Tensor | None, + torch.Tensor | None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ]: + + # ############################################## + # Default values + # ############################################## + grad_input: torch.Tensor | None = None + grad_weight: torch.Tensor | None = None + + # ############################################## + # Get the variables back + # ############################################## + (input, weight, h) = ctx.saved_tensors + + if ctx.use_convolution: + big_r: torch.Tensor = torch.nn.functional.conv_transpose2d( + h, + weight, + stride=ctx.stride, + padding=ctx.padding, + dilation=ctx.dilation, + ) + big_r_div = 1.0 / (big_r + 1e-20) + + factor_x_div_r: torch.Tensor = input * big_r_div + + grad_input = ( + torch.nn.functional.conv_transpose2d( + (h * grad_output), + weight, + stride=ctx.stride, + padding=ctx.padding, + dilation=ctx.dilation, + ) + * big_r_div + ) + + del big_r_div + if ctx.local_learning is False: + del big_r + grad_weight = -torch.nn.functional.conv2d( + (factor_x_div_r * grad_input).movedim(0, 1), + h.movedim(0, 1), + stride=ctx.dilation, + padding=ctx.padding, + dilation=ctx.stride, + ) + + grad_weight += torch.nn.functional.conv2d( + factor_x_div_r.movedim(0, 1), + (h * grad_output).movedim(0, 1), + stride=ctx.dilation, + padding=ctx.padding, + dilation=ctx.stride, + ) + + else: + if ctx.local_learning_kl: + grad_weight = -torch.nn.functional.conv2d( + factor_x_div_r.movedim(0, 1), + h.movedim(0, 1), + stride=ctx.dilation, + padding=ctx.padding, + dilation=ctx.stride, + ) + else: + grad_weight = -torch.nn.functional.conv2d( + (2 * (input - big_r)).movedim(0, 1), + h.movedim(0, 1), + stride=ctx.dilation, + padding=ctx.padding, + dilation=ctx.stride, + ) + grad_weight = grad_weight.movedim(0, 1) + else: + h = h.movedim(1, -1) + grad_output = grad_output.movedim(1, -1) + input = input.movedim(1, -1) + big_r = torch.nn.functional.linear(h, weight.T) + big_r_div = 1.0 / (big_r + 1e-20) + + factor_x_div_r = input * big_r_div + + grad_input = ( + torch.nn.functional.linear(h * grad_output, weight.T) * big_r_div + ) + + del big_r_div + + if ctx.local_learning is False: + del big_r + + grad_weight = -torch.nn.functional.linear( + h.reshape( + grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2], + h.shape[3], + ).T, + (factor_x_div_r * grad_input) + .reshape( + grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2], + grad_input.shape[3], + ) + .T, + ) + + grad_weight += torch.nn.functional.linear( + (h * grad_output) + .reshape( + grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2], + h.shape[3], + ) + .T, + factor_x_div_r.reshape( + grad_input.shape[0] * grad_input.shape[1] * grad_input.shape[2], + grad_input.shape[3], + ).T, + ) + + else: + if ctx.local_learning_kl: + grad_weight = -torch.nn.functional.linear( + h.reshape( + grad_input.shape[0] + * grad_input.shape[1] + * grad_input.shape[2], + h.shape[3], + ).T, + factor_x_div_r.reshape( + grad_input.shape[0] + * grad_input.shape[1] + * grad_input.shape[2], + grad_input.shape[3], + ).T, + ) + else: + grad_weight = -torch.nn.functional.linear( + h.reshape( + grad_input.shape[0] + * grad_input.shape[1] + * grad_input.shape[2], + h.shape[3], + ).T, + (2 * (input - big_r)) + .reshape( + grad_input.shape[0] + * grad_input.shape[1] + * grad_input.shape[2], + grad_input.shape[3], + ) + .T, + ) + grad_input = grad_input.movedim(-1, 1) + assert torch.isfinite(grad_input).all() + assert torch.isfinite(grad_weight).all() + + return ( + grad_input, + grad_weight, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) diff --git a/SplitOnOffLayer.py b/SplitOnOffLayer.py new file mode 100644 index 0000000..b501092 --- /dev/null +++ b/SplitOnOffLayer.py @@ -0,0 +1,23 @@ +import torch + +class SplitOnOffLayer(torch.nn.Module): + + def __init__( + self, + ) -> None: + super().__init__() + + + #################################################################### + # Forward # + #################################################################### + + def forward(self, input: torch.Tensor) -> torch.Tensor: + assert input.ndim == 4 + + temp = input - 0.5 + temp_a = torch.nn.functional.relu(temp) + temp_b = torch.nn.functional.relu(-temp) + output = torch.cat((temp_a, temp_b), dim=1) + + return output diff --git a/convert_log_to_numpy.py b/convert_log_to_numpy.py new file mode 100644 index 0000000..0ff4b8f --- /dev/null +++ b/convert_log_to_numpy.py @@ -0,0 +1,29 @@ +import os +import glob + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + +from tensorboard.backend.event_processing import event_accumulator +import numpy as np + + +def get_data(path: str = "log_cnn"): + acc = event_accumulator.EventAccumulator(path) + acc.Reload() + + which_scalar = "Test Number Correct" + te = acc.Scalars(which_scalar) + + np_temp = np.zeros((len(te), 2)) + + for id in range(0, len(te)): + np_temp[id, 0] = te[id].step + np_temp[id, 1] = te[id].value + + return np_temp + + +for path in glob.glob("log_*"): + print(path) + data = get_data(path) + np.save("data_" + path + ".npy", data) diff --git a/data_loader.py b/data_loader.py new file mode 100644 index 0000000..420a1c1 --- /dev/null +++ b/data_loader.py @@ -0,0 +1,27 @@ +import torch + + +def data_loader( + pattern: torch.Tensor, + labels: torch.Tensor, + batch_size: int = 128, + shuffle: bool = True, + torch_device: torch.device = torch.device("cpu"), +) -> torch.utils.data.dataloader.DataLoader: + + assert pattern.ndim >= 3 + + pattern_storage: torch.Tensor = pattern.to(torch_device).type(torch.float32) + if pattern_storage.ndim == 3: + pattern_storage = pattern_storage.unsqueeze(1) + pattern_storage /= pattern_storage.max() + + label_storage: torch.Tensor = labels.to(torch_device).type(torch.int64) + + dataloader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset(pattern_storage, label_storage), + batch_size=batch_size, + shuffle=shuffle, + ) + + return dataloader diff --git a/get_the_data.py b/get_the_data.py new file mode 100644 index 0000000..5ca1168 --- /dev/null +++ b/get_the_data.py @@ -0,0 +1,115 @@ +import torch +import torchvision # type: ignore +from data_loader import data_loader + + +def get_the_data( + dataset: str, + batch_size_train: int, + batch_size_test: int, + torch_device: torch.device, + input_dim_x: int, + input_dim_y: int, + flip_p: float = 0.5, + jitter_brightness: float = 0.5, + jitter_contrast: float = 0.1, + jitter_saturation: float = 0.1, + jitter_hue: float = 0.15, +) -> tuple[ + data_loader, + data_loader, + torchvision.transforms.Compose, + torchvision.transforms.Compose, +]: + if dataset == "MNIST": + tv_dataset_train = torchvision.datasets.MNIST( + root="data", train=True, download=True + ) + tv_dataset_test = torchvision.datasets.MNIST( + root="data", train=False, download=True + ) + elif dataset == "FashionMNIST": + tv_dataset_train = torchvision.datasets.FashionMNIST( + root="data", train=True, download=True + ) + tv_dataset_test = torchvision.datasets.FashionMNIST( + root="data", train=False, download=True + ) + elif dataset == "CIFAR10": + tv_dataset_train = torchvision.datasets.CIFAR10( + root="data", train=True, download=True + ) + tv_dataset_test = torchvision.datasets.CIFAR10( + root="data", train=False, download=True + ) + else: + raise NotImplementedError("This dataset is not implemented.") + + if dataset == "MNIST" or dataset == "FashionMNIST": + + train_dataloader = data_loader( + torch_device=torch_device, + batch_size=batch_size_train, + pattern=tv_dataset_train.data, + labels=tv_dataset_train.targets, + shuffle=True, + ) + + test_dataloader = data_loader( + torch_device=torch_device, + batch_size=batch_size_test, + pattern=tv_dataset_test.data, + labels=tv_dataset_test.targets, + shuffle=False, + ) + + # Data augmentation filter + test_processing_chain = torchvision.transforms.Compose( + transforms=[torchvision.transforms.CenterCrop((input_dim_x, input_dim_y))], + ) + + train_processing_chain = torchvision.transforms.Compose( + transforms=[torchvision.transforms.RandomCrop((input_dim_x, input_dim_y))], + ) + else: + + train_dataloader = data_loader( + torch_device=torch_device, + batch_size=batch_size_train, + pattern=torch.tensor(tv_dataset_train.data).movedim(-1, 1), + labels=torch.tensor(tv_dataset_train.targets), + shuffle=True, + ) + + test_dataloader = data_loader( + torch_device=torch_device, + batch_size=batch_size_test, + pattern=torch.tensor(tv_dataset_test.data).movedim(-1, 1), + labels=torch.tensor(tv_dataset_test.targets), + shuffle=False, + ) + + # Data augmentation filter + test_processing_chain = torchvision.transforms.Compose( + transforms=[torchvision.transforms.CenterCrop((input_dim_x, input_dim_y))], + ) + + train_processing_chain = torchvision.transforms.Compose( + transforms=[ + torchvision.transforms.RandomCrop((input_dim_x, input_dim_y)), + torchvision.transforms.RandomHorizontalFlip(p=flip_p), + torchvision.transforms.ColorJitter( + brightness=jitter_brightness, + contrast=jitter_contrast, + saturation=jitter_saturation, + hue=jitter_hue, + ), + ], + ) + + return ( + train_dataloader, + test_dataloader, + test_processing_chain, + train_processing_chain, + ) diff --git a/loss_function.py b/loss_function.py new file mode 100644 index 0000000..e256840 --- /dev/null +++ b/loss_function.py @@ -0,0 +1,64 @@ +import torch + + +# loss_mode == 0: "normal" SbS loss function mixture +# loss_mode == 1: cross_entropy +def loss_function( + h: torch.Tensor, + labels: torch.Tensor, + loss_mode: int = 0, + number_of_output_neurons: int = 10, + loss_coeffs_mse: float = 0.0, + loss_coeffs_kldiv: float = 0.0, +) -> torch.Tensor | None: + + assert loss_mode >= 0 + assert loss_mode <= 1 + + assert h.ndim == 2 + + if loss_mode == 0: + + # Convert label into one hot + target_one_hot: torch.Tensor = torch.zeros( + ( + labels.shape[0], + number_of_output_neurons, + ), + device=h.device, + dtype=h.dtype, + ) + + target_one_hot.scatter_( + 1, + labels.to(h.device).unsqueeze(1), + torch.ones( + (labels.shape[0], 1), + device=h.device, + dtype=h.dtype, + ), + ) + + my_loss: torch.Tensor = ((h - target_one_hot) ** 2).sum(dim=0).mean( + dim=0 + ) * loss_coeffs_mse + + my_loss = ( + my_loss + + ( + (target_one_hot * torch.log((target_one_hot + 1e-20) / (h + 1e-20))) + .sum(dim=0) + .mean(dim=0) + ) + * loss_coeffs_kldiv + ) + + my_loss = my_loss / (abs(loss_coeffs_kldiv) + abs(loss_coeffs_mse)) + + return my_loss + + elif loss_mode == 1: + my_loss = torch.nn.functional.cross_entropy(h, labels.to(h.device)) + return my_loss + else: + return None diff --git a/make_network.py b/make_network.py new file mode 100644 index 0000000..671957f --- /dev/null +++ b/make_network.py @@ -0,0 +1,354 @@ +import torch +from NNMFConv2d import NNMFConv2d +from NNMFConv2dP import NNMFConv2dP +from SplitOnOffLayer import SplitOnOffLayer + + +def make_network( + use_nnmf: bool, + cnn_top: bool, + input_dim_x: int, + input_dim_y: int, + input_number_of_channel: int, + iterations: int, + init_min: float = 0.0, + init_max: float = 1.0, + use_convolution: bool = False, + convolution_contribution_map_enable: bool = False, + epsilon: bool | None = None, + positive_function_type: int = 0, + beta: float | None = None, + number_of_output_channels_conv1: int = 32, + number_of_output_channels_conv2: int = 64, + number_of_output_channels_flatten2: int = 96, + number_of_output_channels_full1: int = 10, + kernel_size_conv1: tuple[int, int] = (5, 5), + kernel_size_pool1: tuple[int, int] = (2, 2), + kernel_size_conv2: tuple[int, int] = (5, 5), + kernel_size_pool2: tuple[int, int] = (2, 2), + stride_conv1: tuple[int, int] = (1, 1), + stride_pool1: tuple[int, int] = (2, 2), + stride_conv2: tuple[int, int] = (1, 1), + stride_pool2: tuple[int, int] = (2, 2), + padding_conv1: int = 0, + padding_pool1: int = 0, + padding_conv2: int = 0, + padding_pool2: int = 0, + enable_onoff: bool = False, + local_learning_0: bool = False, + local_learning_1: bool = False, + local_learning_2: bool = False, + local_learning_3: bool = False, + local_learning_kl: bool = True, + p_mode_0: bool = False, + p_mode_1: bool = False, + p_mode_2: bool = False, + p_mode_3: bool = False, +) -> tuple[torch.nn.Sequential, list[int], list[int]]: + + if enable_onoff: + input_number_of_channel *= 2 + + list_cnn_top_id: list[int] = [] + list_other_id: list[int] = [] + + test_image = torch.ones((1, input_number_of_channel, input_dim_x, input_dim_y)) + + network = torch.nn.Sequential() + + if enable_onoff: + network.append(SplitOnOffLayer()) + test_image = network[-1](test_image) + + list_other_id.append(len(network)) + if use_nnmf: + if p_mode_0: + network.append( + NNMFConv2dP( + in_channels=test_image.shape[1], + out_channels=number_of_output_channels_conv1, + kernel_size=kernel_size_conv1, + convolution_contribution_map_enable=convolution_contribution_map_enable, + epsilon=epsilon, + positive_function_type=positive_function_type, + init_min=init_min, + init_max=init_max, + beta=beta, + use_convolution=use_convolution, + iterations=iterations, + local_learning=local_learning_0, + local_learning_kl=local_learning_kl, + ) + ) + else: + network.append( + NNMFConv2d( + in_channels=test_image.shape[1], + out_channels=number_of_output_channels_conv1, + kernel_size=kernel_size_conv1, + convolution_contribution_map_enable=convolution_contribution_map_enable, + epsilon=epsilon, + positive_function_type=positive_function_type, + init_min=init_min, + init_max=init_max, + beta=beta, + use_convolution=use_convolution, + iterations=iterations, + local_learning=local_learning_0, + local_learning_kl=local_learning_kl, + ) + ) + test_image = network[-1](test_image) + else: + network.append( + torch.nn.Conv2d( + in_channels=test_image.shape[1], + out_channels=number_of_output_channels_conv1, + kernel_size=kernel_size_conv1, + stride=stride_conv1, + padding=padding_conv1, + ) + ) + test_image = network[-1](test_image) + network.append(torch.nn.ReLU()) + test_image = network[-1](test_image) + + if cnn_top: + list_cnn_top_id.append(len(network)) + network.append( + torch.nn.Conv2d( + in_channels=test_image.shape[1], + out_channels=number_of_output_channels_conv1, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + bias=True, + ) + ) + test_image = network[-1](test_image) + network.append(torch.nn.ReLU()) + test_image = network[-1](test_image) + + network.append( + torch.nn.MaxPool2d( + kernel_size=kernel_size_pool1, stride=stride_pool1, padding=padding_pool1 + ) + ) + test_image = network[-1](test_image) + + list_other_id.append(len(network)) + if use_nnmf: + if p_mode_1: + network.append( + NNMFConv2dP( + in_channels=test_image.shape[1], + out_channels=number_of_output_channels_conv2, + kernel_size=kernel_size_conv2, + convolution_contribution_map_enable=convolution_contribution_map_enable, + epsilon=epsilon, + positive_function_type=positive_function_type, + init_min=init_min, + init_max=init_max, + beta=beta, + use_convolution=use_convolution, + iterations=iterations, + local_learning=local_learning_1, + local_learning_kl=local_learning_kl, + ) + ) + else: + network.append( + NNMFConv2d( + in_channels=test_image.shape[1], + out_channels=number_of_output_channels_conv2, + kernel_size=kernel_size_conv2, + convolution_contribution_map_enable=convolution_contribution_map_enable, + epsilon=epsilon, + positive_function_type=positive_function_type, + init_min=init_min, + init_max=init_max, + beta=beta, + use_convolution=use_convolution, + iterations=iterations, + local_learning=local_learning_1, + local_learning_kl=local_learning_kl, + ) + ) + test_image = network[-1](test_image) + else: + network.append( + torch.nn.Conv2d( + in_channels=test_image.shape[1], + out_channels=number_of_output_channels_conv2, + kernel_size=kernel_size_conv2, + stride=stride_conv2, + padding=padding_conv2, + ) + ) + test_image = network[-1](test_image) + network.append(torch.nn.ReLU()) + test_image = network[-1](test_image) + + if cnn_top: + list_cnn_top_id.append(len(network)) + network.append( + torch.nn.Conv2d( + in_channels=test_image.shape[1], + out_channels=number_of_output_channels_conv2, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + bias=True, + ) + ) + test_image = network[-1](test_image) + network.append(torch.nn.ReLU()) + test_image = network[-1](test_image) + + network.append( + torch.nn.MaxPool2d( + kernel_size=kernel_size_pool2, stride=stride_pool2, padding=padding_pool2 + ) + ) + test_image = network[-1](test_image) + + list_other_id.append(len(network)) + if use_nnmf: + if p_mode_2: + network.append( + NNMFConv2dP( + in_channels=test_image.shape[1], + out_channels=number_of_output_channels_flatten2, + kernel_size=(test_image.shape[2], test_image.shape[3]), + convolution_contribution_map_enable=convolution_contribution_map_enable, + epsilon=epsilon, + positive_function_type=positive_function_type, + init_min=init_min, + init_max=init_max, + beta=beta, + use_convolution=use_convolution, + iterations=iterations, + local_learning=local_learning_2, + local_learning_kl=local_learning_kl, + ) + ) + else: + network.append( + NNMFConv2d( + in_channels=test_image.shape[1], + out_channels=number_of_output_channels_flatten2, + kernel_size=(test_image.shape[2], test_image.shape[3]), + convolution_contribution_map_enable=convolution_contribution_map_enable, + epsilon=epsilon, + positive_function_type=positive_function_type, + init_min=init_min, + init_max=init_max, + beta=beta, + use_convolution=use_convolution, + iterations=iterations, + local_learning=local_learning_2, + local_learning_kl=local_learning_kl, + ) + ) + test_image = network[-1](test_image) + else: + network.append( + torch.nn.Conv2d( + in_channels=test_image.shape[1], + out_channels=number_of_output_channels_flatten2, + kernel_size=(test_image.shape[2], test_image.shape[3]), + ) + ) + test_image = network[-1](test_image) + network.append(torch.nn.ReLU()) + test_image = network[-1](test_image) + + if cnn_top: + list_cnn_top_id.append(len(network)) + network.append( + torch.nn.Conv2d( + in_channels=test_image.shape[1], + out_channels=number_of_output_channels_flatten2, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + bias=True, + ) + ) + test_image = network[-1](test_image) + network.append(torch.nn.ReLU()) + test_image = network[-1](test_image) + + list_other_id.append(len(network)) + if use_nnmf: + if p_mode_3: + network.append( + NNMFConv2dP( + in_channels=test_image.shape[1], + out_channels=number_of_output_channels_full1, + kernel_size=(test_image.shape[2], test_image.shape[3]), + convolution_contribution_map_enable=convolution_contribution_map_enable, + epsilon=epsilon, + positive_function_type=positive_function_type, + init_min=init_min, + init_max=init_max, + beta=beta, + use_convolution=use_convolution, + iterations=iterations, + local_learning=local_learning_3, + local_learning_kl=local_learning_kl, + ) + ) + else: + network.append( + NNMFConv2d( + in_channels=test_image.shape[1], + out_channels=number_of_output_channels_full1, + kernel_size=(test_image.shape[2], test_image.shape[3]), + convolution_contribution_map_enable=convolution_contribution_map_enable, + epsilon=epsilon, + positive_function_type=positive_function_type, + init_min=init_min, + init_max=init_max, + beta=beta, + use_convolution=use_convolution, + iterations=iterations, + local_learning=local_learning_3, + local_learning_kl=local_learning_kl, + ) + ) + test_image = network[-1](test_image) + else: + network.append( + torch.nn.Conv2d( + in_channels=test_image.shape[1], + out_channels=number_of_output_channels_full1, + kernel_size=(test_image.shape[2], test_image.shape[3]), + ) + ) + test_image = network[-1](test_image) + if cnn_top: + network.append(torch.nn.ReLU()) + test_image = network[-1](test_image) + + if cnn_top: + list_cnn_top_id.append(len(network)) + network.append( + torch.nn.Conv2d( + in_channels=test_image.shape[1], + out_channels=number_of_output_channels_full1, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + bias=True, + ) + ) + test_image = network[-1](test_image) + + network.append(torch.nn.Flatten()) + test_image = network[-1](test_image) + + network.append(torch.nn.Softmax(dim=1)) + test_image = network[-1](test_image) + + return network, list_cnn_top_id, list_other_id diff --git a/make_optimize.py b/make_optimize.py new file mode 100644 index 0000000..54b3072 --- /dev/null +++ b/make_optimize.py @@ -0,0 +1,103 @@ +import torch +from NNMFConv2d import NNMFConv2d +from NNMFConv2dP import NNMFConv2dP + + +def make_optimize( + network: torch.nn.Sequential, + list_cnn_top_id: list[int], + list_other_id: list[int], + lr_initial_nnmf: float = 0.01, + lr_initial_cnn: float = 0.001, + lr_initial_cnn_top: float = 0.001, + eps=1e-10, +) -> tuple[ + torch.optim.Adam | None, + torch.optim.Adam | None, + torch.optim.Adam | None, + torch.optim.lr_scheduler.ReduceLROnPlateau | None, + torch.optim.lr_scheduler.ReduceLROnPlateau | None, + torch.optim.lr_scheduler.ReduceLROnPlateau | None, +]: + + list_cnn_top: list = [] + # Init the cnn top layers 1x1 conv2d layers + for layerid in list_cnn_top_id: + for netp in network[layerid].parameters(): + with torch.no_grad(): + if netp.ndim == 1: + netp.data *= 0 + if netp.ndim == 4: + assert netp.shape[-2] == 1 + assert netp.shape[-1] == 1 + netp[: netp.shape[0], : netp.shape[0], 0, 0] = torch.eye( + netp.shape[0], dtype=netp.dtype, device=netp.device + ) + netp[netp.shape[0] :, :, 0, 0] = 0 + netp[:, netp.shape[0] :, 0, 0] = 0 + + list_cnn_top.append(netp) + + list_cnn: list = [] + list_nnmf: list = [] + for layerid in list_other_id: + if isinstance(network[layerid], torch.nn.Conv2d): + for netp in network[layerid].parameters(): + list_cnn.append(netp) + + if isinstance(network[layerid], (NNMFConv2d, NNMFConv2dP)): + for netp in network[layerid].parameters(): + list_nnmf.append(netp) + + # The optimizer + if len(list_nnmf) > 0: + optimizer_nnmf: torch.optim.Adam | None = torch.optim.Adam( + list_nnmf, lr=lr_initial_nnmf + ) + else: + optimizer_nnmf = None + + if len(list_cnn) > 0: + optimizer_cnn: torch.optim.Adam | None = torch.optim.Adam( + list_cnn, lr=lr_initial_cnn + ) + else: + optimizer_cnn = None + + if len(list_cnn_top) > 0: + optimizer_cnn_top: torch.optim.Adam | None = torch.optim.Adam( + list_cnn_top, lr=lr_initial_cnn_top + ) + else: + optimizer_cnn_top = None + + # The LR Scheduler + if optimizer_nnmf is not None: + lr_scheduler_nnmf: torch.optim.lr_scheduler.ReduceLROnPlateau | None = ( + torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_nnmf, eps=eps) + ) + else: + lr_scheduler_nnmf = None + + if optimizer_cnn is not None: + lr_scheduler_cnn: torch.optim.lr_scheduler.ReduceLROnPlateau | None = ( + torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_cnn, eps=eps) + ) + else: + lr_scheduler_cnn = None + + if optimizer_cnn_top is not None: + lr_scheduler_cnn_top: torch.optim.lr_scheduler.ReduceLROnPlateau | None = ( + torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_cnn_top, eps=eps) + ) + else: + lr_scheduler_cnn_top = None + + return ( + optimizer_nnmf, + optimizer_cnn, + optimizer_cnn_top, + lr_scheduler_nnmf, + lr_scheduler_cnn, + lr_scheduler_cnn_top, + ) diff --git a/non_linear_weigth_function.py b/non_linear_weigth_function.py new file mode 100644 index 0000000..053a9b6 --- /dev/null +++ b/non_linear_weigth_function.py @@ -0,0 +1,26 @@ +import torch + + +def non_linear_weigth_function( + weight: torch.Tensor, beta: torch.Tensor | None, positive_function_type: int +) -> torch.Tensor: + + if positive_function_type == 0: + positive_weights = torch.abs(weight) + + elif positive_function_type == 1: + assert beta is not None + positive_weights = weight + max_value = torch.abs(positive_weights).max() + if max_value > 80: + positive_weights = 80.0 * positive_weights / max_value + positive_weights = torch.exp((torch.tanh(beta) + 1.0) * 0.5 * positive_weights) + + elif positive_function_type == 2: + assert beta is not None + positive_weights = (torch.tanh(beta * weight) + 1.0) * 0.5 + + else: + positive_weights = weight + + return positive_weights diff --git a/plot.py b/plot.py new file mode 100644 index 0000000..3f02184 --- /dev/null +++ b/plot.py @@ -0,0 +1,46 @@ +import numpy as np +import matplotlib.pyplot as plt + + +data = np.load("data_log_cnn_20_True_0.001_0.01_True_True_True_True.npy") +plt.loglog(data[:, 0], 100.0 * (1.0 - data[:, 1] / 10000.0), "k", label="CNN + CNN Top") + +data = np.load("data_log_cnn_20_False_0.001_0.01_True_True_True_True.npy") +plt.loglog(data[:, 0], 100.0 * (1.0 - data[:, 1] / 10000.0), "k--", label="CNN") + +data = np.load("data_log_nnmf_20_True_0.001_0.01_True_True_True_True.npy") +plt.loglog( + data[:, 0], + 100.0 * (1.0 - data[:, 1] / 10000.0), + "r", + label="NNMF + CNN Top (Iter 20, KL)", +) + +data = np.load("data_log_nnmf_20_False_0.001_0.01_True_True_True_True.npy") +plt.loglog( + data[:, 0], + 100.0 * (1.0 - data[:, 1] / 10000.0), + "r--", + label="NNMF (Iter 20, KL)", +) + +data = np.load("data_log_nnmf_20_True_0.001_0.01_True_True_True_False.npy") +plt.loglog( + data[:, 0], + 100.0 * (1.0 - data[:, 1] / 10000.0), + "b", + label="NNMF + CNN Top (Iter 20, MSE)", +) + +data = np.load("data_log_nnmf_20_False_0.001_0.01_True_True_True_False.npy") +plt.loglog( + data[:, 0], + 100.0 * (1.0 - data[:, 1] / 10000.0), + "b--", + label="NNMF (Iter 20, MSE)", +) + +plt.legend() +plt.xlabel("Epoch") +plt.ylabel("Error [%]") +plt.show() diff --git a/run_network.py b/run_network.py new file mode 100644 index 0000000..ae6241a --- /dev/null +++ b/run_network.py @@ -0,0 +1,272 @@ +import os + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + +import argh + +import time +import torch + +from torch.utils.tensorboard import SummaryWriter + +from make_network import make_network +from get_the_data import get_the_data +from loss_function import loss_function +from make_optimize import make_optimize + + +def main( + lr_initial_nnmf: float = 0.01, + lr_initial_cnn: float = 0.001, + lr_initial_cnn_top: float = 0.001, + iterations: int = 20, + cnn_top: bool = True, + use_nnmf: bool = True, + dataset: str = "CIFAR10", # "CIFAR10", "FashionMNIST", "MNIST" + rand_seed: int = 21, + enable_onoff: bool = False, + local_learning_0: bool = False, + local_learning_1: bool = False, + local_learning_2: bool = False, + local_learning_3: bool = False, + local_learning_kl: bool = False, + p_mode_0: bool = False, + p_mode_1: bool = False, + p_mode_2: bool = False, + p_mode_3: bool = False, +) -> None: + + lr_limit: float = 1e-9 + + torch.manual_seed(rand_seed) + + torch_device: torch.device = ( + torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + ) + torch.set_default_dtype(torch.float32) + + # Some parameters + batch_size_train: int = 500 + batch_size_test: int = 500 + number_of_epoch: int = 500 + + if use_nnmf: + prefix: str = "nnmf" + else: + prefix = "cnn" + + default_path: str = ( + f"{prefix}_{iterations}_{cnn_top}_{lr_initial_cnn}_{lr_initial_nnmf}_{local_learning_0}_{local_learning_1}_{local_learning_2}_{local_learning_kl}" + ) + log_dir: str = f"log_{default_path}" + + loss_mode: int = 0 + loss_coeffs_mse: float = 0.5 + loss_coeffs_kldiv: float = 1.0 + print( + "loss_mode: ", + loss_mode, + "loss_coeffs_mse: ", + loss_coeffs_mse, + "loss_coeffs_kldiv: ", + loss_coeffs_kldiv, + ) + + if dataset == "MNIST" or dataset == "FashionMNIST": + input_number_of_channel: int = 1 + input_dim_x: int = 24 + input_dim_y: int = 24 + else: + input_number_of_channel = 3 + input_dim_x = 28 + input_dim_y = 28 + + train_dataloader, test_dataloader, test_processing_chain, train_processing_chain = ( + get_the_data( + dataset, + batch_size_train, + batch_size_test, + torch_device, + input_dim_x, + input_dim_y, + flip_p=0.5, + jitter_brightness=0.5, + jitter_contrast=0.1, + jitter_saturation=0.1, + jitter_hue=0.15, + ) + ) + + network, list_cnn_top_id, list_other_id = make_network( + use_nnmf=use_nnmf, + cnn_top=cnn_top, + input_dim_x=input_dim_x, + input_dim_y=input_dim_y, + input_number_of_channel=input_number_of_channel, + iterations=iterations, + enable_onoff=enable_onoff, + local_learning_0=local_learning_0, + local_learning_1=local_learning_1, + local_learning_2=local_learning_2, + local_learning_3=local_learning_3, + local_learning_kl=local_learning_kl, + p_mode_0=p_mode_0, + p_mode_1=p_mode_1, + p_mode_2=p_mode_2, + p_mode_3=p_mode_3, + ) + network = network.to(torch_device) + + print(network) + + ( + optimizer_nnmf, + optimizer_cnn, + optimizer_cnn_top, + lr_scheduler_nnmf, + lr_scheduler_cnn, + lr_scheduler_cnn_top, + ) = make_optimize( + network=network, + list_cnn_top_id=list_cnn_top_id, + list_other_id=list_other_id, + lr_initial_nnmf=lr_initial_nnmf, + lr_initial_cnn=lr_initial_cnn, + lr_initial_cnn_top=lr_initial_cnn_top, + ) + + tb = SummaryWriter(log_dir=log_dir) + + for epoch_id in range(0, number_of_epoch): + print() + print(f"Epoch: {epoch_id}") + t_start: float = time.perf_counter() + + train_loss: float = 0.0 + train_correct: int = 0 + train_number: int = 0 + test_correct: int = 0 + test_number: int = 0 + + # Switch the network into training mode + network.train() + + # This runs in total for one epoch split up into mini-batches + for image, target in train_dataloader: + # Clean the gradient + if optimizer_nnmf is not None: + optimizer_nnmf.zero_grad() + if optimizer_cnn is not None: + optimizer_cnn.zero_grad() + if optimizer_cnn_top is not None: + optimizer_cnn_top.zero_grad() + + output = network(train_processing_chain(image)) + + loss = loss_function( + h=output, + labels=target, + number_of_output_neurons=output.shape[1], + loss_mode=loss_mode, + loss_coeffs_mse=loss_coeffs_mse, + loss_coeffs_kldiv=loss_coeffs_kldiv, + ) + + assert loss is not None + train_loss += loss.item() + train_correct += (output.argmax(dim=1) == target).sum().cpu().numpy() + train_number += target.shape[0] + + # Calculate backprop + loss.backward() + + # Update the parameter + if optimizer_nnmf is not None: + optimizer_nnmf.step() + if optimizer_cnn is not None: + optimizer_cnn.step() + if optimizer_cnn_top is not None: + optimizer_cnn_top.step() + + perfomance_train_correct: float = 100.0 * train_correct / train_number + # Update the learning rate + if lr_scheduler_nnmf is not None: + lr_scheduler_nnmf.step(train_loss) + + if lr_scheduler_cnn is not None: + lr_scheduler_cnn.step(train_loss) + + if lr_scheduler_cnn_top is not None: + lr_scheduler_cnn_top.step(train_loss) + + print( + "Actual lr: ", + "nnmf: ", + lr_scheduler_nnmf.get_last_lr() if lr_scheduler_nnmf is not None else -1.0, + "cnn: ", + lr_scheduler_cnn.get_last_lr() if lr_scheduler_cnn is not None else -1.0, + "cnn top: ", + ( + lr_scheduler_cnn_top.get_last_lr() + if lr_scheduler_cnn_top is not None + else -1.0 + ), + ) + t_training: float = time.perf_counter() + + # Switch the network into evalution mode + network.eval() + + with torch.no_grad(): + + for image, target in test_dataloader: + output = network(test_processing_chain(image)) + + test_correct += (output.argmax(dim=1) == target).sum().cpu().numpy() + test_number += target.shape[0] + + t_testing = time.perf_counter() + + perfomance_test_correct: float = 100.0 * test_correct / test_number + + tb.add_scalar("Train Loss", train_loss / float(train_number), epoch_id) + tb.add_scalar("Train Number Correct", train_correct, epoch_id) + tb.add_scalar("Test Number Correct", test_correct, epoch_id) + + print( + f"Training: Loss={train_loss / float(train_number):.5f} Correct={perfomance_train_correct:.2f}%" + ) + print(f"Testing: Correct={perfomance_test_correct:.2f}%") + print( + f"Time: Training={(t_training - t_start):.1f}sec, Testing={(t_testing - t_training):.1f}sec" + ) + + tb.flush() + + lr_check: list[float] = [] + if lr_scheduler_nnmf is not None: + lr_check.append(lr_scheduler_nnmf.get_last_lr()[0]) + if lr_scheduler_cnn is not None: + lr_check.append(lr_scheduler_cnn.get_last_lr()[0]) + if lr_scheduler_cnn_top is not None: + lr_check.append(lr_scheduler_cnn_top.get_last_lr()[0]) + + lr_check_max = float(torch.tensor(lr_check).max()) + + if lr_check_max < lr_limit: + torch.save(network, f"Model_{default_path}.pt") + tb.close() + print("Done (lr_limit)") + return + + torch.save(network, f"Model_{default_path}.pt") + print() + + tb.close() + print("Done (loop end)") + + return + + +if __name__ == "__main__": + argh.dispatch_command(main)