Add files via upload

This commit is contained in:
David Rotermund 2024-05-30 15:53:53 +02:00 committed by GitHub
parent 2c9d3368dd
commit 36bf428306
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 51 additions and 16 deletions

View file

@ -24,6 +24,7 @@ class NNMFConv2dP(torch.nn.Module):
use_convolution: bool
local_learning: bool
local_learning_kl: bool
use_reconstruction: bool
def __init__(
self,
@ -45,6 +46,7 @@ class NNMFConv2dP(torch.nn.Module):
use_convolution: bool = False,
local_learning: bool = False,
local_learning_kl: bool = False,
use_reconstruction: bool = False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
@ -79,6 +81,8 @@ class NNMFConv2dP(torch.nn.Module):
self.local_learning = local_learning
self.local_learning_kl = local_learning_kl
self.use_reconstruction = use_reconstruction
self.weight = torch.nn.parameter.Parameter(
torch.empty((out_channels, in_channels, *kernel_size), **factory_kwargs)
)
@ -197,13 +201,12 @@ class NNMFConv2dP(torch.nn.Module):
self.local_learning,
self.local_learning_kl,
)
self.reco = False
if self.reco:
print(h_dyn.shape)
print(positive_weights.shape)
print(input.shape)
exit()
output = torch.cat((h_dyn, input), dim=1)
if self.use_reconstruction:
reconstruction = torch.nn.functional.linear(
h_dyn.movedim(1, -1), positive_weights.T
).movedim(-1, 1)
output = torch.cat((h_dyn, input - reconstruction), dim=1)
else:
output = torch.cat((h_dyn, input), dim=1)
return output

View file

@ -44,6 +44,8 @@ def make_network(
p_mode_1: bool = False,
p_mode_2: bool = False,
p_mode_3: bool = False,
use_reconstruction: bool = False,
max_pool: bool = True,
) -> tuple[torch.nn.Sequential, list[int], list[int]]:
if enable_onoff:
@ -78,6 +80,7 @@ def make_network(
iterations=iterations,
local_learning=local_learning_0,
local_learning_kl=local_learning_kl,
use_reconstruction=use_reconstruction,
)
)
else:
@ -129,11 +132,22 @@ def make_network(
network.append(torch.nn.ReLU())
test_image = network[-1](test_image)
network.append(
torch.nn.MaxPool2d(
kernel_size=kernel_size_pool1, stride=stride_pool1, padding=padding_pool1
if max_pool:
network.append(
torch.nn.MaxPool2d(
kernel_size=kernel_size_pool1,
stride=stride_pool1,
padding=padding_pool1,
)
)
else:
network.append(
torch.nn.AvgPool2d(
kernel_size=kernel_size_pool1,
stride=stride_pool1,
padding=padding_pool1,
)
)
)
test_image = network[-1](test_image)
list_other_id.append(len(network))
@ -154,6 +168,7 @@ def make_network(
iterations=iterations,
local_learning=local_learning_1,
local_learning_kl=local_learning_kl,
use_reconstruction=use_reconstruction,
)
)
else:
@ -205,11 +220,22 @@ def make_network(
network.append(torch.nn.ReLU())
test_image = network[-1](test_image)
network.append(
torch.nn.MaxPool2d(
kernel_size=kernel_size_pool2, stride=stride_pool2, padding=padding_pool2
if max_pool:
network.append(
torch.nn.MaxPool2d(
kernel_size=kernel_size_pool2,
stride=stride_pool2,
padding=padding_pool2,
)
)
else:
network.append(
torch.nn.AvgPool2d(
kernel_size=kernel_size_pool2,
stride=stride_pool2,
padding=padding_pool2,
)
)
)
test_image = network[-1](test_image)
list_other_id.append(len(network))
@ -230,6 +256,7 @@ def make_network(
iterations=iterations,
local_learning=local_learning_2,
local_learning_kl=local_learning_kl,
use_reconstruction=use_reconstruction,
)
)
else:
@ -297,6 +324,7 @@ def make_network(
iterations=iterations,
local_learning=local_learning_3,
local_learning_kl=local_learning_kl,
use_reconstruction=use_reconstruction,
)
)
else:

View file

@ -34,6 +34,8 @@ def main(
p_mode_1: bool = False,
p_mode_2: bool = False,
p_mode_3: bool = False,
use_reconstruction: bool = False,
max_pool: bool = True,
) -> None:
lr_limit: float = 1e-9
@ -56,7 +58,7 @@ def main(
prefix = "cnn"
default_path: str = (
f"{prefix}_{iterations}_{cnn_top}_{lr_initial_cnn}_{lr_initial_nnmf}_{local_learning_0}_{local_learning_1}_{local_learning_2}_{local_learning_kl}"
f"{prefix}_{iterations}_{cnn_top}_{lr_initial_cnn}_{lr_initial_nnmf}_{local_learning_0}_{local_learning_1}_{local_learning_2}_{local_learning_kl}_{use_reconstruction}"
)
log_dir: str = f"log_{default_path}"
@ -114,6 +116,8 @@ def main(
p_mode_1=p_mode_1,
p_mode_2=p_mode_2,
p_mode_3=p_mode_3,
use_reconstruction=use_reconstruction,
max_pool=max_pool,
)
network = network.to(torch_device)