36 lines
1,000 B
Python
36 lines
1,000 B
Python
import torch
|
|
|
|
|
|
@torch.no_grad()
|
|
def preprocess_camera_sequence(
|
|
camera_sequence: torch.Tensor,
|
|
mask: torch.Tensor,
|
|
first_none_ramp_frame: int,
|
|
device: torch.device,
|
|
dtype: torch.dtype,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
limit: torch.Tensor = torch.tensor(
|
|
2**16 - 1,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
|
|
camera_sequence = camera_sequence / camera_sequence[
|
|
:, :, first_none_ramp_frame:
|
|
].nanmean(
|
|
dim=2,
|
|
keepdim=True,
|
|
)
|
|
|
|
camera_sequence = camera_sequence.nan_to_num(nan=0.0)
|
|
|
|
camera_sequence_zero_idx = torch.any(camera_sequence == 0, dim=-1, keepdim=True)
|
|
mask &= (~camera_sequence_zero_idx.squeeze(-1)).type(dtype=torch.bool)
|
|
camera_sequence_zero_idx = torch.tile(
|
|
camera_sequence_zero_idx, (1, 1, camera_sequence.shape[-1])
|
|
)
|
|
camera_sequence_zero_idx[:, :, :first_none_ramp_frame] = False
|
|
camera_sequence[camera_sequence_zero_idx] = limit
|
|
|
|
return camera_sequence, mask
|