Add files via upload
In der config Datei kann man nun einstellen, ob während des Training mit leaky relu, bei einer Performance von 100% auf relu geswitched wird (d.h. leaky relu mit slope = 0.0). In der cnn_trainin.py musste ich beim Lesen und Laden der config.json aufgrund eines komischen Errors beim Ausführen der .sh-file was ändern.
This commit is contained in:
parent
475746ad41
commit
c4a1737fa7
8 changed files with 73 additions and 20 deletions
|
@ -2,7 +2,7 @@ Directory="/home/kk/Documents/Semester4/code/Classic_contour_net_shallow"
|
|||
Priority="-500"
|
||||
echo $Directory
|
||||
mkdir $Directory/argh_log_classic
|
||||
for out_channels_idx in {0..0}; do
|
||||
for out_channels_idx in {0..6}; do
|
||||
for kernel_size_idx in {0..0}; do
|
||||
for stride_idx in {0..0}; do
|
||||
echo "hostname; cd $Directory ; /home/kk/P3.10/bin/python3 cnn_training.py --idx-conv-out-channels-list $out_channels_idx --idx-conv-kernel-sizes $kernel_size_idx --idx-conv-stride-sizes $stride_idx -s \$JOB_ID" | qsub -o $Directory/argh_log_classic -j y -p $Priority -q gp4u,gp3u -N ClassicTraining
|
||||
|
|
|
@ -29,8 +29,11 @@ def main(
|
|||
) -> None:
|
||||
config_filenname = "config.json"
|
||||
with open(config_filenname, "r") as file_handle:
|
||||
config = json.loads(jsmin(file_handle.read()))
|
||||
|
||||
file_contents = file_handle.read()
|
||||
f_contents = jsmin(file_contents)
|
||||
config = json.loads(f_contents)
|
||||
# config = json.loads(jsmin(file_handle.read()))
|
||||
|
||||
# get model information:
|
||||
output_channels = config["conv_out_channels_list"][idx_conv_out_channels_list]
|
||||
|
||||
|
@ -81,6 +84,7 @@ def main(
|
|||
use_adam=bool(config["use_adam"]),
|
||||
use_plot_intermediate=bool(config["use_plot_intermediate"]),
|
||||
leak_relu_negative_slope=float(config["leak_relu_negative_slope"]),
|
||||
switch_leakyR_to_relu=bool(config["switch_leakyR_to_relu"]),
|
||||
scheduler_verbose=bool(config["scheduler_verbose"]),
|
||||
scheduler_factor=float(config["scheduler_factor"]),
|
||||
precision_100_percent=int(config["precision_100_percent"]),
|
||||
|
@ -118,6 +122,7 @@ def run_network(
|
|||
use_adam: bool,
|
||||
use_plot_intermediate: bool,
|
||||
leak_relu_negative_slope: float,
|
||||
switch_leakyR_to_relu: bool,
|
||||
scheduler_verbose: bool,
|
||||
scheduler_factor: float,
|
||||
precision_100_percent: int,
|
||||
|
@ -129,6 +134,9 @@ def run_network(
|
|||
device: torch.device = torch.device(device_str)
|
||||
torch.set_default_dtype(torch.float32)
|
||||
|
||||
# switch to relu if using leaky relu
|
||||
switched_to_relu: bool = False
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
logger.info("-==- START -==-")
|
||||
|
||||
|
@ -362,8 +370,19 @@ def run_network(
|
|||
|
||||
# stop learning: done
|
||||
if round(previous_test_acc, precision_100_percent) == 100.0:
|
||||
logger.info("100% test performance reached. Stop training.")
|
||||
break
|
||||
if activation_function == "leaky relu":
|
||||
if switch_leakyR_to_relu and not switched_to_relu:
|
||||
logger.info(
|
||||
"100% test performance reached. Switching to LeakyReLU slope 0.0."
|
||||
)
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, torch.nn.LeakyReLU):
|
||||
module.negative_slope = 0.0
|
||||
logger.info(model)
|
||||
switched_to_relu = True
|
||||
else:
|
||||
logger.info("100% test performance reached. Stop training.")
|
||||
break
|
||||
|
||||
if use_plot_intermediate:
|
||||
plot_intermediate(
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
"save_ever_x_epochs": 10, // (10)
|
||||
"activation_function": "leaky relu", // tanh, relu, (leaky relu), none
|
||||
"leak_relu_negative_slope": 0.1, // (0.1)
|
||||
"switch_leakyR_to_relu": true,
|
||||
// LR Scheduler ->
|
||||
"use_scheduler": true, // (true), false
|
||||
"scheduler_verbose": true,
|
||||
|
@ -34,11 +35,41 @@
|
|||
"condition": "Coignless",
|
||||
"scale_data": 255.0, // (255.0),
|
||||
"conv_out_channels_list": [
|
||||
[
|
||||
32,
|
||||
8,
|
||||
8
|
||||
],
|
||||
[
|
||||
8,
|
||||
8,
|
||||
8
|
||||
],
|
||||
[
|
||||
6,
|
||||
8,
|
||||
8
|
||||
],
|
||||
[
|
||||
4,
|
||||
8,
|
||||
8
|
||||
],
|
||||
[
|
||||
3,
|
||||
8,
|
||||
8
|
||||
]
|
||||
],
|
||||
[
|
||||
2,
|
||||
8,
|
||||
8
|
||||
],
|
||||
[
|
||||
1,
|
||||
8,
|
||||
8
|
||||
]
|
||||
],
|
||||
"conv_kernel_sizes": [
|
||||
[
|
||||
|
|
Binary file not shown.
|
@ -21,7 +21,9 @@ def plot_intermediate(
|
|||
assert len(train_accuracy) == len(test_losses)
|
||||
|
||||
# legend:
|
||||
pattern = r"(outChannels\[\d+(?:, \d+)*\]_kernelSize\[\d+(?:, \d+)*\]_)(\w+)(?=_stride)"
|
||||
pattern = (
|
||||
r"(outChannels\[\d+(?:, \d+)*\]_kernelSize\[\d+(?:, \d+)*\]_)([^_]+)(?=_stride)"
|
||||
)
|
||||
matches = re.findall(pattern, save_name)
|
||||
legend_label = "".join(["".join(match) for match in matches])
|
||||
|
||||
|
@ -35,8 +37,8 @@ def plot_intermediate(
|
|||
plt.figure(figsize=[12, 7])
|
||||
plt.subplot(2, 1, 1)
|
||||
|
||||
plt.plot(x, np.array(train_accuracy), label="Train: " + str(legend_label))
|
||||
plt.plot(x, np.array(test_accuracy), label="Test: " + str(legend_label))
|
||||
plt.plot(x, np.array(train_accuracy), label="Train " + str(legend_label))
|
||||
plt.plot(x, np.array(test_accuracy), label="Test " + str(legend_label))
|
||||
plt.title("Training and Testing Accuracy", fontsize=18)
|
||||
plt.xlabel("Epoch", fontsize=18)
|
||||
plt.ylabel("Accuracy (\\%)", fontsize=18)
|
||||
|
@ -53,8 +55,8 @@ def plot_intermediate(
|
|||
|
||||
# losses
|
||||
plt.subplot(2, 1, 2)
|
||||
plt.plot(x, np.array(train_losses), label="Train: " + str(legend_label))
|
||||
plt.plot(x, np.array(test_losses), label="Test: " + str(legend_label))
|
||||
plt.plot(x, np.array(train_losses), label="Train " + str(legend_label))
|
||||
plt.plot(x, np.array(test_losses), label="Test " + str(legend_label))
|
||||
plt.title("Training and Testing Losses", fontsize=18)
|
||||
plt.xlabel("Epoch", fontsize=18)
|
||||
plt.ylabel("Loss", fontsize=18)
|
||||
|
|
|
@ -13,7 +13,7 @@ from functions.make_cnn import make_cnn # noqa
|
|||
device = torch.device("cpu")
|
||||
|
||||
# path to NN
|
||||
nn = "ArghCNN_numConvLayers3_outChannels[6, 8, 8]_kernelSize[7, 15]_leaky relu_stride1_trainFirstConvLayerTrue_seed287302_Coignless_801Epoch_2807-0857.pt"
|
||||
nn = "ArghCNN_numConvLayers3_outChannels[3, 8, 8]_kernelSize[7, 15]_leaky relu_stride1_trainFirstConvLayerTrue_seed290415_Coignless_1307Epoch_3107-0912.pt"
|
||||
PATH = f"../trained_models/{nn}"
|
||||
|
||||
# load and evaluate model
|
||||
|
|
|
@ -4,12 +4,13 @@ import matplotlib.pyplot as plt
|
|||
import os
|
||||
import glob
|
||||
from natsort import natsorted
|
||||
import sys
|
||||
|
||||
# import numpy as np
|
||||
|
||||
layer_id: int = 3
|
||||
scale_each_inner: bool = False
|
||||
scale_each_outer: bool = False
|
||||
layer_id: int = int(sys.argv[1])
|
||||
scale_each_inner: bool = True
|
||||
scale_each_outer: bool = True
|
||||
|
||||
model_path: str = "trained_models"
|
||||
filename_list: list = natsorted(glob.glob(os.path.join(model_path, str("*.pt"))))
|
||||
|
@ -47,12 +48,12 @@ v_max_abs = torch.abs(weight_grid[0, ...]).max()
|
|||
plt.subplot(3, 1, (1, 2))
|
||||
plt.imshow(
|
||||
weight_grid[0, ...],
|
||||
vmin=-v_max_abs,
|
||||
vmax=v_max_abs,
|
||||
cmap="cool",
|
||||
# vmin=-v_max_abs,
|
||||
# vmax=v_max_abs,
|
||||
cmap="hot",
|
||||
)
|
||||
plt.axis("off")
|
||||
plt.colorbar()
|
||||
#plt.colorbar()
|
||||
plt.title("Weights")
|
||||
|
||||
plt.subplot(3, 1, 3)
|
||||
|
|
|
@ -132,7 +132,7 @@ if __name__ == "__main__":
|
|||
)
|
||||
assert selection_file_id < len(list_filenames)
|
||||
# model_filename: str = str(list_filenames[selection_file_id])
|
||||
model_filename: str = "./trained_models/ArghCNN_numConvLayers3_outChannels[6, 8, 8]_kernelSize[7, 15]_leaky relu_stride1_trainFirstConvLayerTrue_seed287302_Coignless_801Epoch_2807-0857.pt"
|
||||
model_filename: str = "./trained_models/ArghCNN_numConvLayers3_outChannels[3, 8, 8]_kernelSize[7, 15]_leaky relu_stride1_trainFirstConvLayerTrue_seed290415_Coignless_1307Epoch_3107-0912.pt"
|
||||
logger.info(f"Using model file: {model_filename}")
|
||||
|
||||
# shorter saving name:
|
||||
|
|
Loading…
Reference in a new issue