Add files via upload
This commit is contained in:
parent
659fbd071f
commit
dfc00ebe06
1 changed files with 221 additions and 0 deletions
221
performance_pfinkel_plots.py
Normal file
221
performance_pfinkel_plots.py
Normal file
|
@ -0,0 +1,221 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib as mpl
|
||||||
|
import os
|
||||||
|
import datetime
|
||||||
|
import re
|
||||||
|
import glob
|
||||||
|
from natsort import natsorted
|
||||||
|
|
||||||
|
mpl.rcParams["text.usetex"] = True
|
||||||
|
mpl.rcParams["font.family"] = "serif"
|
||||||
|
|
||||||
|
from functions.alicorn_data_loader import alicorn_data_loader
|
||||||
|
from functions.create_logger import create_logger
|
||||||
|
|
||||||
|
|
||||||
|
def performance_pfinkel_plot(
|
||||||
|
performances_list: list[dict], labels: list[str], save_name: str, logger
|
||||||
|
) -> None:
|
||||||
|
figure_path: str = "performance_pfinkel"
|
||||||
|
assert len(performances_list) == len(labels)
|
||||||
|
|
||||||
|
plt.figure(figsize=[14, 10])
|
||||||
|
# plot accuracy
|
||||||
|
plt.subplot(2, 1, 1)
|
||||||
|
for id in range(0, len(labels)):
|
||||||
|
x_values = np.zeros((len(performances_list[id].keys())))
|
||||||
|
y_values = np.zeros((len(performances_list[id].keys())))
|
||||||
|
|
||||||
|
counter = 0
|
||||||
|
for id_key in performances_list[id].keys():
|
||||||
|
x_values[counter] = performances_list[id][id_key]["pfinkel"]
|
||||||
|
y_values[counter] = performances_list[id][id_key]["test_accuracy"]
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
plt.plot(x_values, y_values, label=labels[id])
|
||||||
|
plt.xticks(x_values)
|
||||||
|
plt.title("Average accuracy", fontsize=18)
|
||||||
|
plt.xlabel("Path angle (in °)", fontsize=17)
|
||||||
|
plt.ylabel("Accuracy (\\%)", fontsize=17)
|
||||||
|
plt.legend(fontsize=14)
|
||||||
|
|
||||||
|
# Increase tick label font size
|
||||||
|
plt.xticks(fontsize=16)
|
||||||
|
plt.yticks(fontsize=16)
|
||||||
|
plt.grid(True)
|
||||||
|
|
||||||
|
# plot loss
|
||||||
|
plt.subplot(2, 1, 2)
|
||||||
|
for id in range(0, len(labels)):
|
||||||
|
x_values = np.zeros((len(performances_list[id].keys())))
|
||||||
|
y_values = np.zeros((len(performances_list[id].keys())))
|
||||||
|
|
||||||
|
counter = 0
|
||||||
|
for id_key in performances_list[id].keys():
|
||||||
|
x_values[counter] = performances_list[id][id_key]["pfinkel"]
|
||||||
|
y_values[counter] = performances_list[id][id_key]["test_losses"]
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
plt.plot(x_values, y_values, label=labels[id])
|
||||||
|
|
||||||
|
plt.xticks(x_values)
|
||||||
|
plt.title("Average loss", fontsize=18)
|
||||||
|
plt.xlabel("Path angle (in °)", fontsize=17)
|
||||||
|
plt.ylabel("Loss", fontsize=17)
|
||||||
|
plt.legend(fontsize=14)
|
||||||
|
|
||||||
|
# Increase tick label font size
|
||||||
|
plt.xticks(fontsize=16)
|
||||||
|
plt.yticks(fontsize=16)
|
||||||
|
plt.grid(True)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
logger.info("")
|
||||||
|
logger.info("Saved in:")
|
||||||
|
|
||||||
|
os.makedirs(figure_path, exist_ok=True)
|
||||||
|
print(
|
||||||
|
os.path.join(
|
||||||
|
figure_path,
|
||||||
|
f"PerformancePfinkel_{save_name}_{current}.pdf",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
plt.savefig(
|
||||||
|
os.path.join(
|
||||||
|
figure_path,
|
||||||
|
f"PerformancePfinkel_{save_name}_{current}.pdf",
|
||||||
|
),
|
||||||
|
dpi=300,
|
||||||
|
bbox_inches="tight",
|
||||||
|
)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
model_path: str = "trained_models"
|
||||||
|
data_path: str = "/home/kk/Documents/Semester4/code/RenderStimuli/Output/"
|
||||||
|
selection_file_id: int = 0
|
||||||
|
|
||||||
|
# num stimuli per Pfinkel and batch size
|
||||||
|
stim_per_pfinkel: int = 10000
|
||||||
|
batch_size: int = 1000
|
||||||
|
# stimulus condition:
|
||||||
|
performances_list: list = []
|
||||||
|
condition: list[str] = ["Coignless", "Natural", "Angular"]
|
||||||
|
figure_label: list[str] = ["Classic", "Corner", "Bridge"]
|
||||||
|
# load test data:
|
||||||
|
num_pfinkel: list = np.arange(0, 100, 10).tolist()
|
||||||
|
image_scale: float = 255.0
|
||||||
|
|
||||||
|
# ------------------------------------------
|
||||||
|
|
||||||
|
# create logger:
|
||||||
|
logger = create_logger(
|
||||||
|
save_logging_messages=False,
|
||||||
|
display_logging_messages=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
device_str: str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||||
|
logger.info(f"Using {device_str} device")
|
||||||
|
device: torch.device = torch.device(device_str)
|
||||||
|
torch.set_default_dtype(torch.float32)
|
||||||
|
|
||||||
|
# current time:
|
||||||
|
current = datetime.datetime.now().strftime("%d%m-%H%M")
|
||||||
|
|
||||||
|
# path to NN
|
||||||
|
list_filenames: list[str] = natsorted(
|
||||||
|
list(glob.glob(os.path.join(model_path, "*.pt")))
|
||||||
|
)
|
||||||
|
assert selection_file_id < len(list_filenames)
|
||||||
|
model_filename: str = str(list_filenames[selection_file_id])
|
||||||
|
logger.info(f"Using model file: {model_filename}")
|
||||||
|
|
||||||
|
# shorter saving name:
|
||||||
|
pattern = r"(outChannels\[.*?\])|(kernelSize\[.*?\])|(_relu)|(_seed\d+)"
|
||||||
|
matches = re.findall(pattern, model_filename)
|
||||||
|
save_name = "".join(["".join(match) for match in matches])
|
||||||
|
|
||||||
|
# load and evaluate model
|
||||||
|
model = torch.load(model_filename, map_location=device)
|
||||||
|
|
||||||
|
# Set the model to evaluation mode
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
for selected_condition in condition:
|
||||||
|
# save performances:
|
||||||
|
logger.info(f"Condition: {selected_condition}")
|
||||||
|
performances: dict = {}
|
||||||
|
for pfinkel in num_pfinkel:
|
||||||
|
test_loss: float = 0.0
|
||||||
|
correct: int = 0
|
||||||
|
pattern_count: int = 0
|
||||||
|
|
||||||
|
data_test = alicorn_data_loader(
|
||||||
|
num_pfinkel=[pfinkel],
|
||||||
|
load_stimuli_per_pfinkel=stim_per_pfinkel,
|
||||||
|
condition=selected_condition,
|
||||||
|
logger=logger,
|
||||||
|
data_path=data_path,
|
||||||
|
)
|
||||||
|
loader = torch.utils.data.DataLoader(
|
||||||
|
data_test, shuffle=False, batch_size=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# start testing network on new stimuli:
|
||||||
|
logger.info("")
|
||||||
|
logger.info(f"-==- Start {selected_condition} " f"Pfinkel {pfinkel}° -==-")
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_num, data in enumerate(loader):
|
||||||
|
label = data[0].to(device)
|
||||||
|
image = data[1].type(dtype=torch.float32).to(device)
|
||||||
|
image /= image_scale
|
||||||
|
|
||||||
|
# compute prediction error;
|
||||||
|
output = model(image)
|
||||||
|
|
||||||
|
# Label Typecast:
|
||||||
|
label = label.to(device)
|
||||||
|
|
||||||
|
# loss and optimization
|
||||||
|
loss = torch.nn.functional.cross_entropy(
|
||||||
|
output, label, reduction="sum"
|
||||||
|
)
|
||||||
|
pattern_count += int(label.shape[0])
|
||||||
|
test_loss += float(loss)
|
||||||
|
prediction = output.argmax(dim=1)
|
||||||
|
correct += prediction.eq(label).sum().item()
|
||||||
|
|
||||||
|
total_number_of_pattern: int = int(len(loader)) * int(
|
||||||
|
label.shape[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
# logging:
|
||||||
|
logger.info(
|
||||||
|
(
|
||||||
|
f"{selected_condition},{pfinkel}° "
|
||||||
|
"Pfinkel: "
|
||||||
|
f"[{int(pattern_count)}/{total_number_of_pattern} ({100.0 * pattern_count / total_number_of_pattern:.2f}%)],"
|
||||||
|
f" Average loss: {test_loss / pattern_count:.3e}, "
|
||||||
|
"Accuracy: "
|
||||||
|
f"{100.0 * correct / pattern_count:.2f}% "
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
performances[pfinkel] = {
|
||||||
|
"pfinkel": pfinkel,
|
||||||
|
"test_accuracy": 100 * correct / pattern_count,
|
||||||
|
"test_losses": float(loss) / pattern_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
performances_list.append(performances)
|
||||||
|
|
||||||
|
performance_pfinkel_plot(
|
||||||
|
performances_list=performances_list,
|
||||||
|
labels=figure_label,
|
||||||
|
save_name=save_name,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
logger.info("-==- DONE -==-")
|
Loading…
Reference in a new issue