gevi/reproduction_effort/functions/preprocessing.py
2024-02-03 12:44:36 +01:00

93 lines
3 KiB
Python

import scipy.io as sio # type: ignore
import torch
import numpy as np
import json
from functions.make_mask import make_mask
from functions.convert_camera_sequenc_to_list import convert_camera_sequenc_to_list
from functions.preprocess_camera_sequence import preprocess_camera_sequence
from functions.interpolate_along_time import interpolate_along_time
from functions.gauss_smear import gauss_smear
from functions.regression import regression
@torch.no_grad()
def preprocessing(
filename_metadata: str,
filename_data: str,
filename_mask: str,
device: torch.device,
first_none_ramp_frame: int,
spatial_width: float,
temporal_width: float,
target_camera: list[str],
regressor_cameras: list[str],
dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
data: torch.Tensor = torch.tensor(
sio.loadmat(filename_data)["data"].astype(np.float32),
device=device,
dtype=dtype,
)
with open(filename_metadata, "r") as file_handle:
metadata: dict = json.load(file_handle)
cameras: list[str] = metadata["channelKey"]
required_order: list[str] = ["acceptor", "donor", "oxygenation", "volume"]
mask: torch.Tensor = make_mask(
filename_mask=filename_mask, data=data, device=device, dtype=dtype
)
camera_sequence: list[torch.Tensor] = convert_camera_sequenc_to_list(
data=data, required_order=required_order, cameras=cameras
)
for num_cams in range(len(camera_sequence)):
camera_sequence[num_cams], mask = preprocess_camera_sequence(
camera_sequence=camera_sequence[num_cams],
mask=mask,
first_none_ramp_frame=first_none_ramp_frame,
device=device,
dtype=dtype,
)
# Interpolate in-between images
interpolate_along_time(camera_sequence)
camera_sequence_filtered: list[torch.Tensor] = []
for id in range(0, len(camera_sequence)):
camera_sequence_filtered.append(camera_sequence[id].clone())
camera_sequence_filtered = gauss_smear(
camera_sequence_filtered,
mask.type(dtype=dtype),
spatial_width=spatial_width,
temporal_width=temporal_width,
)
regressor_camera_ids: list[int] = []
for cam in regressor_cameras:
regressor_camera_ids.append(cameras.index(cam))
results: list[torch.Tensor] = []
for channel_position in range(0, len(target_camera)):
print(f"channel position: {channel_position}")
target_camera_selected = target_camera[channel_position]
target_camera_id: int = cameras.index(target_camera_selected)
output = regression(
target_camera_id=target_camera_id,
regressor_camera_ids=regressor_camera_ids,
mask=mask,
camera_sequence=camera_sequence,
camera_sequence_filtered=camera_sequence_filtered,
first_none_ramp_frame=first_none_ramp_frame,
)
results.append(output)
return results[0], results[1], mask