Add files via upload
This commit is contained in:
parent
a564b10114
commit
b346d947b1
2 changed files with 15 additions and 14 deletions
|
@ -4,12 +4,13 @@ import matplotlib.pyplot as plt
|
|||
import os
|
||||
import glob
|
||||
from natsort import natsorted
|
||||
import sys
|
||||
|
||||
# import numpy as np
|
||||
|
||||
layer_id: int = 3
|
||||
scale_each_inner: bool = False
|
||||
scale_each_outer: bool = False
|
||||
layer_id: int = int(sys.argv[1])
|
||||
scale_each_inner: bool = True
|
||||
scale_each_outer: bool = True
|
||||
|
||||
model_path: str = "trained_models"
|
||||
filename_list: list = natsorted(glob.glob(os.path.join(model_path, str("*.pt"))))
|
||||
|
@ -47,12 +48,12 @@ v_max_abs = torch.abs(weight_grid[0, ...]).max()
|
|||
plt.subplot(3, 1, (1, 2))
|
||||
plt.imshow(
|
||||
weight_grid[0, ...],
|
||||
vmin=-v_max_abs,
|
||||
vmax=v_max_abs,
|
||||
cmap="cool",
|
||||
# vmin=-v_max_abs,
|
||||
# vmax=v_max_abs,
|
||||
cmap="hot",
|
||||
)
|
||||
plt.axis("off")
|
||||
plt.colorbar()
|
||||
#plt.colorbar()
|
||||
plt.title("Weights")
|
||||
|
||||
plt.subplot(3, 1, 3)
|
||||
|
|
|
@ -6,8 +6,8 @@
|
|||
],
|
||||
"conv_kernel_size": [
|
||||
11,
|
||||
7,
|
||||
15
|
||||
11,
|
||||
11
|
||||
],
|
||||
"conv_stride_size": [
|
||||
1,
|
||||
|
@ -30,14 +30,14 @@
|
|||
"l_relu_negative_slope": 0.1, // (0.1)
|
||||
// Pooling layer -----------------------------------------------------------
|
||||
"pooling_kernel_size": [
|
||||
3,
|
||||
0,
|
||||
0
|
||||
5,
|
||||
5,
|
||||
5
|
||||
],
|
||||
"pooling_stride": [
|
||||
2,
|
||||
0,
|
||||
0
|
||||
2,
|
||||
2
|
||||
],
|
||||
"pooling_type": "max", // (max), average, none
|
||||
// Softmax layer -----------------------------------------------------------
|
||||
|
|
Loading…
Reference in a new issue