Add files via upload
This commit is contained in:
parent
5439f31d0d
commit
ed5ac98241
1 changed files with 154 additions and 81 deletions
|
@ -10,39 +10,11 @@ mpl.rcParams["text.usetex"] = True
|
||||||
mpl.rcParams["font.family"] = "serif"
|
mpl.rcParams["font.family"] = "serif"
|
||||||
|
|
||||||
|
|
||||||
def extract_kernel_stride(model: torch.nn.Sequential) -> list[dict]:
|
|
||||||
result = []
|
|
||||||
for idx, m in enumerate(model.modules()):
|
|
||||||
if isinstance(m, (torch.nn.Conv2d, torch.nn.MaxPool2d)):
|
|
||||||
result.append(
|
|
||||||
{
|
|
||||||
"layer_index": idx,
|
|
||||||
"layer_type": type(m).__name__,
|
|
||||||
"kernel_size": m.kernel_size,
|
|
||||||
"stride": m.stride,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_kernel_size(
|
|
||||||
kernel: np.ndarray, stride: np.ndarray
|
|
||||||
) -> tuple[np.ndarray, np.ndarray]:
|
|
||||||
df: np.ndarray = np.cumprod(
|
|
||||||
(
|
|
||||||
np.concatenate(
|
|
||||||
(np.array(1)[np.newaxis], stride.astype(dtype=np.int64)[:-1]), axis=0
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
f = 1 + np.cumsum((kernel.astype(dtype=np.int64) - 1) * df)
|
|
||||||
|
|
||||||
print(f"Receptive field sizes: {f} ")
|
|
||||||
return f, df
|
|
||||||
|
|
||||||
|
|
||||||
def draw_kernel(
|
def draw_kernel(
|
||||||
image: np.ndarray, model: torch.nn.Sequential, ignore_output_conv_layer: bool
|
image: np.ndarray,
|
||||||
|
coordinate_list: list,
|
||||||
|
layer_type_list: list,
|
||||||
|
ignore_output_conv_layer: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Call function after creating the model-to-be-trained.
|
Call function after creating the model-to-be-trained.
|
||||||
|
@ -56,25 +28,6 @@ def draw_kernel(
|
||||||
edge_color_cycler = iter(
|
edge_color_cycler = iter(
|
||||||
cycler(color=["sienna", "orange", "gold", "bisque"] + colors)
|
cycler(color=["sienna", "orange", "gold", "bisque"] + colors)
|
||||||
)
|
)
|
||||||
kernel_sizes: list[int] = []
|
|
||||||
stride_sizes: list[int] = []
|
|
||||||
layer_type: list[str] = []
|
|
||||||
|
|
||||||
# extract kernel and stride information
|
|
||||||
model_info: list[dict] = extract_kernel_stride(model)
|
|
||||||
|
|
||||||
# iterate over kernels to plot on image
|
|
||||||
for layer in model_info:
|
|
||||||
kernel_sizes.append(layer["kernel_size"])
|
|
||||||
stride_sizes.append(layer["stride"])
|
|
||||||
layer_type.append(layer["layer_type"])
|
|
||||||
|
|
||||||
# change tuples to list items:
|
|
||||||
kernel_array: np.ndarray = np.array([k[0] if isinstance(k, tuple) else k for k in kernel_sizes]) # type: ignore
|
|
||||||
stride_array: np.ndarray = np.array([s[0] if isinstance(s, tuple) else s for s in stride_sizes]) # type: ignore
|
|
||||||
|
|
||||||
# calculate values
|
|
||||||
kernels, strides = calculate_kernel_size(kernel_array, stride_array)
|
|
||||||
|
|
||||||
# position first kernel
|
# position first kernel
|
||||||
start_x: int = 4
|
start_x: int = 4
|
||||||
|
@ -87,39 +40,44 @@ def draw_kernel(
|
||||||
ax.tick_params(axis="both", which="major", labelsize=15)
|
ax.tick_params(axis="both", which="major", labelsize=15)
|
||||||
|
|
||||||
if ignore_output_conv_layer:
|
if ignore_output_conv_layer:
|
||||||
number_of_layers: int = len(kernels) - 1
|
number_of_layers: int = len(layer_type_list) - 1
|
||||||
else:
|
else:
|
||||||
number_of_layers = len(kernels)
|
number_of_layers = len(layer_type_list)
|
||||||
|
|
||||||
for i in range(0, number_of_layers):
|
for i in range(0, number_of_layers):
|
||||||
edgecolor = next(edge_color_cycler)["color"]
|
if layer_type_list[i] is not None:
|
||||||
# draw kernel
|
kernels = int(coordinate_list[i].shape[0])
|
||||||
kernel = patch.Rectangle(
|
edgecolor = next(edge_color_cycler)["color"]
|
||||||
(start_x, start_y),
|
# draw kernel
|
||||||
kernels[i],
|
kernel = patch.Rectangle(
|
||||||
kernels[i],
|
(start_x, start_y),
|
||||||
linewidth=1.2,
|
kernels,
|
||||||
edgecolor=edgecolor,
|
kernels,
|
||||||
facecolor="none",
|
linewidth=1.2,
|
||||||
label=layer_type[i],
|
edgecolor=edgecolor,
|
||||||
)
|
facecolor="none",
|
||||||
ax.add_patch(kernel)
|
label=layer_type_list[i],
|
||||||
|
)
|
||||||
|
ax.add_patch(kernel)
|
||||||
|
|
||||||
# draw stride
|
if coordinate_list[i].shape[1] > 1:
|
||||||
stride = patch.Rectangle(
|
strides = int(coordinate_list[i][0, 1]) - int(coordinate_list[i][0, 0])
|
||||||
(start_x + strides[i], start_y + strides[i]),
|
|
||||||
kernels[i],
|
|
||||||
kernels[i],
|
|
||||||
linewidth=1.2,
|
|
||||||
edgecolor=edgecolor,
|
|
||||||
facecolor="none",
|
|
||||||
linestyle="dashed",
|
|
||||||
)
|
|
||||||
ax.add_patch(stride)
|
|
||||||
|
|
||||||
# add distance of next drawing
|
# draw stride
|
||||||
start_x += 14
|
stride = patch.Rectangle(
|
||||||
start_y += 10
|
(start_x + strides, start_y + strides),
|
||||||
|
kernels,
|
||||||
|
kernels,
|
||||||
|
linewidth=1.2,
|
||||||
|
edgecolor=edgecolor,
|
||||||
|
facecolor="none",
|
||||||
|
linestyle="dashed",
|
||||||
|
)
|
||||||
|
ax.add_patch(stride)
|
||||||
|
|
||||||
|
# add distance of next drawing
|
||||||
|
start_x += 14
|
||||||
|
start_y += 10
|
||||||
|
|
||||||
# final plot
|
# final plot
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
|
@ -127,6 +85,108 @@ def draw_kernel(
|
||||||
plt.show(block=True)
|
plt.show(block=True)
|
||||||
|
|
||||||
|
|
||||||
|
def unfold(
|
||||||
|
layer: torch.nn.Conv2d | torch.nn.MaxPool2d | torch.nn.AvgPool2d, size: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if isinstance(layer.kernel_size, tuple):
|
||||||
|
assert layer.kernel_size[0] == layer.kernel_size[1]
|
||||||
|
kernel_size: int = int(layer.kernel_size[0])
|
||||||
|
else:
|
||||||
|
kernel_size = int(layer.kernel_size)
|
||||||
|
|
||||||
|
if isinstance(layer.dilation, tuple):
|
||||||
|
assert layer.dilation[0] == layer.dilation[1]
|
||||||
|
dilation: int = int(layer.dilation[0])
|
||||||
|
else:
|
||||||
|
dilation = int(layer.dilation) # type: ignore
|
||||||
|
|
||||||
|
if isinstance(layer.padding, tuple):
|
||||||
|
assert layer.padding[0] == layer.padding[1]
|
||||||
|
padding: int = int(layer.padding[0])
|
||||||
|
else:
|
||||||
|
padding = int(layer.padding)
|
||||||
|
|
||||||
|
if isinstance(layer.stride, tuple):
|
||||||
|
assert layer.stride[0] == layer.stride[1]
|
||||||
|
stride: int = int(layer.stride[0])
|
||||||
|
else:
|
||||||
|
stride = int(layer.stride)
|
||||||
|
|
||||||
|
out = (
|
||||||
|
torch.nn.functional.unfold(
|
||||||
|
torch.arange(0, size, dtype=torch.float32)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.unsqueeze(-1),
|
||||||
|
kernel_size=(kernel_size, 1),
|
||||||
|
dilation=(dilation, 1),
|
||||||
|
padding=(padding, 0),
|
||||||
|
stride=(stride, 1),
|
||||||
|
)
|
||||||
|
.squeeze(0)
|
||||||
|
.type(torch.int64)
|
||||||
|
)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def analyse_network(
|
||||||
|
model: torch.nn.Sequential, input_shape: int
|
||||||
|
) -> tuple[list, list, list]:
|
||||||
|
combined_list: list = []
|
||||||
|
coordinate_list: list = []
|
||||||
|
layer_type_list: list = []
|
||||||
|
pixel_used: list[int] = []
|
||||||
|
|
||||||
|
size: int = int(input_shape)
|
||||||
|
|
||||||
|
for layer_id in range(0, len(model)):
|
||||||
|
if isinstance(
|
||||||
|
model[layer_id], (torch.nn.Conv2d, torch.nn.MaxPool2d, torch.nn.AvgPool2d)
|
||||||
|
):
|
||||||
|
out = unfold(layer=model[layer_id], size=size)
|
||||||
|
coordinate_list.append(out)
|
||||||
|
layer_type_list.append(
|
||||||
|
str(type(model[layer_id])).split(".")[-1].split("'")[0]
|
||||||
|
)
|
||||||
|
size = int(out.shape[-1])
|
||||||
|
else:
|
||||||
|
coordinate_list.append(None)
|
||||||
|
layer_type_list.append(None)
|
||||||
|
|
||||||
|
assert coordinate_list[0] is not None
|
||||||
|
combined_list.append(coordinate_list[0])
|
||||||
|
|
||||||
|
for i in range(1, len(coordinate_list)):
|
||||||
|
if coordinate_list[i] is None:
|
||||||
|
combined_list.append(combined_list[i - 1])
|
||||||
|
else:
|
||||||
|
for pos in range(0, coordinate_list[i].shape[-1]):
|
||||||
|
idx_shape: int | None = None
|
||||||
|
|
||||||
|
idx = torch.unique(
|
||||||
|
torch.flatten(combined_list[i - 1][:, coordinate_list[i][:, pos]])
|
||||||
|
)
|
||||||
|
if idx_shape is None:
|
||||||
|
idx_shape = idx.shape[0]
|
||||||
|
assert idx_shape == idx.shape[0]
|
||||||
|
|
||||||
|
assert idx_shape is not None
|
||||||
|
|
||||||
|
temp = torch.zeros((idx_shape, coordinate_list[i].shape[-1]))
|
||||||
|
for pos in range(0, coordinate_list[i].shape[-1]):
|
||||||
|
idx = torch.unique(
|
||||||
|
torch.flatten(combined_list[i - 1][:, coordinate_list[i][:, pos]])
|
||||||
|
)
|
||||||
|
temp[:, pos] = idx
|
||||||
|
combined_list.append(temp)
|
||||||
|
|
||||||
|
for i in range(0, len(combined_list)):
|
||||||
|
pixel_used.append(int(torch.unique(torch.flatten(combined_list[i])).shape[0]))
|
||||||
|
|
||||||
|
return combined_list, layer_type_list, pixel_used
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import os
|
import os
|
||||||
|
@ -170,13 +230,26 @@ if __name__ == "__main__":
|
||||||
)
|
)
|
||||||
print(model)
|
print(model)
|
||||||
|
|
||||||
# test_image = torch.zeros((1, *input_shape), dtype=torch.float32)
|
assert input_shape[-2] == input_shape[-1]
|
||||||
|
coordinate_list, layer_type_list, pixel_used = analyse_network(
|
||||||
|
model=model, input_shape=int(input_shape[-1])
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(0, len(coordinate_list)):
|
||||||
|
print(
|
||||||
|
(
|
||||||
|
f"Layer: {i}, Positions: {coordinate_list[i].shape[1]}, "
|
||||||
|
f"Pixel per Positions: {coordinate_list[i].shape[0]}, "
|
||||||
|
f"Type: {layer_type_list[i]}, Number of pixel used: {pixel_used[i]}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
image = data_test.__getitem__(6)[1].squeeze(0)
|
image = data_test.__getitem__(6)[1].squeeze(0)
|
||||||
|
|
||||||
# call function:
|
# call function:
|
||||||
draw_kernel(
|
draw_kernel(
|
||||||
image=image.numpy(),
|
image=image.numpy(),
|
||||||
model=model,
|
coordinate_list=coordinate_list,
|
||||||
|
layer_type_list=layer_type_list,
|
||||||
ignore_output_conv_layer=ignore_output_conv_layer,
|
ignore_output_conv_layer=ignore_output_conv_layer,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue