PySpikeBySpike_04_2025/tools/run_network_test.py
2025-04-08 15:20:45 +02:00

127 lines
3.6 KiB
Python

import time
import numpy as np
import torch
import json
from jsmin import jsmin
import os
from torch.utils.tensorboard import SummaryWriter
from tools.make_network import make_network
from tools.get_the_data import get_the_data
from tools.loss_function import loss_function
from tools.make_optimize import make_optimize
def main(
rand_seed: int = 21,
only_print_network: bool = False,
iterations: int = 20,
model_iterations: int = 20,
config_network_filename: str = "config_network.json",
config_data_filename: str = "config_data.json",
config_lr_parameter_filename: str = "config_lr_parameter.json",
) -> None:
os.makedirs("Models", exist_ok=True)
device: torch.device = (
torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
)
torch.set_default_dtype(torch.float32)
# Some parameters
with open(config_data_filename, "r") as file:
minified = jsmin(file.read())
config_data = json.loads(minified)
with open(config_lr_parameter_filename, "r") as file:
minified = jsmin(file.read())
config_lr_parameter = json.loads(minified)
torch.manual_seed(rand_seed)
torch.cuda.manual_seed(rand_seed)
np.random.seed(rand_seed)
if (
str(config_data["dataset"]) == "MNIST"
or str(config_data["dataset"]) == "FashionMNIST"
):
input_number_of_channel: int = 1
input_dim_x: int = 24
input_dim_y: int = 24
else:
input_number_of_channel = 3
input_dim_x = 28
input_dim_y = 28
train_dataloader, test_dataloader, train_processing_chain, test_processing_chain = (
get_the_data(
str(config_data["dataset"]),
int(config_data["batch_size_train"]),
int(config_data["batch_size_test"]),
device,
input_dim_x,
input_dim_y,
flip_p=float(config_data["flip_p"]),
jitter_brightness=float(config_data["jitter_brightness"]),
jitter_contrast=float(config_data["jitter_contrast"]),
jitter_saturation=float(config_data["jitter_saturation"]),
jitter_hue=float(config_data["jitter_hue"]),
da_auto_mode=bool(config_data["da_auto_mode"]),
)
)
my_string: str = f"seed_{rand_seed}_{model_iterations}"
default_path: str = f"{my_string}"
log_dir: str = f"test_log_{default_path}_{iterations}"
network = torch.load(f"Models/Model_{default_path}.pt", weights_only=False)
network = network.to(device=device)
network.eval()
print(f"Layers are set to {iterations} iterations.")
for layer in network:
if hasattr(layer, 'iterations'):
layer.iterations = iterations
if only_print_network:
print(network)
exit()
tb = SummaryWriter(log_dir=log_dir)
print()
t_start: float = time.perf_counter()
test_correct: int = 0
test_number: int = 0
# Switch the network into evalution mode
network.eval()
with torch.no_grad():
for image, target in test_dataloader:
output = network(test_processing_chain(image))
test_correct += (output.argmax(dim=1) == target).sum().cpu().numpy()
test_number += target.shape[0]
t_testing = time.perf_counter()
perfomance_test_correct: float = 100.0 * test_correct / test_number
tb.add_scalar("Test Number Correct", test_correct, 0)
print(f"Testing: Correct={perfomance_test_correct:.2f}%")
print(
f"Time: Testing={(t_testing - t_start):.1f}sec"
)
tb.flush()
tb.close()
return