Add files via upload
This commit is contained in:
parent
2c9d3368dd
commit
36bf428306
3 changed files with 51 additions and 16 deletions
|
@ -24,6 +24,7 @@ class NNMFConv2dP(torch.nn.Module):
|
|||
use_convolution: bool
|
||||
local_learning: bool
|
||||
local_learning_kl: bool
|
||||
use_reconstruction: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -45,6 +46,7 @@ class NNMFConv2dP(torch.nn.Module):
|
|||
use_convolution: bool = False,
|
||||
local_learning: bool = False,
|
||||
local_learning_kl: bool = False,
|
||||
use_reconstruction: bool = False,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
|
||||
|
@ -79,6 +81,8 @@ class NNMFConv2dP(torch.nn.Module):
|
|||
self.local_learning = local_learning
|
||||
self.local_learning_kl = local_learning_kl
|
||||
|
||||
self.use_reconstruction = use_reconstruction
|
||||
|
||||
self.weight = torch.nn.parameter.Parameter(
|
||||
torch.empty((out_channels, in_channels, *kernel_size), **factory_kwargs)
|
||||
)
|
||||
|
@ -197,13 +201,12 @@ class NNMFConv2dP(torch.nn.Module):
|
|||
self.local_learning,
|
||||
self.local_learning_kl,
|
||||
)
|
||||
self.reco = False
|
||||
if self.reco:
|
||||
print(h_dyn.shape)
|
||||
print(positive_weights.shape)
|
||||
print(input.shape)
|
||||
exit()
|
||||
output = torch.cat((h_dyn, input), dim=1)
|
||||
|
||||
if self.use_reconstruction:
|
||||
reconstruction = torch.nn.functional.linear(
|
||||
h_dyn.movedim(1, -1), positive_weights.T
|
||||
).movedim(-1, 1)
|
||||
output = torch.cat((h_dyn, input - reconstruction), dim=1)
|
||||
else:
|
||||
output = torch.cat((h_dyn, input), dim=1)
|
||||
return output
|
||||
|
|
|
@ -44,6 +44,8 @@ def make_network(
|
|||
p_mode_1: bool = False,
|
||||
p_mode_2: bool = False,
|
||||
p_mode_3: bool = False,
|
||||
use_reconstruction: bool = False,
|
||||
max_pool: bool = True,
|
||||
) -> tuple[torch.nn.Sequential, list[int], list[int]]:
|
||||
|
||||
if enable_onoff:
|
||||
|
@ -78,6 +80,7 @@ def make_network(
|
|||
iterations=iterations,
|
||||
local_learning=local_learning_0,
|
||||
local_learning_kl=local_learning_kl,
|
||||
use_reconstruction=use_reconstruction,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
@ -129,9 +132,20 @@ def make_network(
|
|||
network.append(torch.nn.ReLU())
|
||||
test_image = network[-1](test_image)
|
||||
|
||||
if max_pool:
|
||||
network.append(
|
||||
torch.nn.MaxPool2d(
|
||||
kernel_size=kernel_size_pool1, stride=stride_pool1, padding=padding_pool1
|
||||
kernel_size=kernel_size_pool1,
|
||||
stride=stride_pool1,
|
||||
padding=padding_pool1,
|
||||
)
|
||||
)
|
||||
else:
|
||||
network.append(
|
||||
torch.nn.AvgPool2d(
|
||||
kernel_size=kernel_size_pool1,
|
||||
stride=stride_pool1,
|
||||
padding=padding_pool1,
|
||||
)
|
||||
)
|
||||
test_image = network[-1](test_image)
|
||||
|
@ -154,6 +168,7 @@ def make_network(
|
|||
iterations=iterations,
|
||||
local_learning=local_learning_1,
|
||||
local_learning_kl=local_learning_kl,
|
||||
use_reconstruction=use_reconstruction,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
@ -205,9 +220,20 @@ def make_network(
|
|||
network.append(torch.nn.ReLU())
|
||||
test_image = network[-1](test_image)
|
||||
|
||||
if max_pool:
|
||||
network.append(
|
||||
torch.nn.MaxPool2d(
|
||||
kernel_size=kernel_size_pool2, stride=stride_pool2, padding=padding_pool2
|
||||
kernel_size=kernel_size_pool2,
|
||||
stride=stride_pool2,
|
||||
padding=padding_pool2,
|
||||
)
|
||||
)
|
||||
else:
|
||||
network.append(
|
||||
torch.nn.AvgPool2d(
|
||||
kernel_size=kernel_size_pool2,
|
||||
stride=stride_pool2,
|
||||
padding=padding_pool2,
|
||||
)
|
||||
)
|
||||
test_image = network[-1](test_image)
|
||||
|
@ -230,6 +256,7 @@ def make_network(
|
|||
iterations=iterations,
|
||||
local_learning=local_learning_2,
|
||||
local_learning_kl=local_learning_kl,
|
||||
use_reconstruction=use_reconstruction,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
@ -297,6 +324,7 @@ def make_network(
|
|||
iterations=iterations,
|
||||
local_learning=local_learning_3,
|
||||
local_learning_kl=local_learning_kl,
|
||||
use_reconstruction=use_reconstruction,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
|
|
@ -34,6 +34,8 @@ def main(
|
|||
p_mode_1: bool = False,
|
||||
p_mode_2: bool = False,
|
||||
p_mode_3: bool = False,
|
||||
use_reconstruction: bool = False,
|
||||
max_pool: bool = True,
|
||||
) -> None:
|
||||
|
||||
lr_limit: float = 1e-9
|
||||
|
@ -56,7 +58,7 @@ def main(
|
|||
prefix = "cnn"
|
||||
|
||||
default_path: str = (
|
||||
f"{prefix}_{iterations}_{cnn_top}_{lr_initial_cnn}_{lr_initial_nnmf}_{local_learning_0}_{local_learning_1}_{local_learning_2}_{local_learning_kl}"
|
||||
f"{prefix}_{iterations}_{cnn_top}_{lr_initial_cnn}_{lr_initial_nnmf}_{local_learning_0}_{local_learning_1}_{local_learning_2}_{local_learning_kl}_{use_reconstruction}"
|
||||
)
|
||||
log_dir: str = f"log_{default_path}"
|
||||
|
||||
|
@ -114,6 +116,8 @@ def main(
|
|||
p_mode_1=p_mode_1,
|
||||
p_mode_2=p_mode_2,
|
||||
p_mode_3=p_mode_3,
|
||||
use_reconstruction=use_reconstruction,
|
||||
max_pool=max_pool,
|
||||
)
|
||||
network = network.to(torch_device)
|
||||
|
||||
|
|
Loading…
Reference in a new issue