127 lines
3.6 KiB
Python
127 lines
3.6 KiB
Python
import time
|
|
import numpy as np
|
|
import torch
|
|
|
|
import json
|
|
from jsmin import jsmin
|
|
import os
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from tools.make_network import make_network
|
|
from tools.get_the_data import get_the_data
|
|
from tools.loss_function import loss_function
|
|
from tools.make_optimize import make_optimize
|
|
|
|
|
|
def main(
|
|
rand_seed: int = 21,
|
|
only_print_network: bool = False,
|
|
iterations: int = 20,
|
|
model_iterations: int = 20,
|
|
config_network_filename: str = "config_network.json",
|
|
config_data_filename: str = "config_data.json",
|
|
config_lr_parameter_filename: str = "config_lr_parameter.json",
|
|
) -> None:
|
|
|
|
os.makedirs("Models", exist_ok=True)
|
|
|
|
device: torch.device = (
|
|
torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
|
)
|
|
torch.set_default_dtype(torch.float32)
|
|
|
|
# Some parameters
|
|
with open(config_data_filename, "r") as file:
|
|
minified = jsmin(file.read())
|
|
config_data = json.loads(minified)
|
|
|
|
with open(config_lr_parameter_filename, "r") as file:
|
|
minified = jsmin(file.read())
|
|
config_lr_parameter = json.loads(minified)
|
|
|
|
torch.manual_seed(rand_seed)
|
|
torch.cuda.manual_seed(rand_seed)
|
|
np.random.seed(rand_seed)
|
|
|
|
if (
|
|
str(config_data["dataset"]) == "MNIST"
|
|
or str(config_data["dataset"]) == "FashionMNIST"
|
|
):
|
|
input_number_of_channel: int = 1
|
|
input_dim_x: int = 24
|
|
input_dim_y: int = 24
|
|
else:
|
|
input_number_of_channel = 3
|
|
input_dim_x = 28
|
|
input_dim_y = 28
|
|
|
|
train_dataloader, test_dataloader, train_processing_chain, test_processing_chain = (
|
|
get_the_data(
|
|
str(config_data["dataset"]),
|
|
int(config_data["batch_size_train"]),
|
|
int(config_data["batch_size_test"]),
|
|
device,
|
|
input_dim_x,
|
|
input_dim_y,
|
|
flip_p=float(config_data["flip_p"]),
|
|
jitter_brightness=float(config_data["jitter_brightness"]),
|
|
jitter_contrast=float(config_data["jitter_contrast"]),
|
|
jitter_saturation=float(config_data["jitter_saturation"]),
|
|
jitter_hue=float(config_data["jitter_hue"]),
|
|
da_auto_mode=bool(config_data["da_auto_mode"]),
|
|
)
|
|
)
|
|
|
|
|
|
my_string: str = f"seed_{rand_seed}_{model_iterations}"
|
|
default_path: str = f"{my_string}"
|
|
log_dir: str = f"test_log_{default_path}_{iterations}"
|
|
|
|
network = torch.load(f"Models/Model_{default_path}.pt", weights_only=False)
|
|
network = network.to(device=device)
|
|
network.eval()
|
|
|
|
print(f"Layers are set to {iterations} iterations.")
|
|
for layer in network:
|
|
if hasattr(layer, 'iterations'):
|
|
layer.iterations = iterations
|
|
|
|
if only_print_network:
|
|
print(network)
|
|
exit()
|
|
|
|
tb = SummaryWriter(log_dir=log_dir)
|
|
|
|
print()
|
|
t_start: float = time.perf_counter()
|
|
|
|
test_correct: int = 0
|
|
test_number: int = 0
|
|
|
|
# Switch the network into evalution mode
|
|
network.eval()
|
|
|
|
with torch.no_grad():
|
|
|
|
for image, target in test_dataloader:
|
|
output = network(test_processing_chain(image))
|
|
|
|
test_correct += (output.argmax(dim=1) == target).sum().cpu().numpy()
|
|
test_number += target.shape[0]
|
|
|
|
t_testing = time.perf_counter()
|
|
|
|
perfomance_test_correct: float = 100.0 * test_correct / test_number
|
|
|
|
tb.add_scalar("Test Number Correct", test_correct, 0)
|
|
print(f"Testing: Correct={perfomance_test_correct:.2f}%")
|
|
print(
|
|
f"Time: Testing={(t_testing - t_start):.1f}sec"
|
|
)
|
|
|
|
tb.flush()
|
|
|
|
tb.close()
|
|
|
|
return
|