gevi/functions/binning.py

47 lines
1.1 KiB
Python
Raw Normal View History

2024-02-28 16:14:50 +01:00
import torch
2024-02-28 18:55:37 +01:00
@torch.no_grad()
2024-02-28 16:14:50 +01:00
def binning(
data: torch.Tensor,
kernel_size: int = 4,
stride: int = 4,
divisor_override: int | None = 1,
) -> torch.Tensor:
2024-02-28 18:55:37 +01:00
try:
return binning_internal(
data=data,
kernel_size=kernel_size,
stride=stride,
divisor_override=divisor_override,
)
except torch.cuda.OutOfMemoryError:
return binning_internal(
data=data.cpu(),
kernel_size=kernel_size,
stride=stride,
divisor_override=divisor_override,
).to(device=data.device)
@torch.no_grad()
def binning_internal(
data: torch.Tensor,
kernel_size: int = 4,
stride: int = 4,
divisor_override: int | None = 1,
) -> torch.Tensor:
2024-02-28 16:14:50 +01:00
assert data.ndim == 4
return (
torch.nn.functional.avg_pool2d(
input=data.movedim(0, -1).movedim(0, -1),
kernel_size=kernel_size,
stride=stride,
divisor_override=divisor_override,
)
.movedim(-1, 0)
.movedim(-1, 0)
)