From d93cb0a9b3617e52212c5f2b3cb91f7038c4c5c0 Mon Sep 17 00:00:00 2001 From: davrot Date: Tue, 10 Dec 2024 12:48:21 +0100 Subject: [PATCH] =?UTF-8?q?Dateien=20nach=20=E2=80=9E/=E2=80=9C=20hochlade?= =?UTF-8?q?n?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- convert_log_to_numpy.py | 31 +++++++++++++++++++++++++++++++ plot.py | 15 +++++++++++++++ 2 files changed, 46 insertions(+) create mode 100644 convert_log_to_numpy.py create mode 100644 plot.py diff --git a/convert_log_to_numpy.py b/convert_log_to_numpy.py new file mode 100644 index 0000000..05a5427 --- /dev/null +++ b/convert_log_to_numpy.py @@ -0,0 +1,31 @@ +import os +import glob + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + +from tensorboard.backend.event_processing import event_accumulator # type: ignore +import numpy as np + + +def get_data(path: str = "log_cnn"): + acc = event_accumulator.EventAccumulator(path) + acc.Reload() + + which_scalar = "Test Number Correct" + te = acc.Scalars(which_scalar) + + np_temp = np.zeros((len(te), 2)) + + for id in range(0, len(te)): + np_temp[id, 0] = te[id].step + np_temp[id, 1] = te[id].value + + print(np_temp[:, 1] / 100) + np_temp = np.nan_to_num(np_temp) + return np_temp + + +for path in glob.glob("log_*"): + print(path) + data = get_data(path) + np.save("data_" + path + ".npy", data) diff --git a/plot.py b/plot.py new file mode 100644 index 0000000..ad22d33 --- /dev/null +++ b/plot.py @@ -0,0 +1,15 @@ +import numpy as np +import matplotlib.pyplot as plt + +data = np.load("data_log.npy") +plt.loglog( + data[:, 0], + 100.0 * (1.0 - data[:, 1] / 10000.0), + "k", +) + +plt.legend() +plt.xlabel("Epoch") +plt.ylabel("Error [%]") +plt.title("CIFAR10") +plt.show()