Add files via upload
This commit is contained in:
parent
9bdadfef7c
commit
5b5152fc8b
1 changed files with 30 additions and 38 deletions
|
@ -1,4 +1,3 @@
|
|||
# %%
|
||||
import torch
|
||||
|
||||
|
||||
|
@ -9,7 +8,6 @@ def calculate_output_size(
|
|||
dilation: list[int],
|
||||
padding: list[int],
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert len(value) == 2
|
||||
assert len(kernel_size) == 2
|
||||
assert len(stride) == 2
|
||||
|
@ -44,52 +42,46 @@ def get_coordinates(
|
|||
"""Function converts parameter in coordinates
|
||||
for the convolution window"""
|
||||
|
||||
unfold_0: torch.nn.Unfold = torch.nn.Unfold(
|
||||
kernel_size=(int(kernel_size[0]), 1),
|
||||
dilation=int(dilation[0]),
|
||||
padding=int(padding[0]),
|
||||
stride=int(stride[0]),
|
||||
)
|
||||
|
||||
unfold_1: torch.nn.Unfold = torch.nn.Unfold(
|
||||
kernel_size=(1, int(kernel_size[1])),
|
||||
dilation=int(dilation[1]),
|
||||
padding=int(padding[1]),
|
||||
stride=int(stride[1]),
|
||||
)
|
||||
|
||||
coordinates_0: torch.Tensor = (
|
||||
unfold_0(
|
||||
torch.unsqueeze(
|
||||
torch.unsqueeze(
|
||||
torch.unsqueeze(
|
||||
torch.arange(0, int(value[0]), dtype=torch.float32),
|
||||
1,
|
||||
),
|
||||
0,
|
||||
),
|
||||
0,
|
||||
)
|
||||
torch.nn.functional.unfold(
|
||||
torch.arange(0, int(value[0]), dtype=torch.float32)
|
||||
.unsqueeze(1)
|
||||
.unsqueeze(0)
|
||||
.unsqueeze(0),
|
||||
kernel_size=(int(kernel_size[0]), 1),
|
||||
dilation=int(dilation[0]),
|
||||
padding=(int(padding[0]), 0),
|
||||
stride=int(stride[0]),
|
||||
)
|
||||
.squeeze(0)
|
||||
.type(torch.int64)
|
||||
)
|
||||
|
||||
coordinates_1: torch.Tensor = (
|
||||
unfold_1(
|
||||
torch.unsqueeze(
|
||||
torch.unsqueeze(
|
||||
torch.unsqueeze(
|
||||
torch.arange(0, int(value[1]), dtype=torch.float32),
|
||||
0,
|
||||
),
|
||||
0,
|
||||
),
|
||||
0,
|
||||
)
|
||||
torch.nn.functional.unfold(
|
||||
torch.arange(0, int(value[1]), dtype=torch.float32)
|
||||
.unsqueeze(0)
|
||||
.unsqueeze(0)
|
||||
.unsqueeze(0),
|
||||
kernel_size=(1, int(kernel_size[1])),
|
||||
dilation=int(dilation[1]),
|
||||
padding=(0, int(padding[1])),
|
||||
stride=int(stride[1]),
|
||||
)
|
||||
.squeeze(0)
|
||||
.type(torch.int64)
|
||||
)
|
||||
|
||||
return coordinates_0, coordinates_1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
a, b = get_coordinates(
|
||||
value=[28, 28],
|
||||
kernel_size=[5, 5],
|
||||
stride=[1, 1],
|
||||
dilation=[1, 1],
|
||||
padding=[0, 0],
|
||||
)
|
||||
print(a.shape)
|
||||
print(b.shape)
|
||||
|
|
Loading…
Reference in a new issue