pytorch_introduction_2025/example_network.py
2025-04-13 22:12:46 +02:00

198 lines
5 KiB
Python

import torch
import torchvision # type: ignore
from torchvision.transforms import v2 # type: ignore
import time
import os
number_of_epoch: int = 500
lr_parameter_max: float = 1e-9
ModelsPath: str = "Models"
os.makedirs(ModelsPath, exist_ok=True)
# Tensorboard
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
from torch.utils.tensorboard import SummaryWriter
tb = SummaryWriter(log_dir="run")
# GPU ?
device: torch.device = (
torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
)
torch.set_default_dtype(torch.float32)
# Data augmentation
test_processing_chain = v2.Compose(
transforms=[
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.CenterCrop((28, 28)),
],
)
train_processing_chain = v2.Compose(
transforms=[
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.RandomCrop((28, 28)),
v2.AutoAugment(),
],
)
# Data provider
tv_dataset_train = torchvision.datasets.CIFAR10(
root="data",
train=True,
download=True,
transform=train_processing_chain,
)
tv_dataset_test = torchvision.datasets.CIFAR10(
root="data",
train=False,
download=True,
transform=test_processing_chain,
)
# Data loader
train_data_load = torch.utils.data.DataLoader(
tv_dataset_train, batch_size=100, shuffle=True
)
test_data_load = torch.utils.data.DataLoader(
tv_dataset_test, batch_size=100, shuffle=False
)
# Network
network = torch.nn.Sequential(
torch.nn.Conv2d(
in_channels=3,
out_channels=32,
kernel_size=5,
stride=1,
padding=0,
),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(32),
torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
torch.nn.Conv2d(
in_channels=32,
out_channels=64,
kernel_size=5,
stride=1,
padding=0,
),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(64),
torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
torch.nn.Flatten(
start_dim=1,
),
torch.nn.Linear(
in_features=1024,
out_features=1024,
bias=True,
),
torch.nn.ReLU(),
torch.nn.Linear(
in_features=1024,
out_features=10,
bias=True,
),
).to(device)
# Optimizer
optimizer = torch.optim.Adam(network.parameters(), lr=0.001)
# LR Scheduler
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
# Loss function
loss_function = torch.nn.CrossEntropyLoss()
# Main loop
for epoch_id in range(0, number_of_epoch):
print(f"Epoch: {epoch_id}")
t_start: float = time.perf_counter()
train_loss: float = 0.0
train_correct: int = 0
train_number: int = 0
test_correct: int = 0
test_number: int = 0
# Switch the network into training mode
network.train()
# This runs in total for one epoch split up into mini-batches
for image, target in train_data_load:
# Clean the gradient
optimizer.zero_grad()
# Run data through network
output = network(image.to(device))
# Measure the loss
loss = loss_function(output, target.to(device))
train_loss += loss.item()
# Classifiy
train_correct += (
(output.argmax(dim=1) == target.to(device)).sum().detach().cpu().numpy()
)
train_number += target.shape[0]
# Calculate backprop
loss.backward()
# Update the parameter
optimizer.step()
# Update the learning rate
lr_scheduler.step(train_loss)
t_training: float = time.perf_counter()
# Switch the network into evalution mode
network.eval()
with torch.no_grad():
for image, target in test_data_load:
# Run data thorugh network
output = network(image.to(device))
# Classifiy
test_correct += (
(output.argmax(dim=1) == target.to(device)).sum().detach().cpu().numpy()
)
test_number += target.shape[0]
t_testing = time.perf_counter()
perfomance_test_correct: float = 100.0 * test_correct / test_number
perfomance_train_correct: float = 100.0 * train_correct / train_number
tb.add_scalar("Train Loss", train_loss, epoch_id)
tb.add_scalar("Train Number Correct", train_correct, epoch_id)
tb.add_scalar("Test Number Correct", test_correct, epoch_id)
tb.add_scalar("Error Test", 100.0 - perfomance_test_correct, epoch_id)
tb.add_scalar("Error Train", 100.0 - perfomance_train_correct, epoch_id)
tb.add_scalar("Learning Rate", optimizer.param_groups[-1]["lr"], epoch_id)
tb.flush()
print(
f"Training: Loss={train_loss:.5f} Correct={perfomance_train_correct:.2f}% LR:{optimizer.param_groups[-1]["lr"]}"
)
print(f"Testing: Correct={perfomance_test_correct:.2f}%")
print(
f"Time: Training={(t_training - t_start):.1f}sec, Testing={(t_testing - t_training):.1f}sec"
)
torch.save(network, os.path.join(ModelsPath, f"Model_MNIST_A_{epoch_id}.pt"))
print()
if optimizer.param_groups[-1]["lr"] < lr_parameter_max:
tb.close()
print("Done (lr_limit)")
exit()
tb.close()