21 lines
465 B
Python
21 lines
465 B
Python
import torch
|
|
|
|
|
|
def binning(
|
|
data: torch.Tensor,
|
|
kernel_size: int = 4,
|
|
stride: int = 4,
|
|
divisor_override: int | None = 1,
|
|
) -> torch.Tensor:
|
|
|
|
assert data.ndim == 4
|
|
return (
|
|
torch.nn.functional.avg_pool2d(
|
|
input=data.movedim(0, -1).movedim(0, -1),
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
divisor_override=divisor_override,
|
|
)
|
|
.movedim(-1, 0)
|
|
.movedim(-1, 0)
|
|
)
|