Add files via upload
This commit is contained in:
parent
b267c0b8c4
commit
33f92b7429
1 changed files with 126 additions and 8 deletions
|
@ -5,6 +5,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from network.SbSLayer import SbSLayer
|
from network.SbSLayer import SbSLayer
|
||||||
from network.NNMFLayer import NNMFLayer
|
from network.NNMFLayer import NNMFLayer
|
||||||
|
from network.NNMFLayerSbSBP import NNMFLayerSbSBP
|
||||||
from network.save_weight_and_bias import save_weight_and_bias
|
from network.save_weight_and_bias import save_weight_and_bias
|
||||||
from network.SbSReconstruction import SbSReconstruction
|
from network.SbSReconstruction import SbSReconstruction
|
||||||
|
|
||||||
|
@ -20,8 +21,10 @@ def add_weight_and_bias_to_histogram(
|
||||||
# ################################################
|
# ################################################
|
||||||
# Log the SbS Weights
|
# Log the SbS Weights
|
||||||
# ################################################
|
# ################################################
|
||||||
if (isinstance(network[id], SbSLayer) is True) or (
|
if (
|
||||||
isinstance(network[id], NNMFLayer) is True
|
(isinstance(network[id], SbSLayer) is True)
|
||||||
|
or (isinstance(network[id], NNMFLayer) is True)
|
||||||
|
or (isinstance(network[id], NNMFLayerSbSBP) is True)
|
||||||
):
|
):
|
||||||
if network[id]._w_trainable is True:
|
if network[id]._w_trainable is True:
|
||||||
|
|
||||||
|
@ -231,8 +234,10 @@ def run_optimizer(
|
||||||
cfg: Config,
|
cfg: Config,
|
||||||
) -> None:
|
) -> None:
|
||||||
for id in range(0, len(network)):
|
for id in range(0, len(network)):
|
||||||
if (isinstance(network[id], SbSLayer) is True) or (
|
if (
|
||||||
isinstance(network[id], NNMFLayer) is True
|
(isinstance(network[id], SbSLayer) is True)
|
||||||
|
or (isinstance(network[id], NNMFLayer) is True)
|
||||||
|
or (isinstance(network[id], NNMFLayerSbSBP) is True)
|
||||||
):
|
):
|
||||||
network[id].update_pre_care()
|
network[id].update_pre_care()
|
||||||
|
|
||||||
|
@ -241,8 +246,10 @@ def run_optimizer(
|
||||||
optimizer_item.step()
|
optimizer_item.step()
|
||||||
|
|
||||||
for id in range(0, len(network)):
|
for id in range(0, len(network)):
|
||||||
if (isinstance(network[id], SbSLayer) is True) or (
|
if (
|
||||||
isinstance(network[id], NNMFLayer) is True
|
(isinstance(network[id], SbSLayer) is True)
|
||||||
|
or (isinstance(network[id], NNMFLayer) is True)
|
||||||
|
or (isinstance(network[id], NNMFLayerSbSBP) is True)
|
||||||
):
|
):
|
||||||
network[id].update_after_care(
|
network[id].update_after_care(
|
||||||
cfg.learning_parameters.learning_rate_threshold_w
|
cfg.learning_parameters.learning_rate_threshold_w
|
||||||
|
@ -295,11 +302,15 @@ def run_lr_scheduler(
|
||||||
def deal_with_gradient_scale(epoch_id: int, mini_batch_number: int, network):
|
def deal_with_gradient_scale(epoch_id: int, mini_batch_number: int, network):
|
||||||
if (epoch_id == 0) and (mini_batch_number == 0):
|
if (epoch_id == 0) and (mini_batch_number == 0):
|
||||||
for id in range(0, len(network)):
|
for id in range(0, len(network)):
|
||||||
if isinstance(network[id], SbSLayer) is True:
|
if (isinstance(network[id], SbSLayer) is True) or (
|
||||||
|
isinstance(network[id], NNMFLayerSbSBP) is True
|
||||||
|
):
|
||||||
network[id].after_batch(True)
|
network[id].after_batch(True)
|
||||||
else:
|
else:
|
||||||
for id in range(0, len(network)):
|
for id in range(0, len(network)):
|
||||||
if isinstance(network[id], SbSLayer) is True:
|
if (isinstance(network[id], SbSLayer) is True) or (
|
||||||
|
isinstance(network[id], NNMFLayerSbSBP) is True
|
||||||
|
):
|
||||||
network[id].after_batch()
|
network[id].after_batch()
|
||||||
|
|
||||||
|
|
||||||
|
@ -775,3 +786,110 @@ def loop_test_reconstruction(
|
||||||
tb.flush()
|
tb.flush()
|
||||||
|
|
||||||
return performance
|
return performance
|
||||||
|
|
||||||
|
|
||||||
|
def loop_train_h_activity(
|
||||||
|
cfg: Config,
|
||||||
|
network: torch.nn.modules.container.Sequential,
|
||||||
|
my_loader_train: torch.utils.data.dataloader.DataLoader,
|
||||||
|
the_dataset_train,
|
||||||
|
device: torch.device,
|
||||||
|
default_dtype: torch.dtype,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
first_run = True
|
||||||
|
position_counter = 0
|
||||||
|
for h_x, h_x_labels in my_loader_train:
|
||||||
|
print(f"{position_counter} of {the_dataset_train.number_of_pattern}")
|
||||||
|
# #####################################################
|
||||||
|
# The network does the forward pass (training)
|
||||||
|
# #####################################################
|
||||||
|
h_collection = forward_pass_train(
|
||||||
|
input=h_x,
|
||||||
|
labels=h_x_labels,
|
||||||
|
the_dataset_train=the_dataset_train,
|
||||||
|
cfg=cfg,
|
||||||
|
network=network,
|
||||||
|
device=device,
|
||||||
|
default_dtype=default_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
h_h: torch.Tensor = h_collection[-1].detach().clone().cpu()
|
||||||
|
|
||||||
|
if first_run is True:
|
||||||
|
h_activity = torch.zeros(
|
||||||
|
(
|
||||||
|
the_dataset_train.number_of_pattern,
|
||||||
|
h_h.shape[-3],
|
||||||
|
h_h.shape[-2],
|
||||||
|
h_h.shape[-1],
|
||||||
|
),
|
||||||
|
device=device,
|
||||||
|
dtype=default_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
h_labels = torch.zeros(
|
||||||
|
(the_dataset_train.number_of_pattern),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int,
|
||||||
|
)
|
||||||
|
first_run = False
|
||||||
|
|
||||||
|
h_activity[
|
||||||
|
position_counter : position_counter + int(h_h.shape[0]), :, :, :
|
||||||
|
] = h_h
|
||||||
|
|
||||||
|
h_labels[
|
||||||
|
position_counter : position_counter + int(h_h.shape[0]),
|
||||||
|
] = h_x_labels
|
||||||
|
position_counter += int(h_h.shape[0])
|
||||||
|
|
||||||
|
return h_activity, h_labels
|
||||||
|
|
||||||
|
|
||||||
|
def loop_train_h_confusion(
|
||||||
|
cfg: Config,
|
||||||
|
network: torch.nn.modules.container.Sequential,
|
||||||
|
my_loader_train: torch.utils.data.dataloader.DataLoader,
|
||||||
|
the_dataset_train,
|
||||||
|
optimizer: list,
|
||||||
|
device: torch.device,
|
||||||
|
default_dtype: torch.dtype,
|
||||||
|
logging,
|
||||||
|
adapt_learning_rate: bool,
|
||||||
|
tb: SummaryWriter,
|
||||||
|
lr_scheduler,
|
||||||
|
last_test_performance: float,
|
||||||
|
order_id: float | int | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
first_run = True
|
||||||
|
|
||||||
|
position_counter = 0
|
||||||
|
for h_x, h_x_labels in my_loader_train:
|
||||||
|
print(f"{position_counter} of {the_dataset_train.number_of_pattern}")
|
||||||
|
# #####################################################
|
||||||
|
# The network does the forward pass (training)
|
||||||
|
# #####################################################
|
||||||
|
h_collection = forward_pass_train(
|
||||||
|
input=h_x,
|
||||||
|
labels=h_x_labels,
|
||||||
|
the_dataset_train=the_dataset_train,
|
||||||
|
cfg=cfg,
|
||||||
|
network=network,
|
||||||
|
device=device,
|
||||||
|
default_dtype=default_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
h_h: torch.Tensor = h_collection[-1]
|
||||||
|
idx = h_h.argmax(dim=1).squeeze().cpu().numpy()
|
||||||
|
if first_run is True:
|
||||||
|
first_run = False
|
||||||
|
result = torch.zeros((h_h.shape[1], h_h.shape[1]), dtype=torch.int)
|
||||||
|
|
||||||
|
for i in range(0, idx.shape[0]):
|
||||||
|
result[int(h_x_labels[i]), int(idx[i])] += 1
|
||||||
|
|
||||||
|
position_counter += int(h_h.shape[0])
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
Loading…
Reference in a new issue