Add files via upload

This commit is contained in:
David Rotermund 2023-03-15 16:46:09 +01:00 committed by GitHub
parent b267c0b8c4
commit 33f92b7429
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -5,6 +5,7 @@ from torch.utils.tensorboard import SummaryWriter
from network.SbSLayer import SbSLayer
from network.NNMFLayer import NNMFLayer
from network.NNMFLayerSbSBP import NNMFLayerSbSBP
from network.save_weight_and_bias import save_weight_and_bias
from network.SbSReconstruction import SbSReconstruction
@ -20,8 +21,10 @@ def add_weight_and_bias_to_histogram(
# ################################################
# Log the SbS Weights
# ################################################
if (isinstance(network[id], SbSLayer) is True) or (
isinstance(network[id], NNMFLayer) is True
if (
(isinstance(network[id], SbSLayer) is True)
or (isinstance(network[id], NNMFLayer) is True)
or (isinstance(network[id], NNMFLayerSbSBP) is True)
):
if network[id]._w_trainable is True:
@ -231,8 +234,10 @@ def run_optimizer(
cfg: Config,
) -> None:
for id in range(0, len(network)):
if (isinstance(network[id], SbSLayer) is True) or (
isinstance(network[id], NNMFLayer) is True
if (
(isinstance(network[id], SbSLayer) is True)
or (isinstance(network[id], NNMFLayer) is True)
or (isinstance(network[id], NNMFLayerSbSBP) is True)
):
network[id].update_pre_care()
@ -241,8 +246,10 @@ def run_optimizer(
optimizer_item.step()
for id in range(0, len(network)):
if (isinstance(network[id], SbSLayer) is True) or (
isinstance(network[id], NNMFLayer) is True
if (
(isinstance(network[id], SbSLayer) is True)
or (isinstance(network[id], NNMFLayer) is True)
or (isinstance(network[id], NNMFLayerSbSBP) is True)
):
network[id].update_after_care(
cfg.learning_parameters.learning_rate_threshold_w
@ -295,11 +302,15 @@ def run_lr_scheduler(
def deal_with_gradient_scale(epoch_id: int, mini_batch_number: int, network):
if (epoch_id == 0) and (mini_batch_number == 0):
for id in range(0, len(network)):
if isinstance(network[id], SbSLayer) is True:
if (isinstance(network[id], SbSLayer) is True) or (
isinstance(network[id], NNMFLayerSbSBP) is True
):
network[id].after_batch(True)
else:
for id in range(0, len(network)):
if isinstance(network[id], SbSLayer) is True:
if (isinstance(network[id], SbSLayer) is True) or (
isinstance(network[id], NNMFLayerSbSBP) is True
):
network[id].after_batch()
@ -775,3 +786,110 @@ def loop_test_reconstruction(
tb.flush()
return performance
def loop_train_h_activity(
cfg: Config,
network: torch.nn.modules.container.Sequential,
my_loader_train: torch.utils.data.dataloader.DataLoader,
the_dataset_train,
device: torch.device,
default_dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
first_run = True
position_counter = 0
for h_x, h_x_labels in my_loader_train:
print(f"{position_counter} of {the_dataset_train.number_of_pattern}")
# #####################################################
# The network does the forward pass (training)
# #####################################################
h_collection = forward_pass_train(
input=h_x,
labels=h_x_labels,
the_dataset_train=the_dataset_train,
cfg=cfg,
network=network,
device=device,
default_dtype=default_dtype,
)
h_h: torch.Tensor = h_collection[-1].detach().clone().cpu()
if first_run is True:
h_activity = torch.zeros(
(
the_dataset_train.number_of_pattern,
h_h.shape[-3],
h_h.shape[-2],
h_h.shape[-1],
),
device=device,
dtype=default_dtype,
)
h_labels = torch.zeros(
(the_dataset_train.number_of_pattern),
device=device,
dtype=torch.int,
)
first_run = False
h_activity[
position_counter : position_counter + int(h_h.shape[0]), :, :, :
] = h_h
h_labels[
position_counter : position_counter + int(h_h.shape[0]),
] = h_x_labels
position_counter += int(h_h.shape[0])
return h_activity, h_labels
def loop_train_h_confusion(
cfg: Config,
network: torch.nn.modules.container.Sequential,
my_loader_train: torch.utils.data.dataloader.DataLoader,
the_dataset_train,
optimizer: list,
device: torch.device,
default_dtype: torch.dtype,
logging,
adapt_learning_rate: bool,
tb: SummaryWriter,
lr_scheduler,
last_test_performance: float,
order_id: float | int | None = None,
) -> torch.Tensor:
first_run = True
position_counter = 0
for h_x, h_x_labels in my_loader_train:
print(f"{position_counter} of {the_dataset_train.number_of_pattern}")
# #####################################################
# The network does the forward pass (training)
# #####################################################
h_collection = forward_pass_train(
input=h_x,
labels=h_x_labels,
the_dataset_train=the_dataset_train,
cfg=cfg,
network=network,
device=device,
default_dtype=default_dtype,
)
h_h: torch.Tensor = h_collection[-1]
idx = h_h.argmax(dim=1).squeeze().cpu().numpy()
if first_run is True:
first_run = False
result = torch.zeros((h_h.shape[1], h_h.shape[1]), dtype=torch.int)
for i in range(0, idx.shape[0]):
result[int(h_x_labels[i]), int(idx[i])] += 1
position_counter += int(h_h.shape[0])
return result