kk_contour_net_shallow/thesis code/shallow net/functions/plot_intermediate.py

85 lines
2.5 KiB
Python
Raw Permalink Normal View History

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import os
import re
mpl.rcParams["text.usetex"] = True
mpl.rcParams["font.family"] = "serif"
def plot_intermediate(
train_accuracy: list[float],
test_accuracy: list[float],
train_losses: list[float],
test_losses: list[float],
save_name: str,
reduction_factor: int = 1,
) -> None:
assert len(train_accuracy) == len(test_accuracy)
assert len(train_accuracy) == len(train_losses)
assert len(train_accuracy) == len(test_losses)
# legend:
pattern = (
r"(outChannels\[\d+(?:, \d+)*\]_kernelSize\[\d+(?:, \d+)*\]_)([^_]+)(?=_stride)"
)
matches = re.findall(pattern, save_name)
legend_label = "".join(["".join(match) for match in matches])
max_epochs: int = len(train_accuracy)
# set stepsize
x = np.arange(1, max_epochs + 1)
stepsize = max_epochs // reduction_factor
# accuracies
plt.figure(figsize=[12, 7])
plt.subplot(2, 1, 1)
plt.plot(x, np.array(train_accuracy), label="Train: " + str(legend_label))
plt.plot(x, np.array(test_accuracy), label="Test: " + str(legend_label))
plt.title("Training and Testing Accuracy", fontsize=18)
plt.xlabel("Epoch", fontsize=18)
plt.ylabel("Accuracy (\\%)", fontsize=18)
plt.legend(fontsize=14)
plt.xticks(
np.concatenate((np.array([1]), np.arange(stepsize, max_epochs + 1, stepsize))),
np.concatenate((np.array([1]), np.arange(stepsize, max_epochs + 1, stepsize))),
)
# Increase tick label font size
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.grid(True)
# losses
plt.subplot(2, 1, 2)
plt.plot(x, np.array(train_losses), label="Train: " + str(legend_label))
plt.plot(x, np.array(test_losses), label="Test: " + str(legend_label))
plt.title("Training and Testing Losses", fontsize=18)
plt.xlabel("Epoch", fontsize=18)
plt.ylabel("Loss", fontsize=18)
plt.legend(fontsize=14)
plt.xticks(
np.concatenate((np.array([1]), np.arange(stepsize, max_epochs + 1, stepsize))),
np.concatenate((np.array([1]), np.arange(stepsize, max_epochs + 1, stepsize))),
)
# Increase tick label font size
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.grid(True)
plt.tight_layout()
os.makedirs("performance_plots", exist_ok=True)
plt.savefig(
os.path.join(
"performance_plots",
f"performance_{save_name}.pdf",
),
dpi=300,
bbox_inches="tight",
)
plt.show()