gevi/reproduction_effort/functions/gauss_smear.py
2024-02-03 12:44:36 +01:00

41 lines
1.2 KiB
Python

import torch
from functions.gauss_smear_individual import gauss_smear_individual
@torch.no_grad()
def gauss_smear(
input_cameras: list[torch.Tensor],
input_mask: torch.Tensor,
spatial_width: float,
temporal_width: float,
use_matlab_mask: bool = True,
epsilon: float = float(torch.finfo(torch.float64).eps),
) -> list[torch.Tensor]:
assert len(input_cameras) == 4
filtered_mask: torch.Tensor
filtered_mask, _ = gauss_smear_individual(
input=input_mask,
spatial_width=spatial_width,
temporal_width=temporal_width,
use_matlab_mask=use_matlab_mask,
epsilon=epsilon,
)
overwrite_fft_gauss: None | torch.Tensor = None
for id in range(0, len(input_cameras)):
input_cameras[id] *= input_mask.unsqueeze(-1)
input_cameras[id], overwrite_fft_gauss = gauss_smear_individual(
input=input_cameras[id],
spatial_width=spatial_width,
temporal_width=temporal_width,
overwrite_fft_gauss=overwrite_fft_gauss,
use_matlab_mask=use_matlab_mask,
epsilon=epsilon,
)
input_cameras[id] /= filtered_mask + 1e-20
input_cameras[id] += 1.0 - input_mask.unsqueeze(-1)
return input_cameras