41 lines
1.2 KiB
Python
41 lines
1.2 KiB
Python
import torch
|
|
from functions.gauss_smear_individual import gauss_smear_individual
|
|
|
|
|
|
@torch.no_grad()
|
|
def gauss_smear(
|
|
input_cameras: list[torch.Tensor],
|
|
input_mask: torch.Tensor,
|
|
spatial_width: float,
|
|
temporal_width: float,
|
|
use_matlab_mask: bool = True,
|
|
epsilon: float = float(torch.finfo(torch.float64).eps),
|
|
) -> list[torch.Tensor]:
|
|
assert len(input_cameras) == 4
|
|
|
|
filtered_mask: torch.Tensor
|
|
filtered_mask, _ = gauss_smear_individual(
|
|
input=input_mask,
|
|
spatial_width=spatial_width,
|
|
temporal_width=temporal_width,
|
|
use_matlab_mask=use_matlab_mask,
|
|
epsilon=epsilon,
|
|
)
|
|
|
|
overwrite_fft_gauss: None | torch.Tensor = None
|
|
for id in range(0, len(input_cameras)):
|
|
|
|
input_cameras[id] *= input_mask.unsqueeze(-1)
|
|
input_cameras[id], overwrite_fft_gauss = gauss_smear_individual(
|
|
input=input_cameras[id],
|
|
spatial_width=spatial_width,
|
|
temporal_width=temporal_width,
|
|
overwrite_fft_gauss=overwrite_fft_gauss,
|
|
use_matlab_mask=use_matlab_mask,
|
|
epsilon=epsilon,
|
|
)
|
|
|
|
input_cameras[id] /= filtered_mask + 1e-20
|
|
input_cameras[id] += 1.0 - input_mask.unsqueeze(-1)
|
|
|
|
return input_cameras
|