Add files via upload
This commit is contained in:
parent
d17a19b6f9
commit
c9dcb20c64
12 changed files with 760 additions and 0 deletions
85
reproduction_effort/functions/bandpass.py
Normal file
85
reproduction_effort/functions/bandpass.py
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
import torchaudio as ta # type: ignore
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def filtfilt(
|
||||||
|
input: torch.Tensor,
|
||||||
|
butter_a: torch.Tensor,
|
||||||
|
butter_b: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert butter_a.ndim == 1
|
||||||
|
assert butter_b.ndim == 1
|
||||||
|
assert butter_a.shape[0] == butter_b.shape[0]
|
||||||
|
|
||||||
|
process_data: torch.Tensor = input.detach().clone()
|
||||||
|
|
||||||
|
padding_length = 12 * int(butter_a.shape[0])
|
||||||
|
left_padding = 2 * process_data[..., 0].unsqueeze(-1) - process_data[
|
||||||
|
..., 1 : padding_length + 1
|
||||||
|
].flip(-1)
|
||||||
|
right_padding = 2 * process_data[..., -1].unsqueeze(-1) - process_data[
|
||||||
|
..., -(padding_length + 1) : -1
|
||||||
|
].flip(-1)
|
||||||
|
process_data_padded = torch.cat((left_padding, process_data, right_padding), dim=-1)
|
||||||
|
|
||||||
|
output = ta.functional.filtfilt(
|
||||||
|
process_data_padded.unsqueeze(0), butter_a, butter_b, clamp=False
|
||||||
|
).squeeze(0)
|
||||||
|
|
||||||
|
output = output[..., padding_length:-padding_length]
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def butter_bandpass(
|
||||||
|
device: torch.device,
|
||||||
|
low_frequency: float = 0.1,
|
||||||
|
high_frequency: float = 1.0,
|
||||||
|
fs: float = 30.0,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
import scipy # type: ignore
|
||||||
|
|
||||||
|
butter_b_np, butter_a_np = scipy.signal.butter(
|
||||||
|
4, [low_frequency, high_frequency], btype="bandpass", output="ba", fs=fs
|
||||||
|
)
|
||||||
|
butter_a = torch.tensor(butter_a_np, device=device, dtype=torch.float32)
|
||||||
|
butter_b = torch.tensor(butter_b_np, device=device, dtype=torch.float32)
|
||||||
|
return butter_a, butter_b
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def chunk_iterator(array: torch.Tensor, chunk_size: int):
|
||||||
|
for i in range(0, array.shape[0], chunk_size):
|
||||||
|
yield array[i : i + chunk_size]
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def bandpass(
|
||||||
|
data: torch.Tensor,
|
||||||
|
device: torch.device,
|
||||||
|
low_frequency: float = 0.1,
|
||||||
|
high_frequency: float = 1.0,
|
||||||
|
fs=30.0,
|
||||||
|
filtfilt_chuck_size: int = 10,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
butter_a, butter_b = butter_bandpass(
|
||||||
|
device=device,
|
||||||
|
low_frequency=low_frequency,
|
||||||
|
high_frequency=high_frequency,
|
||||||
|
fs=fs,
|
||||||
|
)
|
||||||
|
|
||||||
|
index_full_dataset: torch.Tensor = torch.arange(
|
||||||
|
0, data.shape[1], device=device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in chunk_iterator(index_full_dataset, filtfilt_chuck_size):
|
||||||
|
temp_filtfilt = filtfilt(
|
||||||
|
data[:, chunk, :],
|
||||||
|
butter_a=butter_a,
|
||||||
|
butter_b=butter_b,
|
||||||
|
)
|
||||||
|
data[:, chunk, :] = temp_filtfilt
|
||||||
|
|
||||||
|
return data
|
|
@ -0,0 +1,13 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def convert_camera_sequenc_to_list(
|
||||||
|
data: torch.Tensor, required_order: list[str], cameras: list[str]
|
||||||
|
) -> list[torch.Tensor]:
|
||||||
|
camera_sequence: list[torch.Tensor] = []
|
||||||
|
|
||||||
|
for cam in required_order:
|
||||||
|
camera_sequence.append(data[:, :, :, cameras.index(cam)].clone())
|
||||||
|
|
||||||
|
return camera_sequence
|
41
reproduction_effort/functions/gauss_smear.py
Normal file
41
reproduction_effort/functions/gauss_smear.py
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
import torch
|
||||||
|
from functions.gauss_smear_individual import gauss_smear_individual
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def gauss_smear(
|
||||||
|
input_cameras: list[torch.Tensor],
|
||||||
|
input_mask: torch.Tensor,
|
||||||
|
spatial_width: float,
|
||||||
|
temporal_width: float,
|
||||||
|
use_matlab_mask: bool = True,
|
||||||
|
epsilon: float = float(torch.finfo(torch.float64).eps),
|
||||||
|
) -> list[torch.Tensor]:
|
||||||
|
assert len(input_cameras) == 4
|
||||||
|
|
||||||
|
filtered_mask: torch.Tensor
|
||||||
|
filtered_mask, _ = gauss_smear_individual(
|
||||||
|
input=input_mask,
|
||||||
|
spatial_width=spatial_width,
|
||||||
|
temporal_width=temporal_width,
|
||||||
|
use_matlab_mask=use_matlab_mask,
|
||||||
|
epsilon=epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
|
overwrite_fft_gauss: None | torch.Tensor = None
|
||||||
|
for id in range(0, len(input_cameras)):
|
||||||
|
|
||||||
|
input_cameras[id] *= input_mask.unsqueeze(-1)
|
||||||
|
input_cameras[id], overwrite_fft_gauss = gauss_smear_individual(
|
||||||
|
input=input_cameras[id],
|
||||||
|
spatial_width=spatial_width,
|
||||||
|
temporal_width=temporal_width,
|
||||||
|
overwrite_fft_gauss=overwrite_fft_gauss,
|
||||||
|
use_matlab_mask=use_matlab_mask,
|
||||||
|
epsilon=epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_cameras[id] /= filtered_mask + 1e-20
|
||||||
|
input_cameras[id] += 1.0 - input_mask.unsqueeze(-1)
|
||||||
|
|
||||||
|
return input_cameras
|
127
reproduction_effort/functions/gauss_smear_individual.py
Normal file
127
reproduction_effort/functions/gauss_smear_individual.py
Normal file
|
@ -0,0 +1,127 @@
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def gauss_smear_individual(
|
||||||
|
input: torch.Tensor,
|
||||||
|
spatial_width: float,
|
||||||
|
temporal_width: float,
|
||||||
|
overwrite_fft_gauss: None | torch.Tensor = None,
|
||||||
|
use_matlab_mask: bool = True,
|
||||||
|
epsilon: float = float(torch.finfo(torch.float64).eps),
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
dim_x: int = int(2 * math.ceil(2 * spatial_width) + 1)
|
||||||
|
dim_y: int = int(2 * math.ceil(2 * spatial_width) + 1)
|
||||||
|
dim_t: int = int(2 * math.ceil(2 * temporal_width) + 1)
|
||||||
|
dims_xyt: torch.Tensor = torch.tensor(
|
||||||
|
[dim_x, dim_y, dim_t], dtype=torch.int64, device=input.device
|
||||||
|
)
|
||||||
|
|
||||||
|
if input.ndim == 2:
|
||||||
|
input = input.unsqueeze(-1)
|
||||||
|
|
||||||
|
input_padded = torch.nn.functional.pad(
|
||||||
|
input.unsqueeze(0),
|
||||||
|
pad=(
|
||||||
|
dim_t,
|
||||||
|
dim_t,
|
||||||
|
dim_y,
|
||||||
|
dim_y,
|
||||||
|
dim_x,
|
||||||
|
dim_x,
|
||||||
|
),
|
||||||
|
mode="replicate",
|
||||||
|
).squeeze(0)
|
||||||
|
|
||||||
|
if overwrite_fft_gauss is None:
|
||||||
|
center_x: int = int(math.ceil(input_padded.shape[0] / 2))
|
||||||
|
center_y: int = int(math.ceil(input_padded.shape[1] / 2))
|
||||||
|
center_z: int = int(math.ceil(input_padded.shape[2] / 2))
|
||||||
|
grid_x: torch.Tensor = (
|
||||||
|
torch.arange(0, input_padded.shape[0], device=input.device) - center_x + 1
|
||||||
|
)
|
||||||
|
grid_y: torch.Tensor = (
|
||||||
|
torch.arange(0, input_padded.shape[1], device=input.device) - center_y + 1
|
||||||
|
)
|
||||||
|
grid_z: torch.Tensor = (
|
||||||
|
torch.arange(0, input_padded.shape[2], device=input.device) - center_z + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
grid_x = grid_x.unsqueeze(-1).unsqueeze(-1) ** 2
|
||||||
|
grid_y = grid_y.unsqueeze(0).unsqueeze(-1) ** 2
|
||||||
|
grid_z = grid_z.unsqueeze(0).unsqueeze(0) ** 2
|
||||||
|
|
||||||
|
gauss_kernel: torch.Tensor = (
|
||||||
|
(grid_x / (spatial_width**2))
|
||||||
|
+ (grid_y / (spatial_width**2))
|
||||||
|
+ (grid_z / (temporal_width**2))
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_matlab_mask:
|
||||||
|
filter_radius: torch.Tensor = (dims_xyt - 1) // 2
|
||||||
|
|
||||||
|
border_lower: list[int] = [
|
||||||
|
center_x - int(filter_radius[0]) - 1,
|
||||||
|
center_y - int(filter_radius[1]) - 1,
|
||||||
|
center_z - int(filter_radius[2]) - 1,
|
||||||
|
]
|
||||||
|
|
||||||
|
border_upper: list[int] = [
|
||||||
|
center_x + int(filter_radius[0]),
|
||||||
|
center_y + int(filter_radius[1]),
|
||||||
|
center_z + int(filter_radius[2]),
|
||||||
|
]
|
||||||
|
|
||||||
|
matlab_mask: torch.Tensor = torch.zeros_like(gauss_kernel)
|
||||||
|
matlab_mask[
|
||||||
|
border_lower[0] : border_upper[0],
|
||||||
|
border_lower[1] : border_upper[1],
|
||||||
|
border_lower[2] : border_upper[2],
|
||||||
|
] = 1.0
|
||||||
|
|
||||||
|
gauss_kernel = torch.exp(-gauss_kernel / 2.0)
|
||||||
|
if use_matlab_mask:
|
||||||
|
gauss_kernel = gauss_kernel * matlab_mask
|
||||||
|
|
||||||
|
gauss_kernel[gauss_kernel < (epsilon * gauss_kernel.max())] = 0
|
||||||
|
|
||||||
|
sum_gauss_kernel: float = float(gauss_kernel.sum())
|
||||||
|
|
||||||
|
if sum_gauss_kernel != 0.0:
|
||||||
|
gauss_kernel = gauss_kernel / sum_gauss_kernel
|
||||||
|
|
||||||
|
# FFT Shift
|
||||||
|
gauss_kernel = torch.cat(
|
||||||
|
(gauss_kernel[center_x - 1 :, :, :], gauss_kernel[: center_x - 1, :, :]),
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
gauss_kernel = torch.cat(
|
||||||
|
(gauss_kernel[:, center_y - 1 :, :], gauss_kernel[:, : center_y - 1, :]),
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
gauss_kernel = torch.cat(
|
||||||
|
(gauss_kernel[:, :, center_z - 1 :], gauss_kernel[:, :, : center_z - 1]),
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
overwrite_fft_gauss = torch.fft.fftn(gauss_kernel)
|
||||||
|
input_padded_gauss_filtered: torch.Tensor = torch.real(
|
||||||
|
torch.fft.ifftn(torch.fft.fftn(input_padded) * overwrite_fft_gauss)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
input_padded_gauss_filtered = torch.real(
|
||||||
|
torch.fft.ifftn(torch.fft.fftn(input_padded) * overwrite_fft_gauss)
|
||||||
|
)
|
||||||
|
|
||||||
|
start = dims_xyt
|
||||||
|
stop = (
|
||||||
|
torch.tensor(input_padded.shape, device=dims_xyt.device, dtype=dims_xyt.dtype)
|
||||||
|
- dims_xyt
|
||||||
|
)
|
||||||
|
|
||||||
|
output = input_padded_gauss_filtered[
|
||||||
|
start[0] : stop[0], start[1] : stop[1], start[2] : stop[2]
|
||||||
|
]
|
||||||
|
|
||||||
|
return (output, overwrite_fft_gauss)
|
11
reproduction_effort/functions/interpolate_along_time.py
Normal file
11
reproduction_effort/functions/interpolate_along_time.py
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def interpolate_along_time(camera_sequence: list[torch.Tensor]) -> None:
|
||||||
|
camera_sequence[2][:, :, 1:] = (
|
||||||
|
camera_sequence[2][:, :, 1:] + camera_sequence[2][:, :, :-1]
|
||||||
|
) / 2.0
|
||||||
|
|
||||||
|
camera_sequence[3][:, :, 1:] = (
|
||||||
|
camera_sequence[3][:, :, 1:] + camera_sequence[3][:, :, :-1]
|
||||||
|
) / 2.0
|
26
reproduction_effort/functions/make_mask.py
Normal file
26
reproduction_effort/functions/make_mask.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
import scipy.io as sio # type: ignore
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def make_mask(
|
||||||
|
filename_mask: str, data: torch.Tensor, device: torch.device, dtype: torch.dtype
|
||||||
|
) -> torch.Tensor:
|
||||||
|
mask: torch.Tensor = torch.tensor(
|
||||||
|
sio.loadmat(filename_mask)["maskInfo"]["maskIdx2D"][0][0],
|
||||||
|
device=device,
|
||||||
|
dtype=torch.bool,
|
||||||
|
)
|
||||||
|
mask = mask > 0.5
|
||||||
|
|
||||||
|
limit: torch.Tensor = torch.tensor(
|
||||||
|
2**16 - 1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if torch.any(data.flatten() >= limit):
|
||||||
|
mask = mask & ~(torch.any(torch.any(data >= limit, dim=-1), dim=-1))
|
||||||
|
|
||||||
|
return mask
|
36
reproduction_effort/functions/preprocess_camera_sequence.py
Normal file
36
reproduction_effort/functions/preprocess_camera_sequence.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def preprocess_camera_sequence(
|
||||||
|
camera_sequence: torch.Tensor,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
first_none_ramp_frame: int,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
limit: torch.Tensor = torch.tensor(
|
||||||
|
2**16 - 1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
camera_sequence = camera_sequence / camera_sequence[
|
||||||
|
:, :, first_none_ramp_frame:
|
||||||
|
].mean(
|
||||||
|
dim=2,
|
||||||
|
keepdim=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
camera_sequence = camera_sequence.nan_to_num(nan=0.0)
|
||||||
|
|
||||||
|
camera_sequence_zero_idx = torch.any(camera_sequence == 0, dim=-1, keepdim=True)
|
||||||
|
mask &= (~camera_sequence_zero_idx.squeeze(-1)).type(dtype=torch.bool)
|
||||||
|
camera_sequence_zero_idx = torch.tile(
|
||||||
|
camera_sequence_zero_idx, (1, 1, camera_sequence.shape[-1])
|
||||||
|
)
|
||||||
|
camera_sequence_zero_idx[:, :, :first_none_ramp_frame] = False
|
||||||
|
camera_sequence[camera_sequence_zero_idx] = limit
|
||||||
|
|
||||||
|
return camera_sequence, mask
|
93
reproduction_effort/functions/preprocessing.py
Normal file
93
reproduction_effort/functions/preprocessing.py
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
import scipy.io as sio # type: ignore
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
|
||||||
|
from functions.make_mask import make_mask
|
||||||
|
from functions.convert_camera_sequenc_to_list import convert_camera_sequenc_to_list
|
||||||
|
from functions.preprocess_camera_sequence import preprocess_camera_sequence
|
||||||
|
from functions.interpolate_along_time import interpolate_along_time
|
||||||
|
from functions.gauss_smear import gauss_smear
|
||||||
|
from functions.regression import regression
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def preprocessing(
|
||||||
|
filename_metadata: str,
|
||||||
|
filename_data: str,
|
||||||
|
filename_mask: str,
|
||||||
|
device: torch.device,
|
||||||
|
first_none_ramp_frame: int,
|
||||||
|
spatial_width: float,
|
||||||
|
temporal_width: float,
|
||||||
|
target_camera: list[str],
|
||||||
|
regressor_cameras: list[str],
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
data: torch.Tensor = torch.tensor(
|
||||||
|
sio.loadmat(filename_data)["data"].astype(np.float32),
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(filename_metadata, "r") as file_handle:
|
||||||
|
metadata: dict = json.load(file_handle)
|
||||||
|
cameras: list[str] = metadata["channelKey"]
|
||||||
|
|
||||||
|
required_order: list[str] = ["acceptor", "donor", "oxygenation", "volume"]
|
||||||
|
|
||||||
|
mask: torch.Tensor = make_mask(
|
||||||
|
filename_mask=filename_mask, data=data, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
camera_sequence: list[torch.Tensor] = convert_camera_sequenc_to_list(
|
||||||
|
data=data, required_order=required_order, cameras=cameras
|
||||||
|
)
|
||||||
|
|
||||||
|
for num_cams in range(len(camera_sequence)):
|
||||||
|
camera_sequence[num_cams], mask = preprocess_camera_sequence(
|
||||||
|
camera_sequence=camera_sequence[num_cams],
|
||||||
|
mask=mask,
|
||||||
|
first_none_ramp_frame=first_none_ramp_frame,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Interpolate in-between images
|
||||||
|
interpolate_along_time(camera_sequence)
|
||||||
|
|
||||||
|
camera_sequence_filtered: list[torch.Tensor] = []
|
||||||
|
for id in range(0, len(camera_sequence)):
|
||||||
|
camera_sequence_filtered.append(camera_sequence[id].clone())
|
||||||
|
|
||||||
|
camera_sequence_filtered = gauss_smear(
|
||||||
|
camera_sequence_filtered,
|
||||||
|
mask.type(dtype=dtype),
|
||||||
|
spatial_width=spatial_width,
|
||||||
|
temporal_width=temporal_width,
|
||||||
|
)
|
||||||
|
|
||||||
|
regressor_camera_ids: list[int] = []
|
||||||
|
|
||||||
|
for cam in regressor_cameras:
|
||||||
|
regressor_camera_ids.append(cameras.index(cam))
|
||||||
|
|
||||||
|
results: list[torch.Tensor] = []
|
||||||
|
|
||||||
|
for channel_position in range(0, len(target_camera)):
|
||||||
|
print(f"channel position: {channel_position}")
|
||||||
|
target_camera_selected = target_camera[channel_position]
|
||||||
|
target_camera_id: int = cameras.index(target_camera_selected)
|
||||||
|
|
||||||
|
output = regression(
|
||||||
|
target_camera_id=target_camera_id,
|
||||||
|
regressor_camera_ids=regressor_camera_ids,
|
||||||
|
mask=mask,
|
||||||
|
camera_sequence=camera_sequence,
|
||||||
|
camera_sequence_filtered=camera_sequence_filtered,
|
||||||
|
first_none_ramp_frame=first_none_ramp_frame,
|
||||||
|
)
|
||||||
|
results.append(output)
|
||||||
|
|
||||||
|
return results[0], results[1], mask
|
118
reproduction_effort/functions/regression.py
Normal file
118
reproduction_effort/functions/regression.py
Normal file
|
@ -0,0 +1,118 @@
|
||||||
|
import torch
|
||||||
|
from functions.regression_internal import regression_internal
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def regression(
|
||||||
|
target_camera_id: int,
|
||||||
|
regressor_camera_ids: list[int],
|
||||||
|
mask: torch.Tensor,
|
||||||
|
camera_sequence: list[torch.Tensor],
|
||||||
|
camera_sequence_filtered: list[torch.Tensor],
|
||||||
|
first_none_ramp_frame: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
assert len(regressor_camera_ids) > 0
|
||||||
|
|
||||||
|
# ------- Prepare the target signals ----------
|
||||||
|
target_signals_train: torch.Tensor = (
|
||||||
|
camera_sequence_filtered[target_camera_id][..., first_none_ramp_frame:].clone()
|
||||||
|
- 1.0
|
||||||
|
)
|
||||||
|
target_signals_train[target_signals_train < -1] = 0.0
|
||||||
|
|
||||||
|
target_signals_perform: torch.Tensor = (
|
||||||
|
camera_sequence[target_camera_id].clone() - 1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if everything is happy
|
||||||
|
assert target_signals_train.ndim == 3
|
||||||
|
assert target_signals_train.ndim == target_signals_perform.ndim
|
||||||
|
assert target_signals_train.shape[0] == target_signals_perform.shape[0]
|
||||||
|
assert target_signals_train.shape[1] == target_signals_perform.shape[1]
|
||||||
|
assert (
|
||||||
|
target_signals_train.shape[2] + first_none_ramp_frame
|
||||||
|
) == target_signals_perform.shape[2]
|
||||||
|
# --==DONE==-
|
||||||
|
|
||||||
|
# ------- Prepare the regressor signals ----------
|
||||||
|
|
||||||
|
# --- Train ---
|
||||||
|
|
||||||
|
regressor_signals_train: torch.Tensor = torch.zeros(
|
||||||
|
(
|
||||||
|
camera_sequence_filtered[0].shape[0],
|
||||||
|
camera_sequence_filtered[0].shape[1],
|
||||||
|
camera_sequence_filtered[0].shape[2],
|
||||||
|
len(regressor_camera_ids) + 1,
|
||||||
|
),
|
||||||
|
device=camera_sequence_filtered[0].device,
|
||||||
|
dtype=camera_sequence_filtered[0].dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copy the regressor signals -1
|
||||||
|
for matrix_id, id in enumerate(regressor_camera_ids):
|
||||||
|
regressor_signals_train[..., matrix_id] = camera_sequence_filtered[id] - 1.0
|
||||||
|
|
||||||
|
regressor_signals_train[regressor_signals_train < -1] = 0.0
|
||||||
|
|
||||||
|
# Linear regressor
|
||||||
|
trend = torch.arange(
|
||||||
|
0, regressor_signals_train.shape[-2], device=camera_sequence_filtered[0].device
|
||||||
|
) / float(regressor_signals_train.shape[-2] - 1)
|
||||||
|
trend -= trend.mean()
|
||||||
|
trend = trend.unsqueeze(0).unsqueeze(0)
|
||||||
|
trend = trend.tile(
|
||||||
|
(regressor_signals_train.shape[0], regressor_signals_train.shape[1], 1)
|
||||||
|
)
|
||||||
|
regressor_signals_train[..., -1] = trend
|
||||||
|
|
||||||
|
regressor_signals_train = regressor_signals_train[:, :, first_none_ramp_frame:, :]
|
||||||
|
|
||||||
|
# --- Perform ---
|
||||||
|
|
||||||
|
regressor_signals_perform: torch.Tensor = torch.zeros(
|
||||||
|
(
|
||||||
|
camera_sequence[0].shape[0],
|
||||||
|
camera_sequence[0].shape[1],
|
||||||
|
camera_sequence[0].shape[2],
|
||||||
|
len(regressor_camera_ids) + 1,
|
||||||
|
),
|
||||||
|
device=camera_sequence[0].device,
|
||||||
|
dtype=camera_sequence[0].dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copy the regressor signals -1
|
||||||
|
for matrix_id, id in enumerate(regressor_camera_ids):
|
||||||
|
regressor_signals_perform[..., matrix_id] = camera_sequence[id] - 1.0
|
||||||
|
|
||||||
|
# Linear regressor
|
||||||
|
trend = torch.arange(
|
||||||
|
0, regressor_signals_perform.shape[-2], device=camera_sequence[0].device
|
||||||
|
) / float(regressor_signals_perform.shape[-2] - 1)
|
||||||
|
trend -= trend.mean()
|
||||||
|
trend = trend.unsqueeze(0).unsqueeze(0)
|
||||||
|
trend = trend.tile(
|
||||||
|
(regressor_signals_perform.shape[0], regressor_signals_perform.shape[1], 1)
|
||||||
|
)
|
||||||
|
regressor_signals_perform[..., -1] = trend
|
||||||
|
|
||||||
|
# --==DONE==-
|
||||||
|
|
||||||
|
coefficients, intercept = regression_internal(
|
||||||
|
input_regressor=regressor_signals_train, input_target=target_signals_train
|
||||||
|
)
|
||||||
|
|
||||||
|
target_signals_perform -= (
|
||||||
|
regressor_signals_perform * coefficients.unsqueeze(-2)
|
||||||
|
).sum(dim=-1)
|
||||||
|
|
||||||
|
target_signals_perform -= intercept.unsqueeze(-1)
|
||||||
|
|
||||||
|
target_signals_perform[
|
||||||
|
~mask.unsqueeze(-1).tile((1, 1, target_signals_perform.shape[-1]))
|
||||||
|
] = 0.0
|
||||||
|
|
||||||
|
target_signals_perform += 1.0
|
||||||
|
|
||||||
|
return target_signals_perform
|
20
reproduction_effort/functions/regression_internal.py
Normal file
20
reproduction_effort/functions/regression_internal.py
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def regression_internal(
|
||||||
|
input_regressor: torch.Tensor, input_target: torch.Tensor
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
regressor_offset = input_regressor.mean(keepdim=True, dim=-2)
|
||||||
|
target_offset = input_target.mean(keepdim=True, dim=-1)
|
||||||
|
|
||||||
|
regressor = input_regressor - regressor_offset
|
||||||
|
target = input_target - target_offset
|
||||||
|
|
||||||
|
coefficients, _, _, _ = torch.linalg.lstsq(regressor, target, rcond=None) # None ?
|
||||||
|
|
||||||
|
intercept = target_offset.squeeze(-1) - (
|
||||||
|
coefficients * regressor_offset.squeeze(-2)
|
||||||
|
).sum(dim=-1)
|
||||||
|
|
||||||
|
return coefficients, intercept
|
122
reproduction_effort/heartbeatanalyse.py
Normal file
122
reproduction_effort/heartbeatanalyse.py
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
import torch
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
from functions.preprocessing import preprocessing
|
||||||
|
from functions.bandpass import bandpass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device_name: str = "cuda:0"
|
||||||
|
else:
|
||||||
|
device_name = "cpu"
|
||||||
|
print(f"Using device: {device_name}")
|
||||||
|
device: torch.device = torch.device(device_name)
|
||||||
|
|
||||||
|
filename_metadata: str = "raw/Exp001_Trial001_Part001_meta.txt"
|
||||||
|
filename_data: str = "Exp001_Trial001_Part001.mat"
|
||||||
|
filename_mask: str = "2020-12-08maskPixelraw2.mat"
|
||||||
|
|
||||||
|
first_none_ramp_frame: int = 100
|
||||||
|
spatial_width: float = 2
|
||||||
|
temporal_width: float = 0.1
|
||||||
|
|
||||||
|
lower_freqency_bandpass: float = 5.0
|
||||||
|
upper_freqency_bandpass: float = 14.0
|
||||||
|
|
||||||
|
target_camera: list[str] = ["acceptor", "donor"]
|
||||||
|
regressor_cameras: list[str] = ["oxygenation", "volume"]
|
||||||
|
|
||||||
|
ratio_sequence_a, ratio_sequence_b, mask = preprocessing(
|
||||||
|
filename_metadata=filename_metadata,
|
||||||
|
filename_data=filename_data,
|
||||||
|
filename_mask=filename_mask,
|
||||||
|
device=device,
|
||||||
|
first_none_ramp_frame=first_none_ramp_frame,
|
||||||
|
spatial_width=spatial_width,
|
||||||
|
temporal_width=temporal_width,
|
||||||
|
target_camera=target_camera,
|
||||||
|
regressor_cameras=regressor_cameras,
|
||||||
|
)
|
||||||
|
|
||||||
|
ratio_sequence_a = bandpass(
|
||||||
|
data=ratio_sequence_a,
|
||||||
|
device=ratio_sequence_a.device,
|
||||||
|
low_frequency=lower_freqency_bandpass,
|
||||||
|
high_frequency=upper_freqency_bandpass,
|
||||||
|
fs=100.0,
|
||||||
|
filtfilt_chuck_size=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
ratio_sequence_b = bandpass(
|
||||||
|
data=ratio_sequence_b,
|
||||||
|
device=ratio_sequence_b.device,
|
||||||
|
low_frequency=lower_freqency_bandpass,
|
||||||
|
high_frequency=upper_freqency_bandpass,
|
||||||
|
fs=100.0,
|
||||||
|
filtfilt_chuck_size=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
original_shape = ratio_sequence_a.shape
|
||||||
|
|
||||||
|
ratio_sequence_a = ratio_sequence_a.flatten(start_dim=0, end_dim=-2)
|
||||||
|
ratio_sequence_b = ratio_sequence_b.flatten(start_dim=0, end_dim=-2)
|
||||||
|
|
||||||
|
mask = mask.flatten(start_dim=0, end_dim=-1)
|
||||||
|
ratio_sequence_a = ratio_sequence_a[mask, :]
|
||||||
|
ratio_sequence_b = ratio_sequence_b[mask, :]
|
||||||
|
|
||||||
|
ratio_sequence_a = ratio_sequence_a.movedim(0, -1)
|
||||||
|
ratio_sequence_b = ratio_sequence_b.movedim(0, -1)
|
||||||
|
|
||||||
|
ratio_sequence_a -= ratio_sequence_a.mean(dim=0, keepdim=True)
|
||||||
|
ratio_sequence_b -= ratio_sequence_b.mean(dim=0, keepdim=True)
|
||||||
|
|
||||||
|
u_a, s_a, Vh_a = torch.linalg.svd(ratio_sequence_a, full_matrices=False)
|
||||||
|
u_a = u_a[:, 0]
|
||||||
|
s_a = s_a[0]
|
||||||
|
Vh_a = Vh_a[0, :]
|
||||||
|
|
||||||
|
heartbeatactivitmap_a = torch.zeros(
|
||||||
|
(original_shape[0], original_shape[1]), device=Vh_a.device, dtype=Vh_a.dtype
|
||||||
|
).flatten(start_dim=0, end_dim=-1)
|
||||||
|
|
||||||
|
heartbeatactivitmap_a *= torch.nan
|
||||||
|
heartbeatactivitmap_a[mask] = s_a * Vh_a
|
||||||
|
heartbeatactivitmap_a = heartbeatactivitmap_a.reshape(
|
||||||
|
(original_shape[0], original_shape[1])
|
||||||
|
)
|
||||||
|
|
||||||
|
u_b, s_b, Vh_b = torch.linalg.svd(ratio_sequence_b, full_matrices=False)
|
||||||
|
u_b = u_b[:, 0]
|
||||||
|
s_b = s_b[0]
|
||||||
|
Vh_b = Vh_b[0, :]
|
||||||
|
|
||||||
|
heartbeatactivitmap_b = torch.zeros(
|
||||||
|
(original_shape[0], original_shape[1]), device=Vh_b.device, dtype=Vh_b.dtype
|
||||||
|
).flatten(start_dim=0, end_dim=-1)
|
||||||
|
|
||||||
|
heartbeatactivitmap_b *= torch.nan
|
||||||
|
heartbeatactivitmap_b[mask] = s_b * Vh_b
|
||||||
|
heartbeatactivitmap_b = heartbeatactivitmap_b.reshape(
|
||||||
|
(original_shape[0], original_shape[1])
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.subplot(2, 1, 1)
|
||||||
|
plt.plot(u_a.cpu(), label="aceptor")
|
||||||
|
plt.plot(u_b.cpu(), label="donor")
|
||||||
|
plt.legend()
|
||||||
|
plt.subplot(2, 1, 2)
|
||||||
|
plt.imshow(
|
||||||
|
torch.cat(
|
||||||
|
(
|
||||||
|
heartbeatactivitmap_a,
|
||||||
|
heartbeatactivitmap_b,
|
||||||
|
),
|
||||||
|
dim=1,
|
||||||
|
).cpu()
|
||||||
|
)
|
||||||
|
plt.colorbar()
|
||||||
|
plt.show()
|
68
reproduction_effort/preprocessing.py
Normal file
68
reproduction_effort/preprocessing.py
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import h5py # type: ignore
|
||||||
|
|
||||||
|
from functions.preprocessing import preprocessing
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device_name: str = "cuda:0"
|
||||||
|
else:
|
||||||
|
device_name = "cpu"
|
||||||
|
print(f"Using device: {device_name}")
|
||||||
|
device: torch.device = torch.device(device_name)
|
||||||
|
|
||||||
|
filename_metadata: str = "raw/Exp001_Trial001_Part001_meta.txt"
|
||||||
|
filename_data: str = "Exp001_Trial001_Part001.mat"
|
||||||
|
filename_mask: str = "2020-12-08maskPixelraw2.mat"
|
||||||
|
|
||||||
|
first_none_ramp_frame: int = 100
|
||||||
|
spatial_width: float = 2
|
||||||
|
temporal_width: float = 0.1
|
||||||
|
|
||||||
|
target_camera: list[str] = ["acceptor", "donor"]
|
||||||
|
regressor_cameras: list[str] = ["oxygenation", "volume"]
|
||||||
|
|
||||||
|
data_acceptor, data_donor, mask = preprocessing(
|
||||||
|
filename_metadata=filename_metadata,
|
||||||
|
filename_data=filename_data,
|
||||||
|
filename_mask=filename_mask,
|
||||||
|
device=device,
|
||||||
|
first_none_ramp_frame=first_none_ramp_frame,
|
||||||
|
spatial_width=spatial_width,
|
||||||
|
temporal_width=temporal_width,
|
||||||
|
target_camera=target_camera,
|
||||||
|
regressor_cameras=regressor_cameras,
|
||||||
|
)
|
||||||
|
|
||||||
|
ratio_sequence: torch.Tensor = data_acceptor / data_donor
|
||||||
|
|
||||||
|
new: np.ndarray = ratio_sequence.cpu().numpy()
|
||||||
|
|
||||||
|
file_handle = h5py.File("old.mat", "r")
|
||||||
|
old: np.ndarray = np.array(file_handle["ratioSequence"])
|
||||||
|
# HDF5 loads everything backwards...
|
||||||
|
old = np.moveaxis(old, 0, -1)
|
||||||
|
old = np.moveaxis(old, 0, -2)
|
||||||
|
|
||||||
|
pos_x = 25
|
||||||
|
pos_y = 75
|
||||||
|
|
||||||
|
plt.subplot(2, 1, 1)
|
||||||
|
new_select = new[pos_x, pos_y, :]
|
||||||
|
old_select = old[pos_x, pos_y, :]
|
||||||
|
plt.plot(new_select, label="New")
|
||||||
|
plt.plot(old_select, "--", label="Old")
|
||||||
|
plt.plot(old_select - new_select + 1.0, label="Old - New + 1")
|
||||||
|
plt.title(f"Position: {pos_x}, {pos_y}")
|
||||||
|
plt.legend()
|
||||||
|
|
||||||
|
plt.subplot(2, 1, 2)
|
||||||
|
differences = (np.abs(new - old)).max(axis=-1)
|
||||||
|
plt.imshow(differences)
|
||||||
|
plt.title("Max of abs(new-old) along time")
|
||||||
|
plt.colorbar()
|
||||||
|
plt.show()
|
Loading…
Reference in a new issue