diff --git a/draw_kernels_v2.py b/draw_kernels_v2.py new file mode 100644 index 0000000..f88a670 --- /dev/null +++ b/draw_kernels_v2.py @@ -0,0 +1,182 @@ +# %% +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.patches as patch +import matplotlib as mpl +import torch +from cycler import cycler + +mpl.rcParams["text.usetex"] = True +mpl.rcParams["font.family"] = "serif" + + +def extract_kernel_stride(model: torch.nn.Sequential) -> list[dict]: + result = [] + for idx, m in enumerate(model.modules()): + if isinstance(m, (torch.nn.Conv2d, torch.nn.MaxPool2d)): + result.append( + { + "layer_index": idx, + "layer_type": type(m).__name__, + "kernel_size": m.kernel_size, + "stride": m.stride, + } + ) + return result + + +def calculate_kernel_size( + kernel: np.ndarray, stride: np.ndarray +) -> tuple[np.ndarray, np.ndarray]: + df: np.ndarray = np.cumprod( + ( + np.concatenate( + (np.array(1)[np.newaxis], stride.astype(dtype=np.int64)[:-1]), axis=0 + ) + ) + ) + f = 1 + np.cumsum((kernel.astype(dtype=np.int64) - 1) * df) + + print(f"Receptive field sizes: {f} ") + return f, df + + +def draw_kernel( + image: np.ndarray, model: torch.nn.Sequential, ignore_output_conv_layer: bool +) -> None: + """ + Call function after creating the model-to-be-trained. + """ + assert image.shape[0] == 200 + assert image.shape[1] == 200 + + # list of colors to choose from: + prop_cycle = plt.rcParams["axes.prop_cycle"] + colors = prop_cycle.by_key()["color"] + edge_color_cycler = iter( + cycler(color=["sienna", "orange", "gold", "bisque"] + colors) + ) + kernel_sizes: list[int] = [] + stride_sizes: list[int] = [] + layer_type: list[str] = [] + + # extract kernel and stride information + model_info: list[dict] = extract_kernel_stride(model) + + # iterate over kernels to plot on image + for layer in model_info: + kernel_sizes.append(layer["kernel_size"]) + stride_sizes.append(layer["stride"]) + layer_type.append(layer["layer_type"]) + + # change tuples to list items: + kernel_array: np.ndarray = np.array([k[0] if isinstance(k, tuple) else k for k in kernel_sizes]) # type: ignore + stride_array: np.ndarray = np.array([s[0] if isinstance(s, tuple) else s for s in stride_sizes]) # type: ignore + + # calculate values + kernels, strides = calculate_kernel_size(kernel_array, stride_array) + + # position first kernel + start_x: int = 4 + start_y: int = 15 + + # general plot structure: + plt.ion() + _, ax = plt.subplots() + ax.imshow(image, cmap="gray") + ax.tick_params(axis="both", which="major", labelsize=15) + + if ignore_output_conv_layer: + number_of_layers: int = len(kernels) - 1 + else: + number_of_layers = len(kernels) + + for i in range(0, number_of_layers): + edgecolor = next(edge_color_cycler)["color"] + # draw kernel + kernel = patch.Rectangle( + (start_x, start_y), + kernels[i], + kernels[i], + linewidth=1.2, + edgecolor=edgecolor, + facecolor="none", + label=layer_type[i], + ) + ax.add_patch(kernel) + + # draw stride + stride = patch.Rectangle( + (start_x + strides[i], start_y + strides[i]), + kernels[i], + kernels[i], + linewidth=1.2, + edgecolor=edgecolor, + facecolor="none", + linestyle="dashed", + ) + ax.add_patch(stride) + + # add distance of next drawing + start_x += 14 + start_y += 10 + + # final plot + plt.tight_layout() + plt.legend(loc="upper right", fontsize=11) + plt.show(block=True) + + +# %% +if __name__ == "__main__": + import os + import sys + import json + from jsmin import jsmin + + parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + sys.path.append(parent_dir) + from functions.alicorn_data_loader import alicorn_data_loader + from functions.make_cnn_v2 import make_cnn + from functions.create_logger import create_logger + + ignore_output_conv_layer: bool = True + network_config_filename = "network_0.json" + config_filenname = "config_v2.json" + with open(config_filenname, "r") as file_handle: + config = json.loads(jsmin(file_handle.read())) + + logger = create_logger( + save_logging_messages=False, + display_logging_messages=False, + ) + + # test image: + data_test = alicorn_data_loader( + num_pfinkel=[0], + load_stimuli_per_pfinkel=10, + condition=str(config["condition"]), + data_path=str(config["data_path"]), + logger=logger, + ) + + assert data_test.__len__() > 0 + input_shape = data_test.__getitem__(0)[1].shape + + model = make_cnn( + network_config_filename=network_config_filename, + logger=logger, + input_shape=input_shape, + ) + print(model) + + # test_image = torch.zeros((1, *input_shape), dtype=torch.float32) + + image = data_test.__getitem__(6)[1].squeeze(0) + + # call function: + draw_kernel( + image=image.numpy(), + model=model, + ignore_output_conv_layer=ignore_output_conv_layer, + )