From ed5ac98241d8d3ec4704aaf40fe8ee286865e15c Mon Sep 17 00:00:00 2001 From: David Rotermund <54365609+davrot@users.noreply.github.com> Date: Sat, 29 Jul 2023 03:20:24 +0200 Subject: [PATCH] Add files via upload --- draw_kernels_v2.py | 235 +++++++++++++++++++++++++++++---------------- 1 file changed, 154 insertions(+), 81 deletions(-) diff --git a/draw_kernels_v2.py b/draw_kernels_v2.py index f88a670..62a1a65 100644 --- a/draw_kernels_v2.py +++ b/draw_kernels_v2.py @@ -10,39 +10,11 @@ mpl.rcParams["text.usetex"] = True mpl.rcParams["font.family"] = "serif" -def extract_kernel_stride(model: torch.nn.Sequential) -> list[dict]: - result = [] - for idx, m in enumerate(model.modules()): - if isinstance(m, (torch.nn.Conv2d, torch.nn.MaxPool2d)): - result.append( - { - "layer_index": idx, - "layer_type": type(m).__name__, - "kernel_size": m.kernel_size, - "stride": m.stride, - } - ) - return result - - -def calculate_kernel_size( - kernel: np.ndarray, stride: np.ndarray -) -> tuple[np.ndarray, np.ndarray]: - df: np.ndarray = np.cumprod( - ( - np.concatenate( - (np.array(1)[np.newaxis], stride.astype(dtype=np.int64)[:-1]), axis=0 - ) - ) - ) - f = 1 + np.cumsum((kernel.astype(dtype=np.int64) - 1) * df) - - print(f"Receptive field sizes: {f} ") - return f, df - - def draw_kernel( - image: np.ndarray, model: torch.nn.Sequential, ignore_output_conv_layer: bool + image: np.ndarray, + coordinate_list: list, + layer_type_list: list, + ignore_output_conv_layer: bool, ) -> None: """ Call function after creating the model-to-be-trained. @@ -56,25 +28,6 @@ def draw_kernel( edge_color_cycler = iter( cycler(color=["sienna", "orange", "gold", "bisque"] + colors) ) - kernel_sizes: list[int] = [] - stride_sizes: list[int] = [] - layer_type: list[str] = [] - - # extract kernel and stride information - model_info: list[dict] = extract_kernel_stride(model) - - # iterate over kernels to plot on image - for layer in model_info: - kernel_sizes.append(layer["kernel_size"]) - stride_sizes.append(layer["stride"]) - layer_type.append(layer["layer_type"]) - - # change tuples to list items: - kernel_array: np.ndarray = np.array([k[0] if isinstance(k, tuple) else k for k in kernel_sizes]) # type: ignore - stride_array: np.ndarray = np.array([s[0] if isinstance(s, tuple) else s for s in stride_sizes]) # type: ignore - - # calculate values - kernels, strides = calculate_kernel_size(kernel_array, stride_array) # position first kernel start_x: int = 4 @@ -87,39 +40,44 @@ def draw_kernel( ax.tick_params(axis="both", which="major", labelsize=15) if ignore_output_conv_layer: - number_of_layers: int = len(kernels) - 1 + number_of_layers: int = len(layer_type_list) - 1 else: - number_of_layers = len(kernels) + number_of_layers = len(layer_type_list) for i in range(0, number_of_layers): - edgecolor = next(edge_color_cycler)["color"] - # draw kernel - kernel = patch.Rectangle( - (start_x, start_y), - kernels[i], - kernels[i], - linewidth=1.2, - edgecolor=edgecolor, - facecolor="none", - label=layer_type[i], - ) - ax.add_patch(kernel) + if layer_type_list[i] is not None: + kernels = int(coordinate_list[i].shape[0]) + edgecolor = next(edge_color_cycler)["color"] + # draw kernel + kernel = patch.Rectangle( + (start_x, start_y), + kernels, + kernels, + linewidth=1.2, + edgecolor=edgecolor, + facecolor="none", + label=layer_type_list[i], + ) + ax.add_patch(kernel) - # draw stride - stride = patch.Rectangle( - (start_x + strides[i], start_y + strides[i]), - kernels[i], - kernels[i], - linewidth=1.2, - edgecolor=edgecolor, - facecolor="none", - linestyle="dashed", - ) - ax.add_patch(stride) + if coordinate_list[i].shape[1] > 1: + strides = int(coordinate_list[i][0, 1]) - int(coordinate_list[i][0, 0]) - # add distance of next drawing - start_x += 14 - start_y += 10 + # draw stride + stride = patch.Rectangle( + (start_x + strides, start_y + strides), + kernels, + kernels, + linewidth=1.2, + edgecolor=edgecolor, + facecolor="none", + linestyle="dashed", + ) + ax.add_patch(stride) + + # add distance of next drawing + start_x += 14 + start_y += 10 # final plot plt.tight_layout() @@ -127,6 +85,108 @@ def draw_kernel( plt.show(block=True) +def unfold( + layer: torch.nn.Conv2d | torch.nn.MaxPool2d | torch.nn.AvgPool2d, size: int +) -> torch.Tensor: + if isinstance(layer.kernel_size, tuple): + assert layer.kernel_size[0] == layer.kernel_size[1] + kernel_size: int = int(layer.kernel_size[0]) + else: + kernel_size = int(layer.kernel_size) + + if isinstance(layer.dilation, tuple): + assert layer.dilation[0] == layer.dilation[1] + dilation: int = int(layer.dilation[0]) + else: + dilation = int(layer.dilation) # type: ignore + + if isinstance(layer.padding, tuple): + assert layer.padding[0] == layer.padding[1] + padding: int = int(layer.padding[0]) + else: + padding = int(layer.padding) + + if isinstance(layer.stride, tuple): + assert layer.stride[0] == layer.stride[1] + stride: int = int(layer.stride[0]) + else: + stride = int(layer.stride) + + out = ( + torch.nn.functional.unfold( + torch.arange(0, size, dtype=torch.float32) + .unsqueeze(0) + .unsqueeze(0) + .unsqueeze(-1), + kernel_size=(kernel_size, 1), + dilation=(dilation, 1), + padding=(padding, 0), + stride=(stride, 1), + ) + .squeeze(0) + .type(torch.int64) + ) + + return out + + +def analyse_network( + model: torch.nn.Sequential, input_shape: int +) -> tuple[list, list, list]: + combined_list: list = [] + coordinate_list: list = [] + layer_type_list: list = [] + pixel_used: list[int] = [] + + size: int = int(input_shape) + + for layer_id in range(0, len(model)): + if isinstance( + model[layer_id], (torch.nn.Conv2d, torch.nn.MaxPool2d, torch.nn.AvgPool2d) + ): + out = unfold(layer=model[layer_id], size=size) + coordinate_list.append(out) + layer_type_list.append( + str(type(model[layer_id])).split(".")[-1].split("'")[0] + ) + size = int(out.shape[-1]) + else: + coordinate_list.append(None) + layer_type_list.append(None) + + assert coordinate_list[0] is not None + combined_list.append(coordinate_list[0]) + + for i in range(1, len(coordinate_list)): + if coordinate_list[i] is None: + combined_list.append(combined_list[i - 1]) + else: + for pos in range(0, coordinate_list[i].shape[-1]): + idx_shape: int | None = None + + idx = torch.unique( + torch.flatten(combined_list[i - 1][:, coordinate_list[i][:, pos]]) + ) + if idx_shape is None: + idx_shape = idx.shape[0] + assert idx_shape == idx.shape[0] + + assert idx_shape is not None + + temp = torch.zeros((idx_shape, coordinate_list[i].shape[-1])) + for pos in range(0, coordinate_list[i].shape[-1]): + idx = torch.unique( + torch.flatten(combined_list[i - 1][:, coordinate_list[i][:, pos]]) + ) + temp[:, pos] = idx + combined_list.append(temp) + + for i in range(0, len(combined_list)): + pixel_used.append(int(torch.unique(torch.flatten(combined_list[i])).shape[0])) + + return combined_list, layer_type_list, pixel_used + + # %% if __name__ == "__main__": import os @@ -170,13 +230,26 @@ if __name__ == "__main__": ) print(model) - # test_image = torch.zeros((1, *input_shape), dtype=torch.float32) + assert input_shape[-2] == input_shape[-1] + coordinate_list, layer_type_list, pixel_used = analyse_network( + model=model, input_shape=int(input_shape[-1]) + ) + + for i in range(0, len(coordinate_list)): + print( + ( + f"Layer: {i}, Positions: {coordinate_list[i].shape[1]}, " + f"Pixel per Positions: {coordinate_list[i].shape[0]}, " + f"Type: {layer_type_list[i]}, Number of pixel used: {pixel_used[i]}" + ) + ) image = data_test.__getitem__(6)[1].squeeze(0) # call function: draw_kernel( image=image.numpy(), - model=model, + coordinate_list=coordinate_list, + layer_type_list=layer_type_list, ignore_output_conv_layer=ignore_output_conv_layer, )