import time import numpy as np import torch import json from jsmin import jsmin import os from torch.utils.tensorboard import SummaryWriter from tools.make_network import make_network from tools.get_the_data import get_the_data from tools.loss_function import loss_function from tools.make_optimize import make_optimize def main( rand_seed: int = 21, only_print_network: bool = False, iterations: int = 20, model_iterations: int = 20, config_network_filename: str = "config_network.json", config_data_filename: str = "config_data.json", config_lr_parameter_filename: str = "config_lr_parameter.json", ) -> None: os.makedirs("Models", exist_ok=True) device: torch.device = ( torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") ) torch.set_default_dtype(torch.float32) # Some parameters with open(config_data_filename, "r") as file: minified = jsmin(file.read()) config_data = json.loads(minified) with open(config_lr_parameter_filename, "r") as file: minified = jsmin(file.read()) config_lr_parameter = json.loads(minified) torch.manual_seed(rand_seed) torch.cuda.manual_seed(rand_seed) np.random.seed(rand_seed) if ( str(config_data["dataset"]) == "MNIST" or str(config_data["dataset"]) == "FashionMNIST" ): input_number_of_channel: int = 1 input_dim_x: int = 24 input_dim_y: int = 24 else: input_number_of_channel = 3 input_dim_x = 28 input_dim_y = 28 train_dataloader, test_dataloader, train_processing_chain, test_processing_chain = ( get_the_data( str(config_data["dataset"]), int(config_data["batch_size_train"]), int(config_data["batch_size_test"]), device, input_dim_x, input_dim_y, flip_p=float(config_data["flip_p"]), jitter_brightness=float(config_data["jitter_brightness"]), jitter_contrast=float(config_data["jitter_contrast"]), jitter_saturation=float(config_data["jitter_saturation"]), jitter_hue=float(config_data["jitter_hue"]), da_auto_mode=bool(config_data["da_auto_mode"]), ) ) my_string: str = f"seed_{rand_seed}_{model_iterations}" default_path: str = f"{my_string}" log_dir: str = f"test_log_{default_path}_{iterations}" network = torch.load(f"Models/Model_{default_path}.pt", weights_only=False) network = network.to(device=device) network.eval() print(f"Layers are set to {iterations} iterations.") for layer in network: if hasattr(layer, 'iterations'): layer.iterations = iterations if only_print_network: print(network) exit() tb = SummaryWriter(log_dir=log_dir) print() t_start: float = time.perf_counter() test_correct: int = 0 test_number: int = 0 # Switch the network into evalution mode network.eval() with torch.no_grad(): for image, target in test_dataloader: output = network(test_processing_chain(image)) test_correct += (output.argmax(dim=1) == target).sum().cpu().numpy() test_number += target.shape[0] t_testing = time.perf_counter() perfomance_test_correct: float = 100.0 * test_correct / test_number tb.add_scalar("Test Number Correct", test_correct, 0) print(f"Testing: Correct={perfomance_test_correct:.2f}%") print( f"Time: Testing={(t_testing - t_start):.1f}sec" ) tb.flush() tb.close() return