Add files via upload
This commit is contained in:
parent
b267c0b8c4
commit
33f92b7429
1 changed files with 126 additions and 8 deletions
|
@ -5,6 +5,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||
|
||||
from network.SbSLayer import SbSLayer
|
||||
from network.NNMFLayer import NNMFLayer
|
||||
from network.NNMFLayerSbSBP import NNMFLayerSbSBP
|
||||
from network.save_weight_and_bias import save_weight_and_bias
|
||||
from network.SbSReconstruction import SbSReconstruction
|
||||
|
||||
|
@ -20,8 +21,10 @@ def add_weight_and_bias_to_histogram(
|
|||
# ################################################
|
||||
# Log the SbS Weights
|
||||
# ################################################
|
||||
if (isinstance(network[id], SbSLayer) is True) or (
|
||||
isinstance(network[id], NNMFLayer) is True
|
||||
if (
|
||||
(isinstance(network[id], SbSLayer) is True)
|
||||
or (isinstance(network[id], NNMFLayer) is True)
|
||||
or (isinstance(network[id], NNMFLayerSbSBP) is True)
|
||||
):
|
||||
if network[id]._w_trainable is True:
|
||||
|
||||
|
@ -231,8 +234,10 @@ def run_optimizer(
|
|||
cfg: Config,
|
||||
) -> None:
|
||||
for id in range(0, len(network)):
|
||||
if (isinstance(network[id], SbSLayer) is True) or (
|
||||
isinstance(network[id], NNMFLayer) is True
|
||||
if (
|
||||
(isinstance(network[id], SbSLayer) is True)
|
||||
or (isinstance(network[id], NNMFLayer) is True)
|
||||
or (isinstance(network[id], NNMFLayerSbSBP) is True)
|
||||
):
|
||||
network[id].update_pre_care()
|
||||
|
||||
|
@ -241,8 +246,10 @@ def run_optimizer(
|
|||
optimizer_item.step()
|
||||
|
||||
for id in range(0, len(network)):
|
||||
if (isinstance(network[id], SbSLayer) is True) or (
|
||||
isinstance(network[id], NNMFLayer) is True
|
||||
if (
|
||||
(isinstance(network[id], SbSLayer) is True)
|
||||
or (isinstance(network[id], NNMFLayer) is True)
|
||||
or (isinstance(network[id], NNMFLayerSbSBP) is True)
|
||||
):
|
||||
network[id].update_after_care(
|
||||
cfg.learning_parameters.learning_rate_threshold_w
|
||||
|
@ -295,11 +302,15 @@ def run_lr_scheduler(
|
|||
def deal_with_gradient_scale(epoch_id: int, mini_batch_number: int, network):
|
||||
if (epoch_id == 0) and (mini_batch_number == 0):
|
||||
for id in range(0, len(network)):
|
||||
if isinstance(network[id], SbSLayer) is True:
|
||||
if (isinstance(network[id], SbSLayer) is True) or (
|
||||
isinstance(network[id], NNMFLayerSbSBP) is True
|
||||
):
|
||||
network[id].after_batch(True)
|
||||
else:
|
||||
for id in range(0, len(network)):
|
||||
if isinstance(network[id], SbSLayer) is True:
|
||||
if (isinstance(network[id], SbSLayer) is True) or (
|
||||
isinstance(network[id], NNMFLayerSbSBP) is True
|
||||
):
|
||||
network[id].after_batch()
|
||||
|
||||
|
||||
|
@ -775,3 +786,110 @@ def loop_test_reconstruction(
|
|||
tb.flush()
|
||||
|
||||
return performance
|
||||
|
||||
|
||||
def loop_train_h_activity(
|
||||
cfg: Config,
|
||||
network: torch.nn.modules.container.Sequential,
|
||||
my_loader_train: torch.utils.data.dataloader.DataLoader,
|
||||
the_dataset_train,
|
||||
device: torch.device,
|
||||
default_dtype: torch.dtype,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
first_run = True
|
||||
position_counter = 0
|
||||
for h_x, h_x_labels in my_loader_train:
|
||||
print(f"{position_counter} of {the_dataset_train.number_of_pattern}")
|
||||
# #####################################################
|
||||
# The network does the forward pass (training)
|
||||
# #####################################################
|
||||
h_collection = forward_pass_train(
|
||||
input=h_x,
|
||||
labels=h_x_labels,
|
||||
the_dataset_train=the_dataset_train,
|
||||
cfg=cfg,
|
||||
network=network,
|
||||
device=device,
|
||||
default_dtype=default_dtype,
|
||||
)
|
||||
|
||||
h_h: torch.Tensor = h_collection[-1].detach().clone().cpu()
|
||||
|
||||
if first_run is True:
|
||||
h_activity = torch.zeros(
|
||||
(
|
||||
the_dataset_train.number_of_pattern,
|
||||
h_h.shape[-3],
|
||||
h_h.shape[-2],
|
||||
h_h.shape[-1],
|
||||
),
|
||||
device=device,
|
||||
dtype=default_dtype,
|
||||
)
|
||||
|
||||
h_labels = torch.zeros(
|
||||
(the_dataset_train.number_of_pattern),
|
||||
device=device,
|
||||
dtype=torch.int,
|
||||
)
|
||||
first_run = False
|
||||
|
||||
h_activity[
|
||||
position_counter : position_counter + int(h_h.shape[0]), :, :, :
|
||||
] = h_h
|
||||
|
||||
h_labels[
|
||||
position_counter : position_counter + int(h_h.shape[0]),
|
||||
] = h_x_labels
|
||||
position_counter += int(h_h.shape[0])
|
||||
|
||||
return h_activity, h_labels
|
||||
|
||||
|
||||
def loop_train_h_confusion(
|
||||
cfg: Config,
|
||||
network: torch.nn.modules.container.Sequential,
|
||||
my_loader_train: torch.utils.data.dataloader.DataLoader,
|
||||
the_dataset_train,
|
||||
optimizer: list,
|
||||
device: torch.device,
|
||||
default_dtype: torch.dtype,
|
||||
logging,
|
||||
adapt_learning_rate: bool,
|
||||
tb: SummaryWriter,
|
||||
lr_scheduler,
|
||||
last_test_performance: float,
|
||||
order_id: float | int | None = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
first_run = True
|
||||
|
||||
position_counter = 0
|
||||
for h_x, h_x_labels in my_loader_train:
|
||||
print(f"{position_counter} of {the_dataset_train.number_of_pattern}")
|
||||
# #####################################################
|
||||
# The network does the forward pass (training)
|
||||
# #####################################################
|
||||
h_collection = forward_pass_train(
|
||||
input=h_x,
|
||||
labels=h_x_labels,
|
||||
the_dataset_train=the_dataset_train,
|
||||
cfg=cfg,
|
||||
network=network,
|
||||
device=device,
|
||||
default_dtype=default_dtype,
|
||||
)
|
||||
|
||||
h_h: torch.Tensor = h_collection[-1]
|
||||
idx = h_h.argmax(dim=1).squeeze().cpu().numpy()
|
||||
if first_run is True:
|
||||
first_run = False
|
||||
result = torch.zeros((h_h.shape[1], h_h.shape[1]), dtype=torch.int)
|
||||
|
||||
for i in range(0, idx.shape[0]):
|
||||
result[int(h_x_labels[i]), int(idx[i])] += 1
|
||||
|
||||
position_counter += int(h_h.shape[0])
|
||||
|
||||
return result
|
||||
|
|
Loading…
Reference in a new issue