From dfc00ebe06040de019d502f12bba569eaee37d2a Mon Sep 17 00:00:00 2001 From: David Rotermund <54365609+davrot@users.noreply.github.com> Date: Wed, 19 Jul 2023 15:47:35 +0200 Subject: [PATCH] Add files via upload --- performance_pfinkel_plots.py | 221 +++++++++++++++++++++++++++++++++++ 1 file changed, 221 insertions(+) create mode 100644 performance_pfinkel_plots.py diff --git a/performance_pfinkel_plots.py b/performance_pfinkel_plots.py new file mode 100644 index 0000000..6609027 --- /dev/null +++ b/performance_pfinkel_plots.py @@ -0,0 +1,221 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt +import matplotlib as mpl +import os +import datetime +import re +import glob +from natsort import natsorted + +mpl.rcParams["text.usetex"] = True +mpl.rcParams["font.family"] = "serif" + +from functions.alicorn_data_loader import alicorn_data_loader +from functions.create_logger import create_logger + + +def performance_pfinkel_plot( + performances_list: list[dict], labels: list[str], save_name: str, logger +) -> None: + figure_path: str = "performance_pfinkel" + assert len(performances_list) == len(labels) + + plt.figure(figsize=[14, 10]) + # plot accuracy + plt.subplot(2, 1, 1) + for id in range(0, len(labels)): + x_values = np.zeros((len(performances_list[id].keys()))) + y_values = np.zeros((len(performances_list[id].keys()))) + + counter = 0 + for id_key in performances_list[id].keys(): + x_values[counter] = performances_list[id][id_key]["pfinkel"] + y_values[counter] = performances_list[id][id_key]["test_accuracy"] + counter += 1 + + plt.plot(x_values, y_values, label=labels[id]) + plt.xticks(x_values) + plt.title("Average accuracy", fontsize=18) + plt.xlabel("Path angle (in °)", fontsize=17) + plt.ylabel("Accuracy (\\%)", fontsize=17) + plt.legend(fontsize=14) + + # Increase tick label font size + plt.xticks(fontsize=16) + plt.yticks(fontsize=16) + plt.grid(True) + + # plot loss + plt.subplot(2, 1, 2) + for id in range(0, len(labels)): + x_values = np.zeros((len(performances_list[id].keys()))) + y_values = np.zeros((len(performances_list[id].keys()))) + + counter = 0 + for id_key in performances_list[id].keys(): + x_values[counter] = performances_list[id][id_key]["pfinkel"] + y_values[counter] = performances_list[id][id_key]["test_losses"] + counter += 1 + + plt.plot(x_values, y_values, label=labels[id]) + + plt.xticks(x_values) + plt.title("Average loss", fontsize=18) + plt.xlabel("Path angle (in °)", fontsize=17) + plt.ylabel("Loss", fontsize=17) + plt.legend(fontsize=14) + + # Increase tick label font size + plt.xticks(fontsize=16) + plt.yticks(fontsize=16) + plt.grid(True) + + plt.tight_layout() + logger.info("") + logger.info("Saved in:") + + os.makedirs(figure_path, exist_ok=True) + print( + os.path.join( + figure_path, + f"PerformancePfinkel_{save_name}_{current}.pdf", + ) + ) + plt.savefig( + os.path.join( + figure_path, + f"PerformancePfinkel_{save_name}_{current}.pdf", + ), + dpi=300, + bbox_inches="tight", + ) + plt.show() + + +if __name__ == "__main__": + model_path: str = "trained_models" + data_path: str = "/home/kk/Documents/Semester4/code/RenderStimuli/Output/" + selection_file_id: int = 0 + + # num stimuli per Pfinkel and batch size + stim_per_pfinkel: int = 10000 + batch_size: int = 1000 + # stimulus condition: + performances_list: list = [] + condition: list[str] = ["Coignless", "Natural", "Angular"] + figure_label: list[str] = ["Classic", "Corner", "Bridge"] + # load test data: + num_pfinkel: list = np.arange(0, 100, 10).tolist() + image_scale: float = 255.0 + + # ------------------------------------------ + + # create logger: + logger = create_logger( + save_logging_messages=False, + display_logging_messages=True, + ) + + device_str: str = "cuda:0" if torch.cuda.is_available() else "cpu" + logger.info(f"Using {device_str} device") + device: torch.device = torch.device(device_str) + torch.set_default_dtype(torch.float32) + + # current time: + current = datetime.datetime.now().strftime("%d%m-%H%M") + + # path to NN + list_filenames: list[str] = natsorted( + list(glob.glob(os.path.join(model_path, "*.pt"))) + ) + assert selection_file_id < len(list_filenames) + model_filename: str = str(list_filenames[selection_file_id]) + logger.info(f"Using model file: {model_filename}") + + # shorter saving name: + pattern = r"(outChannels\[.*?\])|(kernelSize\[.*?\])|(_relu)|(_seed\d+)" + matches = re.findall(pattern, model_filename) + save_name = "".join(["".join(match) for match in matches]) + + # load and evaluate model + model = torch.load(model_filename, map_location=device) + + # Set the model to evaluation mode + model.eval() + + for selected_condition in condition: + # save performances: + logger.info(f"Condition: {selected_condition}") + performances: dict = {} + for pfinkel in num_pfinkel: + test_loss: float = 0.0 + correct: int = 0 + pattern_count: int = 0 + + data_test = alicorn_data_loader( + num_pfinkel=[pfinkel], + load_stimuli_per_pfinkel=stim_per_pfinkel, + condition=selected_condition, + logger=logger, + data_path=data_path, + ) + loader = torch.utils.data.DataLoader( + data_test, shuffle=False, batch_size=batch_size + ) + + # start testing network on new stimuli: + logger.info("") + logger.info(f"-==- Start {selected_condition} " f"Pfinkel {pfinkel}° -==-") + with torch.no_grad(): + for batch_num, data in enumerate(loader): + label = data[0].to(device) + image = data[1].type(dtype=torch.float32).to(device) + image /= image_scale + + # compute prediction error; + output = model(image) + + # Label Typecast: + label = label.to(device) + + # loss and optimization + loss = torch.nn.functional.cross_entropy( + output, label, reduction="sum" + ) + pattern_count += int(label.shape[0]) + test_loss += float(loss) + prediction = output.argmax(dim=1) + correct += prediction.eq(label).sum().item() + + total_number_of_pattern: int = int(len(loader)) * int( + label.shape[0] + ) + + # logging: + logger.info( + ( + f"{selected_condition},{pfinkel}° " + "Pfinkel: " + f"[{int(pattern_count)}/{total_number_of_pattern} ({100.0 * pattern_count / total_number_of_pattern:.2f}%)]," + f" Average loss: {test_loss / pattern_count:.3e}, " + "Accuracy: " + f"{100.0 * correct / pattern_count:.2f}% " + ) + ) + + performances[pfinkel] = { + "pfinkel": pfinkel, + "test_accuracy": 100 * correct / pattern_count, + "test_losses": float(loss) / pattern_count, + } + + performances_list.append(performances) + + performance_pfinkel_plot( + performances_list=performances_list, + labels=figure_label, + save_name=save_name, + logger=logger, + ) + logger.info("-==- DONE -==-")