Add files via upload

This commit is contained in:
David Rotermund 2023-07-26 12:44:15 +02:00 committed by GitHub
parent 9bdadfef7c
commit 5b5152fc8b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,4 +1,3 @@
# %%
import torch import torch
@ -9,7 +8,6 @@ def calculate_output_size(
dilation: list[int], dilation: list[int],
padding: list[int], padding: list[int],
) -> torch.Tensor: ) -> torch.Tensor:
assert len(value) == 2 assert len(value) == 2
assert len(kernel_size) == 2 assert len(kernel_size) == 2
assert len(stride) == 2 assert len(stride) == 2
@ -44,52 +42,46 @@ def get_coordinates(
"""Function converts parameter in coordinates """Function converts parameter in coordinates
for the convolution window""" for the convolution window"""
unfold_0: torch.nn.Unfold = torch.nn.Unfold( coordinates_0: torch.Tensor = (
torch.nn.functional.unfold(
torch.arange(0, int(value[0]), dtype=torch.float32)
.unsqueeze(1)
.unsqueeze(0)
.unsqueeze(0),
kernel_size=(int(kernel_size[0]), 1), kernel_size=(int(kernel_size[0]), 1),
dilation=int(dilation[0]), dilation=int(dilation[0]),
padding=int(padding[0]), padding=(int(padding[0]), 0),
stride=int(stride[0]), stride=int(stride[0]),
) )
unfold_1: torch.nn.Unfold = torch.nn.Unfold(
kernel_size=(1, int(kernel_size[1])),
dilation=int(dilation[1]),
padding=int(padding[1]),
stride=int(stride[1]),
)
coordinates_0: torch.Tensor = (
unfold_0(
torch.unsqueeze(
torch.unsqueeze(
torch.unsqueeze(
torch.arange(0, int(value[0]), dtype=torch.float32),
1,
),
0,
),
0,
)
)
.squeeze(0) .squeeze(0)
.type(torch.int64) .type(torch.int64)
) )
coordinates_1: torch.Tensor = ( coordinates_1: torch.Tensor = (
unfold_1( torch.nn.functional.unfold(
torch.unsqueeze( torch.arange(0, int(value[1]), dtype=torch.float32)
torch.unsqueeze( .unsqueeze(0)
torch.unsqueeze( .unsqueeze(0)
torch.arange(0, int(value[1]), dtype=torch.float32), .unsqueeze(0),
0, kernel_size=(1, int(kernel_size[1])),
), dilation=int(dilation[1]),
0, padding=(0, int(padding[1])),
), stride=int(stride[1]),
0,
)
) )
.squeeze(0) .squeeze(0)
.type(torch.int64) .type(torch.int64)
) )
return coordinates_0, coordinates_1 return coordinates_0, coordinates_1
if __name__ == "__main__":
a, b = get_coordinates(
value=[28, 28],
kernel_size=[5, 5],
stride=[1, 1],
dilation=[1, 1],
padding=[0, 0],
)
print(a.shape)
print(b.shape)