pynnmf/run_network.py

282 lines
8.2 KiB
Python
Raw Normal View History

2024-05-30 14:08:44 +02:00
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import argh
import time
import torch
from torch.utils.tensorboard import SummaryWriter
from make_network import make_network
from get_the_data import get_the_data
from loss_function import loss_function
from make_optimize import make_optimize
def main(
lr_initial_nnmf: float = 0.01,
lr_initial_cnn: float = 0.001,
lr_initial_cnn_top: float = 0.001,
iterations: int = 20,
cnn_top: bool = True,
use_nnmf: bool = True,
dataset: str = "CIFAR10", # "CIFAR10", "FashionMNIST", "MNIST"
rand_seed: int = 21,
enable_onoff: bool = False,
local_learning_0: bool = False,
local_learning_1: bool = False,
local_learning_2: bool = False,
local_learning_3: bool = False,
local_learning_kl: bool = False,
2024-05-31 17:56:34 +02:00
skip_connection_0: bool = True,
skip_connection_1: bool = True,
skip_connection_2: bool = True,
skip_connection_3: bool = True,
2024-05-30 15:53:53 +02:00
use_reconstruction: bool = False,
max_pool: bool = True,
2024-05-30 14:08:44 +02:00
) -> None:
lr_limit: float = 1e-9
torch.manual_seed(rand_seed)
torch_device: torch.device = (
torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
)
torch.set_default_dtype(torch.float32)
# Some parameters
batch_size_train: int = 500
batch_size_test: int = 500
number_of_epoch: int = 500
if use_nnmf:
prefix: str = "nnmf"
else:
prefix = "cnn"
default_path: str = (
2024-05-31 09:59:22 +02:00
f"{prefix}_iter{iterations}_cnntop{cnn_top}_lrcnn{lr_initial_cnn}_lrnnmf{lr_initial_nnmf}_0{local_learning_0}_1{local_learning_1}_2{local_learning_2}_3{local_learning_3}_kl{local_learning_kl}_reco{use_reconstruction}_max{max_pool}"
2024-05-30 14:08:44 +02:00
)
log_dir: str = f"log_{default_path}"
loss_mode: int = 0
loss_coeffs_mse: float = 0.5
loss_coeffs_kldiv: float = 1.0
print(
"loss_mode: ",
loss_mode,
"loss_coeffs_mse: ",
loss_coeffs_mse,
"loss_coeffs_kldiv: ",
loss_coeffs_kldiv,
)
if dataset == "MNIST" or dataset == "FashionMNIST":
input_number_of_channel: int = 1
input_dim_x: int = 24
input_dim_y: int = 24
else:
input_number_of_channel = 3
input_dim_x = 28
input_dim_y = 28
train_dataloader, test_dataloader, test_processing_chain, train_processing_chain = (
get_the_data(
dataset,
batch_size_train,
batch_size_test,
torch_device,
input_dim_x,
input_dim_y,
flip_p=0.5,
jitter_brightness=0.5,
jitter_contrast=0.1,
jitter_saturation=0.1,
jitter_hue=0.15,
)
)
network, list_cnn_top_id, list_other_id = make_network(
use_nnmf=use_nnmf,
cnn_top=cnn_top,
input_dim_x=input_dim_x,
input_dim_y=input_dim_y,
input_number_of_channel=input_number_of_channel,
iterations=iterations,
enable_onoff=enable_onoff,
2024-05-31 17:56:34 +02:00
local_learning=[
local_learning_0,
local_learning_1,
local_learning_2,
local_learning_3,
],
2024-05-30 14:08:44 +02:00
local_learning_kl=local_learning_kl,
2024-05-31 17:56:34 +02:00
skip_connection=[
skip_connection_0,
skip_connection_1,
skip_connection_2,
skip_connection_3,
],
2024-05-30 15:53:53 +02:00
use_reconstruction=use_reconstruction,
max_pool=max_pool,
2024-05-30 14:08:44 +02:00
)
network = network.to(torch_device)
print(network)
2024-06-01 02:14:18 +02:00
2024-05-30 14:08:44 +02:00
(
optimizer_nnmf,
optimizer_cnn,
optimizer_cnn_top,
lr_scheduler_nnmf,
lr_scheduler_cnn,
lr_scheduler_cnn_top,
) = make_optimize(
network=network,
list_cnn_top_id=list_cnn_top_id,
list_other_id=list_other_id,
lr_initial_nnmf=lr_initial_nnmf,
lr_initial_cnn=lr_initial_cnn,
lr_initial_cnn_top=lr_initial_cnn_top,
)
tb = SummaryWriter(log_dir=log_dir)
for epoch_id in range(0, number_of_epoch):
print()
print(f"Epoch: {epoch_id}")
t_start: float = time.perf_counter()
train_loss: float = 0.0
train_correct: int = 0
train_number: int = 0
test_correct: int = 0
test_number: int = 0
# Switch the network into training mode
network.train()
# This runs in total for one epoch split up into mini-batches
for image, target in train_dataloader:
# Clean the gradient
if optimizer_nnmf is not None:
optimizer_nnmf.zero_grad()
if optimizer_cnn is not None:
optimizer_cnn.zero_grad()
if optimizer_cnn_top is not None:
optimizer_cnn_top.zero_grad()
output = network(train_processing_chain(image))
loss = loss_function(
h=output,
labels=target,
number_of_output_neurons=output.shape[1],
loss_mode=loss_mode,
loss_coeffs_mse=loss_coeffs_mse,
loss_coeffs_kldiv=loss_coeffs_kldiv,
)
assert loss is not None
train_loss += loss.item()
train_correct += (output.argmax(dim=1) == target).sum().cpu().numpy()
train_number += target.shape[0]
# Calculate backprop
loss.backward()
# Update the parameter
if optimizer_nnmf is not None:
optimizer_nnmf.step()
if optimizer_cnn is not None:
optimizer_cnn.step()
if optimizer_cnn_top is not None:
optimizer_cnn_top.step()
perfomance_train_correct: float = 100.0 * train_correct / train_number
# Update the learning rate
if lr_scheduler_nnmf is not None:
lr_scheduler_nnmf.step(train_loss)
if lr_scheduler_cnn is not None:
lr_scheduler_cnn.step(train_loss)
if lr_scheduler_cnn_top is not None:
lr_scheduler_cnn_top.step(train_loss)
print(
"Actual lr: ",
"nnmf: ",
lr_scheduler_nnmf.get_last_lr() if lr_scheduler_nnmf is not None else -1.0,
"cnn: ",
lr_scheduler_cnn.get_last_lr() if lr_scheduler_cnn is not None else -1.0,
"cnn top: ",
(
lr_scheduler_cnn_top.get_last_lr()
if lr_scheduler_cnn_top is not None
else -1.0
),
)
t_training: float = time.perf_counter()
# Switch the network into evalution mode
network.eval()
with torch.no_grad():
for image, target in test_dataloader:
output = network(test_processing_chain(image))
test_correct += (output.argmax(dim=1) == target).sum().cpu().numpy()
test_number += target.shape[0]
t_testing = time.perf_counter()
perfomance_test_correct: float = 100.0 * test_correct / test_number
tb.add_scalar("Train Loss", train_loss / float(train_number), epoch_id)
tb.add_scalar("Train Number Correct", train_correct, epoch_id)
tb.add_scalar("Test Number Correct", test_correct, epoch_id)
print(
f"Training: Loss={train_loss / float(train_number):.5f} Correct={perfomance_train_correct:.2f}%"
)
print(f"Testing: Correct={perfomance_test_correct:.2f}%")
print(
f"Time: Training={(t_training - t_start):.1f}sec, Testing={(t_testing - t_training):.1f}sec"
)
tb.flush()
lr_check: list[float] = []
if lr_scheduler_nnmf is not None:
lr_check.append(lr_scheduler_nnmf.get_last_lr()[0])
if lr_scheduler_cnn is not None:
lr_check.append(lr_scheduler_cnn.get_last_lr()[0])
if lr_scheduler_cnn_top is not None:
lr_check.append(lr_scheduler_cnn_top.get_last_lr()[0])
lr_check_max = float(torch.tensor(lr_check).max())
if lr_check_max < lr_limit:
torch.save(network, f"Model_{default_path}.pt")
tb.close()
print("Done (lr_limit)")
return
torch.save(network, f"Model_{default_path}.pt")
print()
tb.close()
print("Done (loop end)")
return
if __name__ == "__main__":
argh.dispatch_command(main)