diff --git a/NNMFConv2dP.py b/NNMFConv2dP.py index 9dedda3..5a87ae4 100644 --- a/NNMFConv2dP.py +++ b/NNMFConv2dP.py @@ -24,6 +24,7 @@ class NNMFConv2dP(torch.nn.Module): use_convolution: bool local_learning: bool local_learning_kl: bool + use_reconstruction: bool def __init__( self, @@ -45,6 +46,7 @@ class NNMFConv2dP(torch.nn.Module): use_convolution: bool = False, local_learning: bool = False, local_learning_kl: bool = False, + use_reconstruction: bool = False, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} @@ -79,6 +81,8 @@ class NNMFConv2dP(torch.nn.Module): self.local_learning = local_learning self.local_learning_kl = local_learning_kl + self.use_reconstruction = use_reconstruction + self.weight = torch.nn.parameter.Parameter( torch.empty((out_channels, in_channels, *kernel_size), **factory_kwargs) ) @@ -197,13 +201,12 @@ class NNMFConv2dP(torch.nn.Module): self.local_learning, self.local_learning_kl, ) - self.reco = False - if self.reco: - print(h_dyn.shape) - print(positive_weights.shape) - print(input.shape) - exit() - output = torch.cat((h_dyn, input), dim=1) + + if self.use_reconstruction: + reconstruction = torch.nn.functional.linear( + h_dyn.movedim(1, -1), positive_weights.T + ).movedim(-1, 1) + output = torch.cat((h_dyn, input - reconstruction), dim=1) else: output = torch.cat((h_dyn, input), dim=1) return output diff --git a/make_network.py b/make_network.py index 671957f..d978e26 100644 --- a/make_network.py +++ b/make_network.py @@ -44,6 +44,8 @@ def make_network( p_mode_1: bool = False, p_mode_2: bool = False, p_mode_3: bool = False, + use_reconstruction: bool = False, + max_pool: bool = True, ) -> tuple[torch.nn.Sequential, list[int], list[int]]: if enable_onoff: @@ -78,6 +80,7 @@ def make_network( iterations=iterations, local_learning=local_learning_0, local_learning_kl=local_learning_kl, + use_reconstruction=use_reconstruction, ) ) else: @@ -129,11 +132,22 @@ def make_network( network.append(torch.nn.ReLU()) test_image = network[-1](test_image) - network.append( - torch.nn.MaxPool2d( - kernel_size=kernel_size_pool1, stride=stride_pool1, padding=padding_pool1 + if max_pool: + network.append( + torch.nn.MaxPool2d( + kernel_size=kernel_size_pool1, + stride=stride_pool1, + padding=padding_pool1, + ) + ) + else: + network.append( + torch.nn.AvgPool2d( + kernel_size=kernel_size_pool1, + stride=stride_pool1, + padding=padding_pool1, + ) ) - ) test_image = network[-1](test_image) list_other_id.append(len(network)) @@ -154,6 +168,7 @@ def make_network( iterations=iterations, local_learning=local_learning_1, local_learning_kl=local_learning_kl, + use_reconstruction=use_reconstruction, ) ) else: @@ -205,11 +220,22 @@ def make_network( network.append(torch.nn.ReLU()) test_image = network[-1](test_image) - network.append( - torch.nn.MaxPool2d( - kernel_size=kernel_size_pool2, stride=stride_pool2, padding=padding_pool2 + if max_pool: + network.append( + torch.nn.MaxPool2d( + kernel_size=kernel_size_pool2, + stride=stride_pool2, + padding=padding_pool2, + ) + ) + else: + network.append( + torch.nn.AvgPool2d( + kernel_size=kernel_size_pool2, + stride=stride_pool2, + padding=padding_pool2, + ) ) - ) test_image = network[-1](test_image) list_other_id.append(len(network)) @@ -230,6 +256,7 @@ def make_network( iterations=iterations, local_learning=local_learning_2, local_learning_kl=local_learning_kl, + use_reconstruction=use_reconstruction, ) ) else: @@ -297,6 +324,7 @@ def make_network( iterations=iterations, local_learning=local_learning_3, local_learning_kl=local_learning_kl, + use_reconstruction=use_reconstruction, ) ) else: diff --git a/run_network.py b/run_network.py index ae6241a..bdfab88 100644 --- a/run_network.py +++ b/run_network.py @@ -34,6 +34,8 @@ def main( p_mode_1: bool = False, p_mode_2: bool = False, p_mode_3: bool = False, + use_reconstruction: bool = False, + max_pool: bool = True, ) -> None: lr_limit: float = 1e-9 @@ -56,7 +58,7 @@ def main( prefix = "cnn" default_path: str = ( - f"{prefix}_{iterations}_{cnn_top}_{lr_initial_cnn}_{lr_initial_nnmf}_{local_learning_0}_{local_learning_1}_{local_learning_2}_{local_learning_kl}" + f"{prefix}_{iterations}_{cnn_top}_{lr_initial_cnn}_{lr_initial_nnmf}_{local_learning_0}_{local_learning_1}_{local_learning_2}_{local_learning_kl}_{use_reconstruction}" ) log_dir: str = f"log_{default_path}" @@ -114,6 +116,8 @@ def main( p_mode_1=p_mode_1, p_mode_2=p_mode_2, p_mode_3=p_mode_3, + use_reconstruction=use_reconstruction, + max_pool=max_pool, ) network = network.to(torch_device)