Add files via upload

This commit is contained in:
David Rotermund 2023-07-19 15:47:35 +02:00 committed by GitHub
parent 659fbd071f
commit dfc00ebe06
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -0,0 +1,221 @@
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import os
import datetime
import re
import glob
from natsort import natsorted
mpl.rcParams["text.usetex"] = True
mpl.rcParams["font.family"] = "serif"
from functions.alicorn_data_loader import alicorn_data_loader
from functions.create_logger import create_logger
def performance_pfinkel_plot(
performances_list: list[dict], labels: list[str], save_name: str, logger
) -> None:
figure_path: str = "performance_pfinkel"
assert len(performances_list) == len(labels)
plt.figure(figsize=[14, 10])
# plot accuracy
plt.subplot(2, 1, 1)
for id in range(0, len(labels)):
x_values = np.zeros((len(performances_list[id].keys())))
y_values = np.zeros((len(performances_list[id].keys())))
counter = 0
for id_key in performances_list[id].keys():
x_values[counter] = performances_list[id][id_key]["pfinkel"]
y_values[counter] = performances_list[id][id_key]["test_accuracy"]
counter += 1
plt.plot(x_values, y_values, label=labels[id])
plt.xticks(x_values)
plt.title("Average accuracy", fontsize=18)
plt.xlabel("Path angle (in °)", fontsize=17)
plt.ylabel("Accuracy (\\%)", fontsize=17)
plt.legend(fontsize=14)
# Increase tick label font size
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.grid(True)
# plot loss
plt.subplot(2, 1, 2)
for id in range(0, len(labels)):
x_values = np.zeros((len(performances_list[id].keys())))
y_values = np.zeros((len(performances_list[id].keys())))
counter = 0
for id_key in performances_list[id].keys():
x_values[counter] = performances_list[id][id_key]["pfinkel"]
y_values[counter] = performances_list[id][id_key]["test_losses"]
counter += 1
plt.plot(x_values, y_values, label=labels[id])
plt.xticks(x_values)
plt.title("Average loss", fontsize=18)
plt.xlabel("Path angle (in °)", fontsize=17)
plt.ylabel("Loss", fontsize=17)
plt.legend(fontsize=14)
# Increase tick label font size
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.grid(True)
plt.tight_layout()
logger.info("")
logger.info("Saved in:")
os.makedirs(figure_path, exist_ok=True)
print(
os.path.join(
figure_path,
f"PerformancePfinkel_{save_name}_{current}.pdf",
)
)
plt.savefig(
os.path.join(
figure_path,
f"PerformancePfinkel_{save_name}_{current}.pdf",
),
dpi=300,
bbox_inches="tight",
)
plt.show()
if __name__ == "__main__":
model_path: str = "trained_models"
data_path: str = "/home/kk/Documents/Semester4/code/RenderStimuli/Output/"
selection_file_id: int = 0
# num stimuli per Pfinkel and batch size
stim_per_pfinkel: int = 10000
batch_size: int = 1000
# stimulus condition:
performances_list: list = []
condition: list[str] = ["Coignless", "Natural", "Angular"]
figure_label: list[str] = ["Classic", "Corner", "Bridge"]
# load test data:
num_pfinkel: list = np.arange(0, 100, 10).tolist()
image_scale: float = 255.0
# ------------------------------------------
# create logger:
logger = create_logger(
save_logging_messages=False,
display_logging_messages=True,
)
device_str: str = "cuda:0" if torch.cuda.is_available() else "cpu"
logger.info(f"Using {device_str} device")
device: torch.device = torch.device(device_str)
torch.set_default_dtype(torch.float32)
# current time:
current = datetime.datetime.now().strftime("%d%m-%H%M")
# path to NN
list_filenames: list[str] = natsorted(
list(glob.glob(os.path.join(model_path, "*.pt")))
)
assert selection_file_id < len(list_filenames)
model_filename: str = str(list_filenames[selection_file_id])
logger.info(f"Using model file: {model_filename}")
# shorter saving name:
pattern = r"(outChannels\[.*?\])|(kernelSize\[.*?\])|(_relu)|(_seed\d+)"
matches = re.findall(pattern, model_filename)
save_name = "".join(["".join(match) for match in matches])
# load and evaluate model
model = torch.load(model_filename, map_location=device)
# Set the model to evaluation mode
model.eval()
for selected_condition in condition:
# save performances:
logger.info(f"Condition: {selected_condition}")
performances: dict = {}
for pfinkel in num_pfinkel:
test_loss: float = 0.0
correct: int = 0
pattern_count: int = 0
data_test = alicorn_data_loader(
num_pfinkel=[pfinkel],
load_stimuli_per_pfinkel=stim_per_pfinkel,
condition=selected_condition,
logger=logger,
data_path=data_path,
)
loader = torch.utils.data.DataLoader(
data_test, shuffle=False, batch_size=batch_size
)
# start testing network on new stimuli:
logger.info("")
logger.info(f"-==- Start {selected_condition} " f"Pfinkel {pfinkel}° -==-")
with torch.no_grad():
for batch_num, data in enumerate(loader):
label = data[0].to(device)
image = data[1].type(dtype=torch.float32).to(device)
image /= image_scale
# compute prediction error;
output = model(image)
# Label Typecast:
label = label.to(device)
# loss and optimization
loss = torch.nn.functional.cross_entropy(
output, label, reduction="sum"
)
pattern_count += int(label.shape[0])
test_loss += float(loss)
prediction = output.argmax(dim=1)
correct += prediction.eq(label).sum().item()
total_number_of_pattern: int = int(len(loader)) * int(
label.shape[0]
)
# logging:
logger.info(
(
f"{selected_condition},{pfinkel}° "
"Pfinkel: "
f"[{int(pattern_count)}/{total_number_of_pattern} ({100.0 * pattern_count / total_number_of_pattern:.2f}%)],"
f" Average loss: {test_loss / pattern_count:.3e}, "
"Accuracy: "
f"{100.0 * correct / pattern_count:.2f}% "
)
)
performances[pfinkel] = {
"pfinkel": pfinkel,
"test_accuracy": 100 * correct / pattern_count,
"test_losses": float(loss) / pattern_count,
}
performances_list.append(performances)
performance_pfinkel_plot(
performances_list=performances_list,
labels=figure_label,
save_name=save_name,
logger=logger,
)
logger.info("-==- DONE -==-")