Add files via upload
This commit is contained in:
parent
190666a067
commit
e49dba1ee4
1 changed files with 182 additions and 0 deletions
182
draw_kernels_v2.py
Normal file
182
draw_kernels_v2.py
Normal file
|
@ -0,0 +1,182 @@
|
||||||
|
# %%
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib.patches as patch
|
||||||
|
import matplotlib as mpl
|
||||||
|
import torch
|
||||||
|
from cycler import cycler
|
||||||
|
|
||||||
|
mpl.rcParams["text.usetex"] = True
|
||||||
|
mpl.rcParams["font.family"] = "serif"
|
||||||
|
|
||||||
|
|
||||||
|
def extract_kernel_stride(model: torch.nn.Sequential) -> list[dict]:
|
||||||
|
result = []
|
||||||
|
for idx, m in enumerate(model.modules()):
|
||||||
|
if isinstance(m, (torch.nn.Conv2d, torch.nn.MaxPool2d)):
|
||||||
|
result.append(
|
||||||
|
{
|
||||||
|
"layer_index": idx,
|
||||||
|
"layer_type": type(m).__name__,
|
||||||
|
"kernel_size": m.kernel_size,
|
||||||
|
"stride": m.stride,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_kernel_size(
|
||||||
|
kernel: np.ndarray, stride: np.ndarray
|
||||||
|
) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
df: np.ndarray = np.cumprod(
|
||||||
|
(
|
||||||
|
np.concatenate(
|
||||||
|
(np.array(1)[np.newaxis], stride.astype(dtype=np.int64)[:-1]), axis=0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
f = 1 + np.cumsum((kernel.astype(dtype=np.int64) - 1) * df)
|
||||||
|
|
||||||
|
print(f"Receptive field sizes: {f} ")
|
||||||
|
return f, df
|
||||||
|
|
||||||
|
|
||||||
|
def draw_kernel(
|
||||||
|
image: np.ndarray, model: torch.nn.Sequential, ignore_output_conv_layer: bool
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Call function after creating the model-to-be-trained.
|
||||||
|
"""
|
||||||
|
assert image.shape[0] == 200
|
||||||
|
assert image.shape[1] == 200
|
||||||
|
|
||||||
|
# list of colors to choose from:
|
||||||
|
prop_cycle = plt.rcParams["axes.prop_cycle"]
|
||||||
|
colors = prop_cycle.by_key()["color"]
|
||||||
|
edge_color_cycler = iter(
|
||||||
|
cycler(color=["sienna", "orange", "gold", "bisque"] + colors)
|
||||||
|
)
|
||||||
|
kernel_sizes: list[int] = []
|
||||||
|
stride_sizes: list[int] = []
|
||||||
|
layer_type: list[str] = []
|
||||||
|
|
||||||
|
# extract kernel and stride information
|
||||||
|
model_info: list[dict] = extract_kernel_stride(model)
|
||||||
|
|
||||||
|
# iterate over kernels to plot on image
|
||||||
|
for layer in model_info:
|
||||||
|
kernel_sizes.append(layer["kernel_size"])
|
||||||
|
stride_sizes.append(layer["stride"])
|
||||||
|
layer_type.append(layer["layer_type"])
|
||||||
|
|
||||||
|
# change tuples to list items:
|
||||||
|
kernel_array: np.ndarray = np.array([k[0] if isinstance(k, tuple) else k for k in kernel_sizes]) # type: ignore
|
||||||
|
stride_array: np.ndarray = np.array([s[0] if isinstance(s, tuple) else s for s in stride_sizes]) # type: ignore
|
||||||
|
|
||||||
|
# calculate values
|
||||||
|
kernels, strides = calculate_kernel_size(kernel_array, stride_array)
|
||||||
|
|
||||||
|
# position first kernel
|
||||||
|
start_x: int = 4
|
||||||
|
start_y: int = 15
|
||||||
|
|
||||||
|
# general plot structure:
|
||||||
|
plt.ion()
|
||||||
|
_, ax = plt.subplots()
|
||||||
|
ax.imshow(image, cmap="gray")
|
||||||
|
ax.tick_params(axis="both", which="major", labelsize=15)
|
||||||
|
|
||||||
|
if ignore_output_conv_layer:
|
||||||
|
number_of_layers: int = len(kernels) - 1
|
||||||
|
else:
|
||||||
|
number_of_layers = len(kernels)
|
||||||
|
|
||||||
|
for i in range(0, number_of_layers):
|
||||||
|
edgecolor = next(edge_color_cycler)["color"]
|
||||||
|
# draw kernel
|
||||||
|
kernel = patch.Rectangle(
|
||||||
|
(start_x, start_y),
|
||||||
|
kernels[i],
|
||||||
|
kernels[i],
|
||||||
|
linewidth=1.2,
|
||||||
|
edgecolor=edgecolor,
|
||||||
|
facecolor="none",
|
||||||
|
label=layer_type[i],
|
||||||
|
)
|
||||||
|
ax.add_patch(kernel)
|
||||||
|
|
||||||
|
# draw stride
|
||||||
|
stride = patch.Rectangle(
|
||||||
|
(start_x + strides[i], start_y + strides[i]),
|
||||||
|
kernels[i],
|
||||||
|
kernels[i],
|
||||||
|
linewidth=1.2,
|
||||||
|
edgecolor=edgecolor,
|
||||||
|
facecolor="none",
|
||||||
|
linestyle="dashed",
|
||||||
|
)
|
||||||
|
ax.add_patch(stride)
|
||||||
|
|
||||||
|
# add distance of next drawing
|
||||||
|
start_x += 14
|
||||||
|
start_y += 10
|
||||||
|
|
||||||
|
# final plot
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.legend(loc="upper right", fontsize=11)
|
||||||
|
plt.show(block=True)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
from jsmin import jsmin
|
||||||
|
|
||||||
|
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
sys.path.append(parent_dir)
|
||||||
|
from functions.alicorn_data_loader import alicorn_data_loader
|
||||||
|
from functions.make_cnn_v2 import make_cnn
|
||||||
|
from functions.create_logger import create_logger
|
||||||
|
|
||||||
|
ignore_output_conv_layer: bool = True
|
||||||
|
network_config_filename = "network_0.json"
|
||||||
|
config_filenname = "config_v2.json"
|
||||||
|
with open(config_filenname, "r") as file_handle:
|
||||||
|
config = json.loads(jsmin(file_handle.read()))
|
||||||
|
|
||||||
|
logger = create_logger(
|
||||||
|
save_logging_messages=False,
|
||||||
|
display_logging_messages=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# test image:
|
||||||
|
data_test = alicorn_data_loader(
|
||||||
|
num_pfinkel=[0],
|
||||||
|
load_stimuli_per_pfinkel=10,
|
||||||
|
condition=str(config["condition"]),
|
||||||
|
data_path=str(config["data_path"]),
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert data_test.__len__() > 0
|
||||||
|
input_shape = data_test.__getitem__(0)[1].shape
|
||||||
|
|
||||||
|
model = make_cnn(
|
||||||
|
network_config_filename=network_config_filename,
|
||||||
|
logger=logger,
|
||||||
|
input_shape=input_shape,
|
||||||
|
)
|
||||||
|
print(model)
|
||||||
|
|
||||||
|
# test_image = torch.zeros((1, *input_shape), dtype=torch.float32)
|
||||||
|
|
||||||
|
image = data_test.__getitem__(6)[1].squeeze(0)
|
||||||
|
|
||||||
|
# call function:
|
||||||
|
draw_kernel(
|
||||||
|
image=image.numpy(),
|
||||||
|
model=model,
|
||||||
|
ignore_output_conv_layer=ignore_output_conv_layer,
|
||||||
|
)
|
Loading…
Reference in a new issue