gevi/new_pipeline/functions/regression.py

118 lines
4 KiB
Python
Raw Normal View History

2024-02-27 01:13:52 +01:00
import torch
2024-02-27 15:18:53 +01:00
import logging
2024-02-27 01:13:52 +01:00
from functions.regression_internal import regression_internal
@torch.no_grad()
def regression(
2024-02-27 15:18:53 +01:00
mylogger: logging.Logger,
2024-02-27 01:13:52 +01:00
target_camera_id: int,
regressor_camera_ids: list[int],
mask: torch.Tensor,
2024-02-27 15:18:53 +01:00
data: torch.Tensor,
data_filtered: torch.Tensor,
2024-02-27 01:13:52 +01:00
first_none_ramp_frame: int,
2024-02-27 15:18:53 +01:00
) -> tuple[torch.Tensor, torch.Tensor]:
2024-02-27 01:13:52 +01:00
assert len(regressor_camera_ids) > 0
2024-02-27 15:18:53 +01:00
mylogger.info("Prepare the target signal - 1.0 (from data_filtered)")
2024-02-27 01:13:52 +01:00
target_signals_train: torch.Tensor = (
2024-02-27 15:18:53 +01:00
data_filtered[target_camera_id, ..., first_none_ramp_frame:].clone() - 1.0
2024-02-27 01:13:52 +01:00
)
target_signals_train[target_signals_train < -1] = 0.0
# Check if everything is happy
assert target_signals_train.ndim == 3
2024-02-27 15:18:53 +01:00
assert target_signals_train.ndim == data[target_camera_id, ...].ndim
assert target_signals_train.shape[0] == data[target_camera_id, ...].shape[0]
assert target_signals_train.shape[1] == data[target_camera_id, ...].shape[1]
assert (target_signals_train.shape[2] + first_none_ramp_frame) == data[
target_camera_id, ...
].shape[2]
2024-02-27 01:13:52 +01:00
2024-02-27 15:18:53 +01:00
mylogger.info("Prepare the regressor signals (linear plus from data_filtered)")
2024-02-27 01:13:52 +01:00
regressor_signals_train: torch.Tensor = torch.zeros(
(
2024-02-27 15:18:53 +01:00
data_filtered.shape[1],
data_filtered.shape[2],
data_filtered.shape[3],
2024-02-27 01:13:52 +01:00
len(regressor_camera_ids) + 1,
),
2024-02-27 15:18:53 +01:00
device=data_filtered.device,
dtype=data_filtered.dtype,
2024-02-27 01:13:52 +01:00
)
2024-02-27 15:18:53 +01:00
mylogger.info("Copy the regressor signals - 1.0")
2024-02-27 01:13:52 +01:00
for matrix_id, id in enumerate(regressor_camera_ids):
2024-02-27 15:18:53 +01:00
regressor_signals_train[..., matrix_id] = data_filtered[id, ...] - 1.0
2024-02-27 01:13:52 +01:00
regressor_signals_train[regressor_signals_train < -1] = 0.0
2024-02-27 15:18:53 +01:00
mylogger.info("Create the linear regressor")
2024-02-27 01:13:52 +01:00
trend = torch.arange(
2024-02-27 15:18:53 +01:00
0, regressor_signals_train.shape[-2], device=data_filtered.device
2024-02-27 01:13:52 +01:00
) / float(regressor_signals_train.shape[-2] - 1)
trend -= trend.mean()
trend = trend.unsqueeze(0).unsqueeze(0)
trend = trend.tile(
(regressor_signals_train.shape[0], regressor_signals_train.shape[1], 1)
)
regressor_signals_train[..., -1] = trend
regressor_signals_train = regressor_signals_train[:, :, first_none_ramp_frame:, :]
2024-02-27 15:18:53 +01:00
mylogger.info("Calculating the regression coefficients")
coefficients, intercept = regression_internal(
input_regressor=regressor_signals_train, input_target=target_signals_train
)
del regressor_signals_train
del target_signals_train
mylogger.info("Prepare the target signal - 1.0 (from data)")
target_signals_perform: torch.Tensor = data[target_camera_id, ...].clone() - 1.0
2024-02-27 01:13:52 +01:00
2024-02-27 15:18:53 +01:00
mylogger.info("Prepare the regressor signals (linear plus from data)")
2024-02-27 01:13:52 +01:00
regressor_signals_perform: torch.Tensor = torch.zeros(
(
2024-02-27 15:18:53 +01:00
data.shape[1],
data.shape[2],
data.shape[3],
2024-02-27 01:13:52 +01:00
len(regressor_camera_ids) + 1,
),
2024-02-27 15:18:53 +01:00
device=data.device,
dtype=data.dtype,
2024-02-27 01:13:52 +01:00
)
2024-02-27 15:18:53 +01:00
mylogger.info("Copy the regressor signals - 1.0 ")
2024-02-27 01:13:52 +01:00
for matrix_id, id in enumerate(regressor_camera_ids):
2024-02-27 15:18:53 +01:00
regressor_signals_perform[..., matrix_id] = data[id] - 1.0
2024-02-27 01:13:52 +01:00
2024-02-27 15:18:53 +01:00
mylogger.info("Create the linear regressor")
2024-02-27 01:13:52 +01:00
trend = torch.arange(
2024-02-27 15:18:53 +01:00
0, regressor_signals_perform.shape[-2], device=data[0].device
2024-02-27 01:13:52 +01:00
) / float(regressor_signals_perform.shape[-2] - 1)
trend -= trend.mean()
trend = trend.unsqueeze(0).unsqueeze(0)
trend = trend.tile(
(regressor_signals_perform.shape[0], regressor_signals_perform.shape[1], 1)
)
regressor_signals_perform[..., -1] = trend
2024-02-27 15:18:53 +01:00
mylogger.info("Remove regressors")
2024-02-27 01:13:52 +01:00
target_signals_perform -= (
regressor_signals_perform * coefficients.unsqueeze(-2)
).sum(dim=-1)
2024-02-27 15:18:53 +01:00
mylogger.info("Remove offset")
2024-02-27 01:13:52 +01:00
target_signals_perform -= intercept.unsqueeze(-1)
2024-02-27 15:18:53 +01:00
mylogger.info("Remove masked pixels")
target_signals_perform[mask, :] = 0.0
2024-02-27 01:13:52 +01:00
2024-02-27 15:18:53 +01:00
mylogger.info("Add an offset of 1.0")
2024-02-27 01:13:52 +01:00
target_signals_perform += 1.0
2024-02-27 15:18:53 +01:00
return target_signals_perform, coefficients