From 5b5152fc8b8cda9394678a6e5e14e703cccd55b1 Mon Sep 17 00:00:00 2001 From: David Rotermund <54365609+davrot@users.noreply.github.com> Date: Wed, 26 Jul 2023 12:44:15 +0200 Subject: [PATCH] Add files via upload --- network/calculate_output_size.py | 68 ++++++++++++++------------------ 1 file changed, 30 insertions(+), 38 deletions(-) diff --git a/network/calculate_output_size.py b/network/calculate_output_size.py index d99681e..56c5457 100644 --- a/network/calculate_output_size.py +++ b/network/calculate_output_size.py @@ -1,4 +1,3 @@ -# %% import torch @@ -9,7 +8,6 @@ def calculate_output_size( dilation: list[int], padding: list[int], ) -> torch.Tensor: - assert len(value) == 2 assert len(kernel_size) == 2 assert len(stride) == 2 @@ -44,52 +42,46 @@ def get_coordinates( """Function converts parameter in coordinates for the convolution window""" - unfold_0: torch.nn.Unfold = torch.nn.Unfold( - kernel_size=(int(kernel_size[0]), 1), - dilation=int(dilation[0]), - padding=int(padding[0]), - stride=int(stride[0]), - ) - - unfold_1: torch.nn.Unfold = torch.nn.Unfold( - kernel_size=(1, int(kernel_size[1])), - dilation=int(dilation[1]), - padding=int(padding[1]), - stride=int(stride[1]), - ) - coordinates_0: torch.Tensor = ( - unfold_0( - torch.unsqueeze( - torch.unsqueeze( - torch.unsqueeze( - torch.arange(0, int(value[0]), dtype=torch.float32), - 1, - ), - 0, - ), - 0, - ) + torch.nn.functional.unfold( + torch.arange(0, int(value[0]), dtype=torch.float32) + .unsqueeze(1) + .unsqueeze(0) + .unsqueeze(0), + kernel_size=(int(kernel_size[0]), 1), + dilation=int(dilation[0]), + padding=(int(padding[0]), 0), + stride=int(stride[0]), ) .squeeze(0) .type(torch.int64) ) coordinates_1: torch.Tensor = ( - unfold_1( - torch.unsqueeze( - torch.unsqueeze( - torch.unsqueeze( - torch.arange(0, int(value[1]), dtype=torch.float32), - 0, - ), - 0, - ), - 0, - ) + torch.nn.functional.unfold( + torch.arange(0, int(value[1]), dtype=torch.float32) + .unsqueeze(0) + .unsqueeze(0) + .unsqueeze(0), + kernel_size=(1, int(kernel_size[1])), + dilation=int(dilation[1]), + padding=(0, int(padding[1])), + stride=int(stride[1]), ) .squeeze(0) .type(torch.int64) ) return coordinates_0, coordinates_1 + + +if __name__ == "__main__": + a, b = get_coordinates( + value=[28, 28], + kernel_size=[5, 5], + stride=[1, 1], + dilation=[1, 1], + padding=[0, 0], + ) + print(a.shape) + print(b.shape)