kk_contour_net_shallow/performance_pfinkel_plots.py
2023-07-19 15:47:35 +02:00

221 lines
7.1 KiB
Python

import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import os
import datetime
import re
import glob
from natsort import natsorted
mpl.rcParams["text.usetex"] = True
mpl.rcParams["font.family"] = "serif"
from functions.alicorn_data_loader import alicorn_data_loader
from functions.create_logger import create_logger
def performance_pfinkel_plot(
performances_list: list[dict], labels: list[str], save_name: str, logger
) -> None:
figure_path: str = "performance_pfinkel"
assert len(performances_list) == len(labels)
plt.figure(figsize=[14, 10])
# plot accuracy
plt.subplot(2, 1, 1)
for id in range(0, len(labels)):
x_values = np.zeros((len(performances_list[id].keys())))
y_values = np.zeros((len(performances_list[id].keys())))
counter = 0
for id_key in performances_list[id].keys():
x_values[counter] = performances_list[id][id_key]["pfinkel"]
y_values[counter] = performances_list[id][id_key]["test_accuracy"]
counter += 1
plt.plot(x_values, y_values, label=labels[id])
plt.xticks(x_values)
plt.title("Average accuracy", fontsize=18)
plt.xlabel("Path angle (in °)", fontsize=17)
plt.ylabel("Accuracy (\\%)", fontsize=17)
plt.legend(fontsize=14)
# Increase tick label font size
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.grid(True)
# plot loss
plt.subplot(2, 1, 2)
for id in range(0, len(labels)):
x_values = np.zeros((len(performances_list[id].keys())))
y_values = np.zeros((len(performances_list[id].keys())))
counter = 0
for id_key in performances_list[id].keys():
x_values[counter] = performances_list[id][id_key]["pfinkel"]
y_values[counter] = performances_list[id][id_key]["test_losses"]
counter += 1
plt.plot(x_values, y_values, label=labels[id])
plt.xticks(x_values)
plt.title("Average loss", fontsize=18)
plt.xlabel("Path angle (in °)", fontsize=17)
plt.ylabel("Loss", fontsize=17)
plt.legend(fontsize=14)
# Increase tick label font size
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.grid(True)
plt.tight_layout()
logger.info("")
logger.info("Saved in:")
os.makedirs(figure_path, exist_ok=True)
print(
os.path.join(
figure_path,
f"PerformancePfinkel_{save_name}_{current}.pdf",
)
)
plt.savefig(
os.path.join(
figure_path,
f"PerformancePfinkel_{save_name}_{current}.pdf",
),
dpi=300,
bbox_inches="tight",
)
plt.show()
if __name__ == "__main__":
model_path: str = "trained_models"
data_path: str = "/home/kk/Documents/Semester4/code/RenderStimuli/Output/"
selection_file_id: int = 0
# num stimuli per Pfinkel and batch size
stim_per_pfinkel: int = 10000
batch_size: int = 1000
# stimulus condition:
performances_list: list = []
condition: list[str] = ["Coignless", "Natural", "Angular"]
figure_label: list[str] = ["Classic", "Corner", "Bridge"]
# load test data:
num_pfinkel: list = np.arange(0, 100, 10).tolist()
image_scale: float = 255.0
# ------------------------------------------
# create logger:
logger = create_logger(
save_logging_messages=False,
display_logging_messages=True,
)
device_str: str = "cuda:0" if torch.cuda.is_available() else "cpu"
logger.info(f"Using {device_str} device")
device: torch.device = torch.device(device_str)
torch.set_default_dtype(torch.float32)
# current time:
current = datetime.datetime.now().strftime("%d%m-%H%M")
# path to NN
list_filenames: list[str] = natsorted(
list(glob.glob(os.path.join(model_path, "*.pt")))
)
assert selection_file_id < len(list_filenames)
model_filename: str = str(list_filenames[selection_file_id])
logger.info(f"Using model file: {model_filename}")
# shorter saving name:
pattern = r"(outChannels\[.*?\])|(kernelSize\[.*?\])|(_relu)|(_seed\d+)"
matches = re.findall(pattern, model_filename)
save_name = "".join(["".join(match) for match in matches])
# load and evaluate model
model = torch.load(model_filename, map_location=device)
# Set the model to evaluation mode
model.eval()
for selected_condition in condition:
# save performances:
logger.info(f"Condition: {selected_condition}")
performances: dict = {}
for pfinkel in num_pfinkel:
test_loss: float = 0.0
correct: int = 0
pattern_count: int = 0
data_test = alicorn_data_loader(
num_pfinkel=[pfinkel],
load_stimuli_per_pfinkel=stim_per_pfinkel,
condition=selected_condition,
logger=logger,
data_path=data_path,
)
loader = torch.utils.data.DataLoader(
data_test, shuffle=False, batch_size=batch_size
)
# start testing network on new stimuli:
logger.info("")
logger.info(f"-==- Start {selected_condition} " f"Pfinkel {pfinkel}° -==-")
with torch.no_grad():
for batch_num, data in enumerate(loader):
label = data[0].to(device)
image = data[1].type(dtype=torch.float32).to(device)
image /= image_scale
# compute prediction error;
output = model(image)
# Label Typecast:
label = label.to(device)
# loss and optimization
loss = torch.nn.functional.cross_entropy(
output, label, reduction="sum"
)
pattern_count += int(label.shape[0])
test_loss += float(loss)
prediction = output.argmax(dim=1)
correct += prediction.eq(label).sum().item()
total_number_of_pattern: int = int(len(loader)) * int(
label.shape[0]
)
# logging:
logger.info(
(
f"{selected_condition},{pfinkel}° "
"Pfinkel: "
f"[{int(pattern_count)}/{total_number_of_pattern} ({100.0 * pattern_count / total_number_of_pattern:.2f}%)],"
f" Average loss: {test_loss / pattern_count:.3e}, "
"Accuracy: "
f"{100.0 * correct / pattern_count:.2f}% "
)
)
performances[pfinkel] = {
"pfinkel": pfinkel,
"test_accuracy": 100 * correct / pattern_count,
"test_losses": float(loss) / pattern_count,
}
performances_list.append(performances)
performance_pfinkel_plot(
performances_list=performances_list,
labels=figure_label,
save_name=save_name,
logger=logger,
)
logger.info("-==- DONE -==-")