From b346d947b1cfd69c213d50f2b7479b9635c23de1 Mon Sep 17 00:00:00 2001 From: David Rotermund <54365609+davrot@users.noreply.github.com> Date: Fri, 28 Jul 2023 00:11:46 +0200 Subject: [PATCH] Add files via upload --- inspect_weights_conv_x.py | 15 ++++++++------- network_0.json | 14 +++++++------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/inspect_weights_conv_x.py b/inspect_weights_conv_x.py index cf99f67..a145f5f 100644 --- a/inspect_weights_conv_x.py +++ b/inspect_weights_conv_x.py @@ -4,12 +4,13 @@ import matplotlib.pyplot as plt import os import glob from natsort import natsorted +import sys # import numpy as np -layer_id: int = 3 -scale_each_inner: bool = False -scale_each_outer: bool = False +layer_id: int = int(sys.argv[1]) +scale_each_inner: bool = True +scale_each_outer: bool = True model_path: str = "trained_models" filename_list: list = natsorted(glob.glob(os.path.join(model_path, str("*.pt")))) @@ -47,12 +48,12 @@ v_max_abs = torch.abs(weight_grid[0, ...]).max() plt.subplot(3, 1, (1, 2)) plt.imshow( weight_grid[0, ...], - vmin=-v_max_abs, - vmax=v_max_abs, - cmap="cool", +# vmin=-v_max_abs, +# vmax=v_max_abs, + cmap="hot", ) plt.axis("off") -plt.colorbar() +#plt.colorbar() plt.title("Weights") plt.subplot(3, 1, 3) diff --git a/network_0.json b/network_0.json index 6fed93d..d0d63ec 100644 --- a/network_0.json +++ b/network_0.json @@ -6,8 +6,8 @@ ], "conv_kernel_size": [ 11, - 7, - 15 + 11, + 11 ], "conv_stride_size": [ 1, @@ -30,14 +30,14 @@ "l_relu_negative_slope": 0.1, // (0.1) // Pooling layer ----------------------------------------------------------- "pooling_kernel_size": [ - 3, - 0, - 0 + 5, + 5, + 5 ], "pooling_stride": [ 2, - 0, - 0 + 2, + 2 ], "pooling_type": "max", // (max), average, none // Softmax layer -----------------------------------------------------------