2024-02-28 16:14:50 +01:00
|
|
|
import torch
|
|
|
|
|
|
|
|
|
2024-02-28 18:55:37 +01:00
|
|
|
@torch.no_grad()
|
2024-02-28 16:14:50 +01:00
|
|
|
def binning(
|
|
|
|
data: torch.Tensor,
|
|
|
|
kernel_size: int = 4,
|
|
|
|
stride: int = 4,
|
|
|
|
divisor_override: int | None = 1,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
2024-02-28 18:55:37 +01:00
|
|
|
try:
|
|
|
|
return binning_internal(
|
|
|
|
data=data,
|
|
|
|
kernel_size=kernel_size,
|
|
|
|
stride=stride,
|
|
|
|
divisor_override=divisor_override,
|
|
|
|
)
|
|
|
|
except torch.cuda.OutOfMemoryError:
|
|
|
|
return binning_internal(
|
|
|
|
data=data.cpu(),
|
|
|
|
kernel_size=kernel_size,
|
|
|
|
stride=stride,
|
|
|
|
divisor_override=divisor_override,
|
|
|
|
).to(device=data.device)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def binning_internal(
|
|
|
|
data: torch.Tensor,
|
|
|
|
kernel_size: int = 4,
|
|
|
|
stride: int = 4,
|
|
|
|
divisor_override: int | None = 1,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
2024-02-28 16:14:50 +01:00
|
|
|
assert data.ndim == 4
|
|
|
|
return (
|
|
|
|
torch.nn.functional.avg_pool2d(
|
|
|
|
input=data.movedim(0, -1).movedim(0, -1),
|
|
|
|
kernel_size=kernel_size,
|
|
|
|
stride=stride,
|
|
|
|
divisor_override=divisor_override,
|
|
|
|
)
|
|
|
|
.movedim(-1, 0)
|
|
|
|
.movedim(-1, 0)
|
|
|
|
)
|