pytorch-sbs/network/calculate_output_size.py

96 lines
2.2 KiB
Python
Raw Normal View History

2023-01-05 13:23:58 +01:00
# %%
import torch
def calculate_output_size(
value: list[int],
kernel_size: list[int],
stride: list[int],
dilation: list[int],
padding: list[int],
) -> torch.Tensor:
assert len(value) == 2
assert len(kernel_size) == 2
assert len(stride) == 2
assert len(dilation) == 2
assert len(padding) == 2
coordinates_0, coordinates_1 = get_coordinates(
value=value,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
)
output_size: torch.Tensor = torch.tensor(
[
coordinates_0.shape[1],
coordinates_1.shape[1],
],
dtype=torch.int64,
)
return output_size
def get_coordinates(
value: list[int],
kernel_size: list[int],
stride: list[int],
dilation: list[int],
padding: list[int],
) -> tuple[torch.Tensor, torch.Tensor]:
"""Function converts parameter in coordinates
for the convolution window"""
unfold_0: torch.nn.Unfold = torch.nn.Unfold(
kernel_size=(int(kernel_size[0]), 1),
dilation=int(dilation[0]),
padding=int(padding[0]),
stride=int(stride[0]),
)
unfold_1: torch.nn.Unfold = torch.nn.Unfold(
kernel_size=(1, int(kernel_size[1])),
dilation=int(dilation[1]),
padding=int(padding[1]),
stride=int(stride[1]),
)
coordinates_0: torch.Tensor = (
unfold_0(
torch.unsqueeze(
torch.unsqueeze(
torch.unsqueeze(
torch.arange(0, int(value[0]), dtype=torch.float32),
1,
),
0,
),
0,
)
)
.squeeze(0)
.type(torch.int64)
)
coordinates_1: torch.Tensor = (
unfold_1(
torch.unsqueeze(
torch.unsqueeze(
torch.unsqueeze(
torch.arange(0, int(value[1]), dtype=torch.float32),
0,
),
0,
),
0,
)
)
.squeeze(0)
.type(torch.int64)
)
return coordinates_0, coordinates_1