gevi/DataContainer.py

1836 lines
67 KiB
Python
Raw Normal View History

2023-07-09 18:55:12 +02:00
# pip install roipoly natsort numpy matplotlib
# Also install: torch torchaudio torchvision
# (for details see https://pytorch.org/get-started/locally/ )
# Tested on Python 3.11
import glob
import json
import logging
import math
import os
from datetime import datetime
import matplotlib.pyplot as plt
import natsort
import numpy as np
import torch
import torchaudio as ta
import torchvision as tv
from roipoly import RoiPoly
from ImageAlignment import ImageAlignment
class DataContainer(torch.nn.Module):
ref_image_acceptor: torch.Tensor | None = None
ref_image_donor: torch.Tensor | None = None
acceptor: torch.Tensor | None = None
donor: torch.Tensor | None = None
oxygenation: torch.Tensor | None = None
volume: torch.Tensor | None = None
acceptor_whiten_mean: torch.Tensor | None = None
acceptor_whiten_k: torch.Tensor | None = None
acceptor_eigenvalues: torch.Tensor | None = None
acceptor_residuum: torch.Tensor | None = None
donor_whiten_mean: torch.Tensor | None = None
donor_whiten_k: torch.Tensor | None = None
donor_eigenvalues: torch.Tensor | None = None
donor_residuum: torch.Tensor | None = None
oxygenation_whiten_mean: torch.Tensor | None = None
oxygenation_whiten_k: torch.Tensor | None = None
oxygenation_eigenvalues: torch.Tensor | None = None
oxygenation_residuum: torch.Tensor | None = None
volume_whiten_mean: torch.Tensor | None = None
volume_whiten_k: torch.Tensor | None = None
volume_eigenvalues: torch.Tensor | None = None
volume_residuum: torch.Tensor | None = None
2023-07-10 13:06:54 +02:00
acceptor_scale: torch.Tensor | None = None
donor_scale: torch.Tensor | None = None
oxygenation_scale: torch.Tensor | None = None
volume_scale: torch.Tensor | None = None
2023-07-09 19:26:30 +02:00
2023-07-09 18:55:12 +02:00
# -------
image_alignment: ImageAlignment
acceptor_index: int
donor_index: int
oxygenation_index: int
volume_index: int
path: str
channels: list[str]
device: torch.device
batch_size: int = 200
fill_value: float = -0.1
filtfilt_chuck_size: int = 10
level0 = str("=")
level1 = str("==")
level2 = str("===")
level3 = str("====")
@torch.no_grad()
def __init__(
self,
path: str,
device: torch.device,
display_logging_messages: bool = False,
save_logging_messages: bool = False,
) -> None:
super().__init__()
self.device = device
assert path is not None
self.path = path
now = datetime.now()
dt_string_filename = now.strftime("%Y_%m_%d_%H_%M_%S")
self.logger = logging.getLogger("DataContainer")
self.logger.setLevel(logging.DEBUG)
2023-07-10 13:06:54 +02:00
if save_logging_messages:
2023-07-09 18:55:12 +02:00
time_format = "%b %-d %Y %H:%M:%S"
logformat = "%(asctime)s %(message)s"
file_formatter = logging.Formatter(fmt=logformat, datefmt=time_format)
file_handler = logging.FileHandler(f"log_{dt_string_filename}.txt")
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(file_formatter)
self.logger.addHandler(file_handler)
2023-07-10 13:06:54 +02:00
if display_logging_messages:
2023-07-09 18:55:12 +02:00
time_format = "%b %-d %Y %H:%M:%S"
logformat = "%(asctime)s %(message)s"
stream_formatter = logging.Formatter(fmt=logformat, datefmt=time_format)
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(stream_formatter)
self.logger.addHandler(stream_handler)
file_input_ref_image = self._find_ref_image_file()
data = np.load(file_input_ref_image, mmap_mode="r")
ref_image = torch.tensor(
data[:, :, data.shape[2] // 2, :].astype(np.float32),
device=self.device,
dtype=torch.float32,
)
json_postfix: str = "_meta.txt"
found_name_json: str = file_input_ref_image.replace(".npy", json_postfix)
2023-07-10 13:06:54 +02:00
assert os.path.isfile(found_name_json)
2023-07-09 18:55:12 +02:00
with open(found_name_json, "r") as file_handle:
metadata = json.load(file_handle)
self.channels = metadata["channelKey"]
self.acceptor_index = self.channels.index("acceptor")
self.donor_index = self.channels.index("donor")
self.oxygenation_index = self.channels.index("oxygenation")
self.volume_index = self.channels.index("volume")
self.ref_image_acceptor: torch.Tensor = ref_image[:, :, self.acceptor_index]
self.ref_image_donor: torch.Tensor = ref_image[:, :, self.donor_index]
self.image_alignment = ImageAlignment(
default_dtype=torch.float32, device=self.device
)
@torch.no_grad()
def get_trials(self, experiment_id: int) -> torch.Tensor:
filename_np: str = os.path.join(
self.path,
f"Exp{experiment_id:03d}_Trial*_Part001.npy",
)
list_str = glob.glob(filename_np)
list_int: list[int] = []
for i in range(0, len(list_str)):
list_int.append(int(list_str[i].split("_Trial")[-1].split("_Part")[0]))
list_int = sorted(list_int)
return torch.tensor(list_int).unique()
@torch.no_grad()
def get_experiments(
self,
) -> torch.Tensor:
filename_np: str = os.path.join(
self.path,
"Exp*_Part001.npy",
)
list_str = glob.glob(filename_np)
list_int: list[int] = []
for i in range(0, len(list_str)):
list_int.append(int(list_str[i].split("Exp")[-1].split("_Trial")[0]))
list_int = sorted(list_int)
return torch.tensor(list_int).unique()
@torch.no_grad()
def load_data( # start_position_coefficients: OK
self,
experiment_id: int,
trial_id: int,
align: bool = True,
enable_secondary_data: bool = True,
mmap_mode: bool = True,
start_position_coefficients: int = 0,
):
self.acceptor = None
self.donor = None
self.oxygenation = None
self.volume = None
part_id: int = 1
filename_np: str = os.path.join(
self.path,
f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}.npy",
)
filename_meta: str = os.path.join(
self.path,
f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}_meta.txt",
)
2023-07-10 13:06:54 +02:00
while (os.path.isfile(filename_np)) and (os.path.isfile(filename_meta)):
2023-07-09 18:55:12 +02:00
self.logger.info(f"{self.level3} work in {filename_np}")
# Check if channel asignment is still okay
with open(filename_meta, "r") as file_handle:
metadata = json.load(file_handle)
channels = metadata["channelKey"]
assert len(channels) == len(self.channels)
for i in range(0, len(channels)):
assert channels[i] == self.channels[i]
# Load the data...
self.logger.info(f"{self.level3} np.load")
2023-07-10 13:06:54 +02:00
if mmap_mode:
2023-07-09 18:55:12 +02:00
temp: np.ndarray = np.load(filename_np, mmap_mode="r")
else:
temp = np.load(filename_np)
self.logger.info(f"{self.level3} organize acceptor")
if self.acceptor is None:
self.acceptor = torch.tensor(
temp[:, :, :, self.acceptor_index].astype(np.float32),
device=self.device,
dtype=torch.float32,
)
else:
assert self.acceptor is not None
assert self.acceptor.ndim == temp.ndim
assert self.acceptor.shape[0] == temp.shape[0]
assert self.acceptor.shape[1] == temp.shape[1]
assert self.acceptor.shape[3] == temp.shape[3]
self.acceptor = torch.cat(
(
self.acceptor,
torch.tensor(
temp[:, :, :, self.acceptor_index].astype(np.float32),
device=self.device,
dtype=torch.float32,
),
),
dim=2,
)
self.logger.info(f"{self.level3} organize donor")
if self.donor is None:
self.donor = torch.tensor(
temp[:, :, :, self.donor_index].astype(np.float32),
device=self.device,
dtype=torch.float32,
)
else:
assert self.donor is not None
assert self.donor.ndim == temp.ndim
assert self.donor.shape[0] == temp.shape[0]
assert self.donor.shape[1] == temp.shape[1]
assert self.donor.shape[3] == temp.shape[3]
self.donor = torch.cat(
(
self.donor,
torch.tensor(
temp[:, :, :, self.donor_index].astype(np.float32),
device=self.device,
dtype=torch.float32,
),
),
dim=2,
)
2023-07-10 13:06:54 +02:00
if enable_secondary_data:
2023-07-09 18:55:12 +02:00
self.logger.info(f"{self.level3} organize oxygenation")
if self.oxygenation is None:
self.oxygenation = torch.tensor(
temp[:, :, :, self.oxygenation_index].astype(np.float32),
device=self.device,
dtype=torch.float32,
)
else:
assert self.oxygenation is not None
assert self.oxygenation.ndim == temp.ndim
assert self.oxygenation.shape[0] == temp.shape[0]
assert self.oxygenation.shape[1] == temp.shape[1]
assert self.oxygenation.shape[3] == temp.shape[3]
self.oxygenation = torch.cat(
(
self.oxygenation,
torch.tensor(
temp[:, :, :, self.oxygenation_index].astype(
np.float32
),
device=self.device,
dtype=torch.float32,
),
),
dim=2,
)
if self.volume is None:
self.logger.info(f"{self.level3} organize volume")
self.volume = torch.tensor(
temp[:, :, :, self.volume_index].astype(np.float32),
device=self.device,
dtype=torch.float32,
)
else:
assert self.volume is not None
assert self.volume.ndim == temp.ndim
assert self.volume.shape[0] == temp.shape[0]
assert self.volume.shape[1] == temp.shape[1]
assert self.volume.shape[3] == temp.shape[3]
self.volume = torch.cat(
(
self.volume,
torch.tensor(
temp[:, :, :, self.volume_index].astype(np.float32),
device=self.device,
dtype=torch.float32,
),
),
dim=2,
)
part_id += 1
filename_np = os.path.join(
self.path,
f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}.npy",
)
filename_meta = os.path.join(
self.path,
f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}_meta.txt",
)
self.logger.info(f"{self.level3} move axis")
assert self.acceptor is not None
assert self.donor is not None
self.acceptor = self.acceptor.moveaxis(-1, 0)
self.donor = self.donor.moveaxis(-1, 0)
2023-07-10 13:06:54 +02:00
if enable_secondary_data:
2023-07-09 18:55:12 +02:00
assert self.oxygenation is not None
assert self.volume is not None
self.oxygenation = self.oxygenation.moveaxis(-1, 0)
self.volume = self.volume.moveaxis(-1, 0)
2023-07-10 13:06:54 +02:00
if align:
2023-07-09 18:55:12 +02:00
self.logger.info(f"{self.level3} move intra timeseries")
self._move_intra_timeseries(
enable_secondary_data=enable_secondary_data,
start_position_coefficients=start_position_coefficients,
)
self.logger.info(f"{self.level3} rotate inter timeseries")
self._rotate_inter_timeseries(
enable_secondary_data=enable_secondary_data,
start_position_coefficients=start_position_coefficients,
)
self.logger.info(f"{self.level3} move inter timeseries")
self._move_inter_timeseries(
enable_secondary_data=enable_secondary_data,
start_position_coefficients=start_position_coefficients,
)
# reset svd
self.acceptor_whiten_mean = None
self.acceptor_whiten_k = None
self.acceptor_eigenvalues = None
self.donor_whiten_mean = None
self.donor_whiten_k = None
self.donor_eigenvalues = None
self.oxygenation_whiten_mean = None
self.oxygenation_whiten_k = None
self.oxygenation_eigenvalues = None
self.volume_whiten_mean = None
self.volume_whiten_k = None
self.volume_eigenvalues = None
@torch.no_grad()
def _find_ref_image_file(self) -> str:
filename_postfix: str = "Exp*.npy"
new_list = glob.glob(os.path.join(self.path, filename_postfix))
new_list = natsort.natsorted(new_list)
found_name: str | None = None
for filename in new_list:
if (filename.find("Trial") != -1) and (filename.find("Part") != -1):
found_name = filename
break
assert found_name is not None
return found_name
@torch.no_grad()
def _calculate_translation( # start_position_coefficients: OK
self,
input: torch.Tensor,
reference_image: torch.Tensor,
start_position_coefficients: int = 0,
) -> torch.Tensor:
tvec = torch.zeros((input.shape[0], 2))
data_loader = torch.utils.data.DataLoader(
torch.utils.data.TensorDataset(input[start_position_coefficients:, ...]),
batch_size=self.batch_size,
shuffle=False,
)
start_position: int = 0
for input_batch in data_loader:
assert len(input_batch) == 1
end_position = start_position + input_batch[0].shape[0]
tvec_temp = self.image_alignment.dry_run_translation(
input=input_batch[0],
new_reference_image=reference_image,
)
assert tvec_temp is not None
tvec[start_position:end_position, :] = tvec_temp
start_position += input_batch[0].shape[0]
tvec = torch.round(torch.median(tvec, dim=0)[0])
return tvec
@torch.no_grad()
def _calculate_rotation( # start_position_coefficients: OK
self,
input: torch.Tensor,
reference_image: torch.Tensor,
start_position_coefficients: int = 0,
) -> torch.Tensor:
angle = torch.zeros((input.shape[0]))
data_loader = torch.utils.data.DataLoader(
torch.utils.data.TensorDataset(input[start_position_coefficients:, ...]),
batch_size=self.batch_size,
shuffle=False,
)
start_position: int = 0
for input_batch in data_loader:
assert len(input_batch) == 1
end_position = start_position + input_batch[0].shape[0]
angle_temp = self.image_alignment.dry_run_angle(
input=input_batch[0],
new_reference_image=reference_image,
)
assert angle_temp is not None
angle[start_position:end_position] = angle_temp
start_position += input_batch[0].shape[0]
angle = torch.where(angle >= 180, 360.0 - angle, angle)
angle = torch.where(angle <= -180, 360.0 + angle, angle)
angle = torch.median(angle, dim=0)[0]
return angle
@torch.no_grad()
def _move_intra_timeseries( # start_position_coefficients: OK
self,
enable_secondary_data: bool = True,
start_position_coefficients: int = 0,
) -> None:
# donor_volume
assert self.donor is not None
assert self.ref_image_donor is not None
tvec_donor = self._calculate_translation(
self.donor,
self.ref_image_donor,
start_position_coefficients=start_position_coefficients,
)
self.donor = tv.transforms.functional.affine(
img=self.donor,
angle=0,
translate=[tvec_donor[1], tvec_donor[0]],
scale=1.0,
shear=0,
fill=self.fill_value,
)
2023-07-10 13:06:54 +02:00
if enable_secondary_data:
2023-07-09 18:55:12 +02:00
assert self.volume is not None
self.volume = tv.transforms.functional.affine(
img=self.volume,
angle=0,
translate=[tvec_donor[1], tvec_donor[0]],
scale=1.0,
shear=0,
fill=self.fill_value,
)
# acceptor_oxy
assert self.acceptor is not None
assert self.ref_image_acceptor is not None
tvec_acceptor = self._calculate_translation(
self.acceptor,
self.ref_image_acceptor,
start_position_coefficients=start_position_coefficients,
)
self.acceptor = tv.transforms.functional.affine(
img=self.acceptor,
angle=0,
translate=[tvec_acceptor[1], tvec_acceptor[0]],
scale=1.0,
shear=0,
fill=self.fill_value,
)
2023-07-10 13:06:54 +02:00
if enable_secondary_data:
2023-07-09 18:55:12 +02:00
assert self.oxygenation is not None
self.oxygenation = tv.transforms.functional.affine(
img=self.oxygenation,
angle=0,
translate=[tvec_acceptor[1], tvec_acceptor[0]],
scale=1.0,
shear=0,
fill=self.fill_value,
)
@torch.no_grad()
def _move_inter_timeseries( # start_position_coefficients: OK
self,
enable_secondary_data: bool = True,
start_position_coefficients: int = 0,
) -> None:
# acceptor_oxy
assert self.acceptor is not None
assert self.ref_image_donor is not None
tvec = self._calculate_translation(
self.acceptor,
self.ref_image_donor,
start_position_coefficients=start_position_coefficients,
)
self.acceptor = tv.transforms.functional.affine(
img=self.acceptor,
angle=0,
translate=[tvec[1], tvec[0]],
scale=1.0,
shear=0,
fill=self.fill_value,
)
2023-07-10 13:06:54 +02:00
if enable_secondary_data:
2023-07-09 18:55:12 +02:00
assert self.oxygenation is not None
self.oxygenation = tv.transforms.functional.affine(
img=self.oxygenation,
angle=0,
translate=[tvec[1], tvec[0]],
scale=1.0,
shear=0,
fill=self.fill_value,
)
@torch.no_grad()
def _rotate_inter_timeseries( # start_position_coefficients: OK
self,
enable_secondary_data: bool = True,
start_position_coefficients: int = 0,
) -> None:
# acceptor_oxy
assert self.acceptor is not None
assert self.ref_image_donor is not None
angle = self._calculate_rotation(
self.acceptor,
self.ref_image_donor,
start_position_coefficients=start_position_coefficients,
)
self.acceptor = tv.transforms.functional.affine(
img=self.acceptor,
angle=-float(angle),
translate=[0, 0],
scale=1.0,
shear=0,
interpolation=tv.transforms.InterpolationMode.BILINEAR,
fill=self.fill_value,
)
2023-07-10 13:06:54 +02:00
if enable_secondary_data:
2023-07-09 18:55:12 +02:00
assert self.oxygenation is not None
self.oxygenation = tv.transforms.functional.affine(
img=self.oxygenation,
angle=-float(angle),
translate=[0, 0],
scale=1.0,
shear=0,
interpolation=tv.transforms.InterpolationMode.BILINEAR,
fill=self.fill_value,
)
@torch.no_grad()
def _svd( # start_position_coefficients: OK
self,
input: torch.Tensor,
lowrank_method: bool = True,
lowrank_q: int = 6,
start_position_coefficients: int = 0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
selection = torch.flatten(
input[start_position_coefficients:, ...].clone().movedim(0, -1),
start_dim=0,
end_dim=1,
)
whiten_mean = torch.mean(selection, dim=-1)
selection -= whiten_mean.unsqueeze(-1)
whiten_mean = whiten_mean.reshape((input.shape[1], input.shape[2]))
if lowrank_method is False:
svd_u, svd_s, _ = torch.linalg.svd(selection, full_matrices=False)
else:
svd_u, svd_s, _ = torch.svd_lowrank(selection, q=lowrank_q)
whiten_k = (
torch.sign(svd_u[0, :]).unsqueeze(0) * svd_u / (svd_s.unsqueeze(0) + 1e-20)
)
whiten_k = whiten_k.reshape((input.shape[1], input.shape[2], svd_s.shape[0]))
eigenvalues = svd_s
return whiten_mean, whiten_k, eigenvalues
@torch.no_grad()
def _to_remove( # start_position_coefficients: OK
self,
input: torch.Tensor | None,
lowrank_method: bool = True,
lowrank_q: int = 6,
start_position_coefficients: int = 0,
) -> tuple[torch.Tensor, torch.Tensor, float, torch.Tensor, torch.Tensor]:
assert input is not None
id: int = 0
(
input_whiten_mean,
input_whiten_k,
input_eigenvalues,
) = self._svd(
input,
lowrank_method=lowrank_method,
lowrank_q=lowrank_q,
start_position_coefficients=start_position_coefficients,
)
assert input_whiten_mean is not None
assert input_whiten_k is not None
assert input_eigenvalues is not None
eigenvalue = float(input_eigenvalues[id])
whiten_mean = input_whiten_mean
whiten_k = input_whiten_k[:, :, 0]
data = (input - input_whiten_mean.unsqueeze(0)) * input_whiten_k[
:, :, id
].unsqueeze(0)
input_svd = data.sum(dim=-1).sum(dim=-1).unsqueeze(-1).unsqueeze(-1)
factor = (data * input_svd).sum(dim=0, keepdim=True) / (input_svd**2).sum(
dim=0, keepdim=True
)
to_remove = input_svd * factor
to_remove /= input_whiten_k[:, :, id].unsqueeze(0) + 1e-20
to_remove += input_whiten_mean.unsqueeze(0)
output = input - to_remove
return output, to_remove, eigenvalue, whiten_mean, whiten_k
@torch.no_grad()
def acceptor_svd_remove( # start_position_coefficients: OK
self,
lowrank_method: bool = True,
lowrank_q: int = 6,
start_position_coefficients: int = 0,
) -> tuple[torch.Tensor, float, torch.Tensor, torch.Tensor]:
self.acceptor, to_remove, eigenvalue, whiten_mean, whiten_k = self._to_remove(
input=self.acceptor,
lowrank_method=lowrank_method,
lowrank_q=lowrank_q,
start_position_coefficients=start_position_coefficients,
)
return to_remove, eigenvalue, whiten_mean, whiten_k
@torch.no_grad()
def donor_svd_remove( # start_position_coefficients: OK
self,
lowrank_method: bool = True,
lowrank_q: int = 6,
start_position_coefficients: int = 0,
) -> tuple[torch.Tensor, float, torch.Tensor, torch.Tensor]:
self.donor, to_remove, eigenvalue, whiten_mean, whiten_k = self._to_remove(
input=self.donor,
lowrank_method=lowrank_method,
lowrank_q=lowrank_q,
start_position_coefficients=start_position_coefficients,
)
return to_remove, eigenvalue, whiten_mean, whiten_k
@torch.no_grad()
def volume_svd_remove( # start_position_coefficients: OK
self,
lowrank_method: bool = True,
lowrank_q: int = 6,
start_position_coefficients: int = 0,
) -> tuple[torch.Tensor, float, torch.Tensor, torch.Tensor]:
self.volume, to_remove, eigenvalue, whiten_mean, whiten_k = self._to_remove(
input=self.volume,
lowrank_method=lowrank_method,
lowrank_q=lowrank_q,
start_position_coefficients=start_position_coefficients,
)
return to_remove, eigenvalue, whiten_mean, whiten_k
@torch.no_grad()
def oxygenation_svd_remove( # start_position_coefficients: OK
self,
lowrank_method: bool = True,
lowrank_q: int = 6,
start_position_coefficients: int = 0,
) -> tuple[torch.Tensor, float, torch.Tensor, torch.Tensor]:
(
self.oxygenation,
to_remove,
eigenvalue,
whiten_mean,
whiten_k,
) = self._to_remove(
input=self.oxygenation,
lowrank_method=lowrank_method,
lowrank_q=lowrank_q,
start_position_coefficients=start_position_coefficients,
)
return to_remove, eigenvalue, whiten_mean, whiten_k
@torch.no_grad()
def remove_heartbeat( # start_position_coefficients: OK
self,
iterations: int = 2,
lowrank_method: bool = True,
lowrank_q: int = 6,
enable_secondary_data: bool = True,
start_position_coefficients: int = 0,
):
self.acceptor_residuum = None
self.donor_residuum = None
self.oxygenation_residuum = None
self.volume_residuum = None
for _ in range(0, iterations):
to_remove, _, _, _ = self.acceptor_svd_remove(
lowrank_method=lowrank_method,
lowrank_q=lowrank_q,
start_position_coefficients=start_position_coefficients,
)
if self.acceptor_residuum is None:
self.acceptor_residuum = to_remove
else:
self.acceptor_residuum += to_remove
to_remove, _, _, _ = self.donor_svd_remove(
lowrank_method=lowrank_method,
lowrank_q=lowrank_q,
start_position_coefficients=start_position_coefficients,
)
if self.donor_residuum is None:
self.donor_residuum = to_remove
else:
self.donor_residuum += to_remove
2023-07-10 13:06:54 +02:00
if enable_secondary_data:
2023-07-09 18:55:12 +02:00
to_remove, _, _, _ = self.volume_svd_remove(
lowrank_method=lowrank_method,
lowrank_q=lowrank_q,
start_position_coefficients=start_position_coefficients,
)
if self.volume_residuum is None:
self.volume_residuum = to_remove
else:
self.volume_residuum += to_remove
to_remove, _, _, _ = self.oxygenation_svd_remove(
lowrank_method=lowrank_method,
lowrank_q=lowrank_q,
start_position_coefficients=start_position_coefficients,
)
if self.oxygenation_residuum is None:
self.oxygenation_residuum = to_remove
else:
self.oxygenation_residuum += to_remove
@torch.no_grad()
def remove_mean_data(self, enable_secondary_data: bool = True) -> None:
assert self.donor is not None
assert self.acceptor is not None
self.donor -= self.donor.mean(dim=0, keepdim=True)
self.acceptor -= self.acceptor.mean(dim=0, keepdim=True)
2023-07-10 13:06:54 +02:00
if enable_secondary_data:
2023-07-09 18:55:12 +02:00
assert self.volume is not None
assert self.oxygenation is not None
self.volume -= self.volume.mean(dim=0, keepdim=True)
self.oxygenation -= self.oxygenation.mean(dim=0, keepdim=True)
@torch.no_grad()
def remove_mean_residuum(self, enable_secondary_data: bool = True) -> None:
assert self.donor_residuum is not None
assert self.acceptor_residuum is not None
self.donor_residuum -= self.donor_residuum.mean(dim=0, keepdim=True)
self.acceptor_residuum -= self.acceptor_residuum.mean(dim=0, keepdim=True)
2023-07-10 13:06:54 +02:00
if enable_secondary_data:
2023-07-09 18:55:12 +02:00
assert self.volume_residuum is not None
assert self.oxygenation_residuum is not None
self.volume_residuum -= self.volume_residuum.mean(dim=0, keepdim=True)
self.oxygenation_residuum -= self.oxygenation_residuum.mean(
dim=0, keepdim=True
)
@torch.no_grad()
def _calculate_linear_trend_data(self, input: torch.Tensor) -> torch.Tensor:
assert input.ndim == 3
time_beam: torch.Tensor = torch.arange(
0, input.shape[0], dtype=torch.float32, device=self.device
)
time_beam -= time_beam.mean()
input_mean = input.mean(dim=0, keepdim=True)
factor = (time_beam.unsqueeze(-1).unsqueeze(-1) * (input - input_mean)).sum(
dim=0, keepdim=True
) / (time_beam**2).sum(dim=0, keepdim=True).unsqueeze(-1).unsqueeze(-1)
output = factor * time_beam.unsqueeze(-1).unsqueeze(-1) + input_mean
return output
@torch.no_grad()
def remove_linear_trend_data(self, enable_secondary_data: bool = True) -> None:
assert self.donor is not None
assert self.acceptor is not None
self.donor -= self._calculate_linear_trend_data(self.donor)
self.acceptor -= self._calculate_linear_trend_data(self.acceptor)
2023-07-10 13:06:54 +02:00
if enable_secondary_data:
2023-07-09 18:55:12 +02:00
assert self.volume is not None
assert self.oxygenation is not None
self.volume -= self._calculate_linear_trend_data(self.volume)
self.oxygenation -= self._calculate_linear_trend_data(self.oxygenation)
@torch.no_grad()
def remove_linear_trend_residuum(
self,
enable_secondary_data: bool = True,
) -> None:
assert self.donor_residuum is not None
assert self.acceptor_residuum is not None
self.donor_residuum -= self._calculate_linear_trend_data(self.donor_residuum)
self.acceptor_residuum -= self._calculate_linear_trend_data(
self.acceptor_residuum
)
2023-07-10 13:06:54 +02:00
if enable_secondary_data:
2023-07-09 18:55:12 +02:00
assert self.volume_residuum is not None
assert self.oxygenation_residuum is not None
self.volume_residuum -= self._calculate_linear_trend_data(
self.volume_residuum
)
self.oxygenation_residuum -= self._calculate_linear_trend_data(
self.oxygenation_residuum
)
@torch.no_grad()
def frame_shift(
self,
enable_secondary_data: bool = True,
):
assert self.donor is not None
assert self.acceptor is not None
self.donor = self.donor[1:, :, :]
self.acceptor = self.acceptor[1:, :, :]
2023-07-10 13:06:54 +02:00
if enable_secondary_data:
2023-07-09 18:55:12 +02:00
assert self.volume is not None
assert self.oxygenation is not None
self.volume = (self.volume[1:, :, :] + self.volume[:-1, :, :]) / 2.0
self.oxygenation = (
self.oxygenation[1:, :, :] + self.oxygenation[:-1, :, :]
) / 2.0
if self.donor_residuum is not None:
self.donor_residuum = self.donor_residuum[1:, :, :]
if self.acceptor_residuum is not None:
self.acceptor_residuum = self.acceptor_residuum[1:, :, :]
2023-07-10 13:06:54 +02:00
if enable_secondary_data:
2023-07-09 18:55:12 +02:00
if self.volume_residuum is not None:
self.volume_residuum = (
self.volume_residuum[1:, :, :] + self.volume_residuum[:-1, :, :]
) / 2.0
if self.oxygenation_residuum is not None:
self.oxygenation_residuum = (
self.oxygenation_residuum[1:, :, :]
+ self.oxygenation_residuum[:-1, :, :]
) / 2.0
@torch.no_grad()
def cleaned_load_data(
self,
experiment_id: int,
trial_id: int,
align: bool = True,
iterations: int = 1,
lowrank_method: bool = True,
lowrank_q: int = 6,
remove_heartbeat: bool = True,
remove_mean: bool = True,
remove_linear: bool = True,
remove_heartbeat_mean: bool = False,
remove_heartbeat_linear: bool = False,
bin_size: int = 4,
do_frame_shift: bool = True,
enable_secondary_data: bool = True,
mmap_mode: bool = True,
initital_mask: torch.Tensor | None = None,
start_position_coefficients: int = 0,
) -> None:
self.logger.info(f"{self.level2} start load_data")
self.load_data(
experiment_id=experiment_id,
trial_id=trial_id,
align=align,
enable_secondary_data=enable_secondary_data,
mmap_mode=mmap_mode,
start_position_coefficients=start_position_coefficients,
)
assert self.donor is not None
assert self.acceptor is not None
if bin_size > 1:
self.logger.info(f"{self.level2} spatial pooling")
pool = torch.nn.AvgPool2d((bin_size, bin_size), stride=(bin_size, bin_size))
self.donor = pool(self.donor)
self.acceptor = pool(self.acceptor)
2023-07-10 13:06:54 +02:00
if enable_secondary_data:
2023-07-09 18:55:12 +02:00
assert self.volume is not None
assert self.oxygenation is not None
self.volume = pool(self.volume)
self.oxygenation = pool(self.oxygenation)
2023-07-10 13:06:54 +02:00
if self.donor is not None:
self.donor_scale = self.donor.mean(dim=0, keepdim=True)
self.donor /= self.donor_scale
self.donor -= 1.0
if self.acceptor is not None:
self.acceptor_scale = self.acceptor.mean(dim=0, keepdim=True)
self.acceptor /= self.acceptor_scale
self.acceptor -= 1.0
if self.volume is not None:
self.volume_scale = self.volume.mean(dim=0, keepdim=True)
self.volume /= self.volume_scale
self.volume -= 1.0
if self.oxygenation is not None:
self.oxygenation_scale = self.oxygenation.mean(dim=0, keepdim=True)
self.oxygenation /= self.oxygenation_scale
self.oxygenation -= 1.0
2023-07-09 18:55:12 +02:00
if initital_mask is not None:
self.logger.info(f"{self.level2} initial mask is applied on the data")
assert self.acceptor is not None
assert self.donor is not None
assert initital_mask.ndim == 2
assert initital_mask.shape[0] == self.donor.shape[1]
assert initital_mask.shape[1] == self.donor.shape[2]
self.acceptor *= initital_mask.unsqueeze(0)
self.donor *= initital_mask.unsqueeze(0)
2023-07-10 13:06:54 +02:00
if enable_secondary_data:
2023-07-09 18:55:12 +02:00
assert self.oxygenation is not None
assert self.volume is not None
self.oxygenation *= initital_mask.unsqueeze(0)
self.volume *= initital_mask.unsqueeze(0)
2023-07-10 13:06:54 +02:00
if remove_heartbeat:
2023-07-09 18:55:12 +02:00
self.logger.info(f"{self.level2} remove the heart beat via SVD")
self.remove_heartbeat(
iterations=iterations,
lowrank_method=lowrank_method,
lowrank_q=lowrank_q,
enable_secondary_data=enable_secondary_data,
start_position_coefficients=start_position_coefficients,
)
2023-07-10 13:06:54 +02:00
if remove_mean:
2023-07-09 18:55:12 +02:00
self.logger.info(f"{self.level2} remove mean")
self.remove_mean_data(enable_secondary_data=enable_secondary_data)
2023-07-10 13:06:54 +02:00
if remove_linear:
2023-07-09 18:55:12 +02:00
self.logger.info(f"{self.level2} remove linear trends")
self.remove_linear_trend_data(enable_secondary_data=enable_secondary_data)
2023-07-10 13:06:54 +02:00
if remove_heartbeat:
if remove_heartbeat_mean:
2023-07-09 18:55:12 +02:00
self.logger.info(f"{self.level2} remove mean (heart beat signal)")
self.remove_mean_residuum(enable_secondary_data=enable_secondary_data)
2023-07-10 13:06:54 +02:00
if remove_heartbeat_linear:
2023-07-09 18:55:12 +02:00
self.logger.info(
f"{self.level2} remove linear trends (heart beat signal)"
)
self.remove_linear_trend_residuum(
enable_secondary_data=enable_secondary_data
)
2023-07-10 13:06:54 +02:00
if do_frame_shift:
2023-07-09 18:55:12 +02:00
self.logger.info(f"{self.level2} frame shift")
self.frame_shift(enable_secondary_data=enable_secondary_data)
@torch.no_grad()
def remove_other_signals( # start_position_coefficients: OK
self,
start_position_coefficients: int = 0,
match_iterations: int = 25,
export_parameters: bool = True,
) -> tuple[
torch.Tensor,
torch.Tensor,
float,
float,
float,
float,
torch.Tensor | None,
torch.Tensor | None,
]:
assert self.acceptor is not None
assert self.donor is not None
assert self.oxygenation is not None
assert self.volume is not None
index_full_dataset = torch.arange(
0, self.acceptor.shape[1], device=self.device, dtype=torch.int64
)
result_a: torch.Tensor = torch.zeros_like(self.acceptor)
result_d: torch.Tensor = torch.zeros_like(self.donor)
max_scale_value_a = 0.0
initial_scale_value_a = 0.0
max_scale_value_d = 0.0
initial_scale_value_d = 0.0
parameter_a: torch.Tensor | None = None
parameter_d: torch.Tensor | None = None
for chunk in self._chunk_iterator(index_full_dataset, self.filtfilt_chuck_size):
a: torch.Tensor = self.acceptor[:, chunk, :].detach().clone()
d: torch.Tensor = self.donor[:, chunk, :].detach().clone()
o: torch.Tensor = self.oxygenation[:, chunk, :].detach().clone()
v: torch.Tensor = self.volume[:, chunk, :].detach().clone()
a_mean = a[start_position_coefficients:, ...].mean(dim=0, keepdim=True)
a_mean_full = a.mean(dim=0, keepdim=True)
a -= a_mean_full
a_correction = a_mean - a_mean_full
d_mean = d[start_position_coefficients:, ...].mean(dim=0, keepdim=True)
d_mean_full = d.mean(dim=0, keepdim=True)
d -= d_mean_full
d_correction = d_mean - d_mean_full
o_mean = o[start_position_coefficients:, ...].mean(dim=0, keepdim=True)
o_mean_full = o.mean(dim=0, keepdim=True)
o -= o_mean
o_correction = o_mean - o_mean_full
o_norm = 1.0 / (
(o[start_position_coefficients:, ...] ** 2).sum(dim=0) + 1e-20
)
v_mean = v[start_position_coefficients:, ...].mean(dim=0, keepdim=True)
v_mean_full = v.mean(dim=0, keepdim=True)
v -= v_mean
v_correction = v_mean - v_mean_full
v_norm = 1.0 / (
(v[start_position_coefficients:, ...] ** 2).sum(dim=0) + 1e-20
)
linear: torch.Tensor = (
torch.arange(0, o.shape[0], device=self.device, dtype=torch.float32)
.unsqueeze(-1)
.unsqueeze(-1)
)
l_mean = linear[start_position_coefficients:, ...].mean(dim=0, keepdim=True)
l_mean_full = linear.mean(dim=0, keepdim=True)
linear -= l_mean
l_correction = l_mean - l_mean_full
linear_norm = 1.0 / (
(linear[start_position_coefficients:, ...] ** 2).sum(dim=0) + 1e-20
)
linear = torch.tile(linear, (1, o.shape[1], o.shape[2]))
linear_norm = torch.tile(linear_norm, (o.shape[1], o.shape[2]))
l_correction = torch.tile(l_correction, (1, o.shape[1], o.shape[2]))
data = torch.cat(
(linear.unsqueeze(-1), o.unsqueeze(-1), v.unsqueeze(-1)), dim=-1
)
del linear
del o
del v
data_mean_correction = torch.cat(
(
l_correction.unsqueeze(-1),
o_correction.unsqueeze(-1),
v_correction.unsqueeze(-1),
),
dim=-1,
)
data_norm = torch.cat(
(linear_norm.unsqueeze(-1), o_norm.unsqueeze(-1), v_norm.unsqueeze(-1)),
dim=-1,
)
del linear_norm
del o_norm
del v_norm
2023-07-10 13:06:54 +02:00
if export_parameters:
2023-07-09 18:55:12 +02:00
parameter_a_temp: torch.Tensor | None = torch.zeros_like(data_norm)
parameter_d_temp: torch.Tensor | None = torch.zeros_like(data_norm)
else:
parameter_a_temp = None
parameter_d_temp = None
for mode_a in [True, False]:
2023-07-10 13:06:54 +02:00
if mode_a:
2023-07-09 18:55:12 +02:00
result = a.detach().clone()
result_mean_correct = a_correction
else:
result = d.detach().clone()
result_mean_correct = d_correction
for i in range(0, match_iterations):
scale = (
(
data[start_position_coefficients:, ...]
* (
result[start_position_coefficients:, ...]
+ result_mean_correct
).unsqueeze(-1)
).sum(dim=0)
) * data_norm
idx = torch.abs(scale).argmax(dim=-1)
scale = torch.gather(scale, -1, idx.unsqueeze(-1)).squeeze(-1)
idx_3d = torch.tile(idx.unsqueeze(0), (data.shape[0], 1, 1))
data_selected = torch.gather(
(data - data_mean_correction), -1, idx_3d.unsqueeze(-1)
).squeeze(-1)
result -= data_selected * scale.unsqueeze(0)
2023-07-10 13:06:54 +02:00
if mode_a:
2023-07-09 18:55:12 +02:00
if i == 0:
initial_scale_value_a = max(
[max_scale_value_a, float(scale.max())]
)
if parameter_a_temp is not None:
parameter_a_temp.scatter_add_(
-1, idx.unsqueeze(-1), scale.unsqueeze(-1)
)
else:
if i == 0:
initial_scale_value_d = max(
[max_scale_value_d, float(scale.max())]
)
if parameter_d_temp is not None:
parameter_d_temp.scatter_add_(
-1, idx.unsqueeze(-1), scale.unsqueeze(-1)
)
2023-07-10 13:06:54 +02:00
if mode_a:
2023-07-09 18:55:12 +02:00
result_a[:, chunk, :] = result.detach().clone()
max_scale_value_a = max([max_scale_value_a, float(scale.max())])
if parameter_a_temp is not None:
parameter_a_temp = torch.cat(
(parameter_a_temp, a_mean_full.squeeze(0).unsqueeze(-1)),
dim=-1,
)
else:
result_d[:, chunk, :] = result.detach().clone()
max_scale_value_d = max([max_scale_value_d, float(scale.max())])
if parameter_d_temp is not None:
parameter_d_temp = torch.cat(
(parameter_d_temp, d_mean_full.squeeze(0).unsqueeze(-1)),
dim=-1,
)
2023-07-10 13:06:54 +02:00
if export_parameters:
2023-07-09 18:55:12 +02:00
if (parameter_a is None) and (parameter_a_temp is not None):
parameter_a = torch.zeros(
(
self.acceptor.shape[1],
parameter_a_temp.shape[1],
parameter_a_temp.shape[2],
),
device=self.device,
dtype=torch.float32,
)
if (parameter_a is not None) and (parameter_a_temp is not None):
parameter_a[chunk, ...] = parameter_a_temp
if (parameter_d is None) and (parameter_d_temp is not None):
parameter_d = torch.zeros(
(
self.acceptor.shape[1],
parameter_d_temp.shape[1],
parameter_d_temp.shape[2],
),
device=self.device,
dtype=torch.float32,
)
if (parameter_d is not None) and (parameter_d_temp is not None):
parameter_d[chunk, ...] = parameter_d_temp
self.logger.info(
f"{self.level3} acceptor -- Progression scale: {initial_scale_value_a} -> {max_scale_value_a}"
)
self.logger.info(
f"{self.level3} donor -- Progression scale: {initial_scale_value_d} -> {max_scale_value_d}"
)
return (
result_a,
result_d,
max_scale_value_a,
initial_scale_value_a,
max_scale_value_d,
initial_scale_value_d,
parameter_a,
parameter_d,
)
@torch.no_grad()
def _filtfilt(
self,
input: torch.Tensor,
butter_a: torch.Tensor,
butter_b: torch.Tensor,
) -> torch.Tensor:
assert butter_a.ndim == 1
assert butter_b.ndim == 1
assert butter_a.shape[0] == butter_b.shape[0]
process_data: torch.Tensor = input.movedim(0, -1).detach().clone()
padding_length = 12 * int(butter_a.shape[0])
left_padding = 2 * process_data[..., 0].unsqueeze(-1) - process_data[
..., 1 : padding_length + 1
].flip(-1)
right_padding = 2 * process_data[..., -1].unsqueeze(-1) - process_data[
..., -(padding_length + 1) : -1
].flip(-1)
process_data_padded = torch.cat(
(left_padding, process_data, right_padding), dim=-1
)
output = ta.functional.filtfilt(
process_data_padded.unsqueeze(0), butter_a, butter_b, clamp=False
).squeeze(0)
output = output[..., padding_length:-padding_length].movedim(-1, 0)
return output
@torch.no_grad()
def _butter_bandpass(
self, low_frequency: float = 5, high_frequency: float = 15, fs: float = 100.0
) -> tuple[torch.Tensor, torch.Tensor]:
import scipy
butter_b_np, butter_a_np = scipy.signal.butter(
4, [low_frequency, high_frequency], btype="bandpass", output="ba", fs=fs
)
butter_a = torch.tensor(butter_a_np, device=self.device, dtype=torch.float32)
butter_b = torch.tensor(butter_b_np, device=self.device, dtype=torch.float32)
return butter_a, butter_b
@torch.no_grad()
def _chunk_iterator(self, array: torch.Tensor, chunk_size: int):
for i in range(0, array.shape[0], chunk_size):
yield array[i : i + chunk_size]
@torch.no_grad()
def heartbeat_scale( # start_position_coefficients: OK
self,
low_frequency: float = 5,
high_frequency: float = 15,
fs: float = 100.0,
apply_to_data: bool = False,
threshold: float | None = 0.5,
start_position_coefficients: int = 0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
assert self.donor_residuum is not None
assert self.acceptor_residuum is not None
butter_a, butter_b = self._butter_bandpass(
low_frequency=low_frequency, high_frequency=high_frequency, fs=fs
)
butter_a, butter_b = self._butter_bandpass(
low_frequency=low_frequency, high_frequency=high_frequency, fs=100.0
)
self.logger.info(f"{self.level3} apply bandpass donor_residuum (filtfilt)")
index_full_dataset: torch.Tensor = torch.arange(
0, self.donor_residuum.shape[1], device=self.device, dtype=torch.int64
)
hb_d = torch.zeros_like(self.donor_residuum[start_position_coefficients:, ...])
for chunk in self._chunk_iterator(index_full_dataset, self.filtfilt_chuck_size):
temp_filtfilt = self._filtfilt(
self.donor_residuum[start_position_coefficients:, chunk, :],
butter_a=butter_a,
butter_b=butter_b,
)
hb_d[:, chunk, :] = temp_filtfilt
hb_d = hb_d[start_position:, ...]
hb_d -= hb_d.mean(dim=0, keepdim=True)
self.logger.info(f"{self.level3} apply bandpass acceptor_residuum (filtfilt)")
index_full_dataset = torch.arange(
0, self.acceptor_residuum.shape[1], device=self.device, dtype=torch.int64
)
hb_a = torch.zeros_like(self.donor_residuum[start_position_coefficients:, ...])
for chunk in self._chunk_iterator(index_full_dataset, self.filtfilt_chuck_size):
temp_filtfilt = self._filtfilt(
self.acceptor_residuum[start_position_coefficients:, chunk, :],
butter_a=butter_a,
butter_b=butter_b,
)
hb_a[:, chunk, :] = temp_filtfilt
hb_a = hb_a[start_position:, ...]
hb_a -= hb_a.mean(dim=0, keepdim=True)
scale = (hb_a * hb_d).sum(dim=0) / (hb_a**2).sum(dim=0)
heartbeat_a = torch.sqrt(scale)
heartbeat_d = 1.0 / (heartbeat_a + 1e-20)
2023-07-10 13:06:54 +02:00
if apply_to_data:
2023-07-09 18:55:12 +02:00
if self.donor is not None:
self.donor *= heartbeat_d.unsqueeze(0)
if self.volume is not None:
self.volume *= heartbeat_d.unsqueeze(0)
if self.acceptor is not None:
self.acceptor *= heartbeat_a.unsqueeze(0)
if self.oxygenation is not None:
self.oxygenation *= heartbeat_a.unsqueeze(0)
if threshold is not None:
self.logger.info(f"{self.level3} calculate mask")
2023-07-10 13:06:54 +02:00
assert self.donor_scale is not None
assert self.acceptor_scale is not None
temp_d = hb_d.std(dim=0) * self.donor_scale.squeeze(0)
temp_d -= temp_d.min()
temp_d /= temp_d.max()
temp_a = hb_a.std(dim=0) * self.acceptor_scale.squeeze(0)
temp_a -= temp_a.min()
temp_a /= temp_a.max()
mask = torch.where(temp_d > threshold, 1.0, 0.0) * torch.where(
temp_a > threshold, 1.0, 0.0
2023-07-09 18:55:12 +02:00
)
else:
mask = None
return heartbeat_a, heartbeat_d, mask
@torch.no_grad()
def measure_heartbeat_frequency( # start_position_coefficients: OK
self,
low_frequency: float = 5,
high_frequency: float = 15,
fs: float = 100.0,
use_input_source: str = "volume",
start_position_coefficients: int = 0,
half_width_frequency_window: float = 3.0, # Hz (on side )
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if use_input_source == "donor":
assert self.donor is not None
hb: torch.Tensor = self.donor[start_position_coefficients:, ...]
elif use_input_source == "acceptor":
assert self.acceptor is not None
hb = self.acceptor[start_position_coefficients:, ...]
elif use_input_source == "volume":
assert self.volume is not None
hb = self.volume[start_position_coefficients:, ...]
else:
assert self.oxygenation is not None
hb = self.oxygenation[start_position_coefficients:, ...]
frequency_axis: torch.Tensor = (
torch.fft.rfftfreq(hb.shape[0]).to(device=self.device) * fs
)
delta_idx = int(
math.ceil(
half_width_frequency_window
/ (float(frequency_axis[1]) - float(frequency_axis[0]))
)
)
idx: torch.Tensor = torch.where(
(frequency_axis >= low_frequency) * (frequency_axis <= high_frequency)
)[0]
power_hb: torch.Tensor = torch.abs(torch.fft.rfft(hb, dim=0)) ** 2
power_hb = power_hb[idx, :, :].argmax(dim=0) + idx[0]
power_hb_low = power_hb - delta_idx
power_hb_low = power_hb_low.clamp(min=0)
power_hb_high = power_hb + delta_idx
power_hb_high = power_hb_high.clamp(max=frequency_axis.shape[0])
return power_hb_low, power_hb_high, frequency_axis
@torch.no_grad()
def measure_heartbeat_power( # start_position_coefficients: OK
self,
use_input_source: str = "donor",
start_position_coefficients: int = 0,
power_hb_low: torch.Tensor | None = None,
power_hb_high: torch.Tensor | None = None,
2023-07-09 19:26:30 +02:00
custom_input: torch.Tensor | None = None,
2023-07-09 18:55:12 +02:00
) -> torch.Tensor:
if use_input_source == "donor":
assert self.donor is not None
hb: torch.Tensor = self.donor[start_position_coefficients:, ...]
elif use_input_source == "acceptor":
assert self.acceptor is not None
hb = self.acceptor[start_position_coefficients:, ...]
elif use_input_source == "volume":
assert self.volume is not None
hb = self.volume[start_position_coefficients:, ...]
2023-07-09 19:26:30 +02:00
elif use_input_source == "custom":
assert custom_input is not None
hb = custom_input[start_position_coefficients:, ...]
2023-07-09 18:55:12 +02:00
else:
assert self.oxygenation is not None
hb = self.oxygenation[start_position_coefficients:, ...]
counter: torch.Tensor = torch.zeros(
(hb.shape[1], hb.shape[2]),
dtype=hb.dtype,
device=self.device,
)
index_full_dataset = torch.arange(
0, hb.shape[1], device=self.device, dtype=torch.int64
)
power_hb: torch.Tensor | None = None
for chunk in self._chunk_iterator(index_full_dataset, self.filtfilt_chuck_size):
temp_power = torch.abs(torch.fft.rfft(hb[:, chunk, :], dim=0)) ** 2
if power_hb is None:
power_hb = torch.zeros(
(temp_power.shape[0], hb.shape[1], temp_power.shape[2]),
dtype=temp_power.dtype,
device=temp_power.device,
)
assert power_hb is not None
power_hb[:, chunk, :] = temp_power
assert power_hb is not None
for pos in range(0, power_hb.shape[0]):
pos_torch = torch.tensor(pos, dtype=torch.int64, device=self.device)
slice_temp = (
(pos_torch >= power_hb_low) * (pos_torch < power_hb_high)
).type(dtype=power_hb.dtype)
power_hb[pos, ...] *= slice_temp
counter += slice_temp
power_hb = power_hb.sum(dim=0) / counter
return power_hb
@torch.no_grad()
def automatic_load( # start_position_coefficients: OK
self,
experiment_id: int = 1,
trial_id: int = 1,
start_position: int = 0,
start_position_coefficients: int = 100,
fs: float = 100.0,
2023-07-10 13:06:54 +02:00
use_regression: bool | None = None,
2023-07-09 18:55:12 +02:00
# Heartbeat
2023-07-09 19:26:30 +02:00
remove_heartbeat: bool = True, # i.e. use SVD
2023-07-09 18:55:12 +02:00
low_frequency: float = 5, # Hz Butter Bandpass Heartbeat
high_frequency: float = 15, # Hz Butter Bandpass Heartbeat
threshold: float | None = 0.5, # For the mask
# Extra exposed parameters:
align: bool = True,
iterations: int = 1, # SVD iterations: Do not touch! Keep at 1
lowrank_method: bool = True,
lowrank_q: int = 6,
remove_heartbeat_mean: bool = False,
remove_heartbeat_linear: bool = False,
bin_size: int = 4,
2023-07-10 13:06:54 +02:00
do_frame_shift: bool | None = None,
2023-07-09 18:55:12 +02:00
half_width_frequency_window: float = 3.0, # Hz (on side ) measure_heartbeat_frequency
mmap_mode: bool = True,
initital_mask_name: str | None = None,
initital_mask_update: bool = True,
2023-07-09 19:26:30 +02:00
initital_mask_roi: bool = False,
gaussian_blur_kernel_size: int | None = 3,
2023-07-09 18:55:12 +02:00
gaussian_blur_sigma: float = 1.0,
bin_size_post: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
self.logger.info(f"{self.level0} start automatic_load")
if use_regression is None:
use_regression = not remove_heartbeat
2023-07-10 13:06:54 +02:00
if do_frame_shift is None:
do_frame_shift = not remove_heartbeat
2023-07-09 18:55:12 +02:00
initital_mask: torch.Tensor | None = None
2023-07-10 13:06:54 +02:00
if (initital_mask_name is not None) and os.path.isfile(initital_mask_name):
2023-07-09 18:55:12 +02:00
initital_mask = torch.tensor(
np.load(initital_mask_name), device=self.device, dtype=torch.float32
)
self.logger.info(f"{self.level1} try to load previous mask: found")
else:
self.logger.info(f"{self.level1} try to load previous mask: NOT found")
self.logger.info(f"{self.level1} start cleaned_load_data")
self.cleaned_load_data(
experiment_id=experiment_id,
trial_id=trial_id,
remove_heartbeat=remove_heartbeat,
remove_mean=not use_regression,
remove_linear=not use_regression,
enable_secondary_data=use_regression,
align=align,
iterations=iterations,
lowrank_method=lowrank_method,
lowrank_q=lowrank_q,
remove_heartbeat_mean=remove_heartbeat_mean,
remove_heartbeat_linear=remove_heartbeat_linear,
bin_size=bin_size,
do_frame_shift=do_frame_shift,
mmap_mode=mmap_mode,
initital_mask=initital_mask,
start_position_coefficients=start_position_coefficients,
)
heartbeat_a: torch.Tensor | None = None
heartbeat_d: torch.Tensor | None = None
mask: torch.Tensor | None = None
power_hb_low: torch.Tensor | None = None
power_hb_high: torch.Tensor | None = None
2023-07-10 13:06:54 +02:00
if remove_heartbeat:
2023-07-09 18:55:12 +02:00
self.logger.info(f"{self.level1} remove heart beat (heartbeat_scale)")
heartbeat_a, heartbeat_d, mask = self.heartbeat_scale(
low_frequency=low_frequency,
high_frequency=high_frequency,
fs=fs,
apply_to_data=False,
threshold=threshold,
start_position_coefficients=start_position_coefficients,
)
else:
self.logger.info(
f"{self.level1} measure heart rate (measure_heartbeat_frequency)"
)
2023-07-10 13:06:54 +02:00
assert self.volume is not None
2023-07-09 18:55:12 +02:00
(
power_hb_low,
power_hb_high,
_,
) = self.measure_heartbeat_frequency(
low_frequency=low_frequency,
high_frequency=high_frequency,
fs=fs,
use_input_source="volume",
start_position_coefficients=start_position_coefficients,
half_width_frequency_window=half_width_frequency_window,
)
2023-07-10 13:06:54 +02:00
if use_regression:
2023-07-09 18:55:12 +02:00
self.logger.info(f"{self.level1} use regression")
(
result_a,
result_d,
_,
_,
_,
_,
_,
_,
) = self.remove_other_signals(
start_position_coefficients=start_position_coefficients,
match_iterations=25,
export_parameters=False,
)
result_a = result_a[start_position:, ...]
result_d = result_d[start_position:, ...]
else:
self.logger.info(f"{self.level1} don't use regression")
assert self.acceptor is not None
assert self.donor is not None
result_a = self.acceptor[start_position:, ...].clone()
result_d = self.donor[start_position:, ...].clone()
if mask is not None:
result_a *= mask.unsqueeze(0)
result_d *= mask.unsqueeze(0)
if remove_heartbeat is False:
self.logger.info(
f"{self.level1} donor: measure heart beat spectral power (measure_heartbeat_power)"
)
temp_d = self.measure_heartbeat_power(
use_input_source="donor",
start_position_coefficients=start_position_coefficients,
power_hb_low=power_hb_low,
power_hb_high=power_hb_high,
)
self.logger.info(
f"{self.level1} acceptor: measure heart beat spectral power (measure_heartbeat_power)"
)
temp_a = self.measure_heartbeat_power(
use_input_source="acceptor",
start_position_coefficients=start_position_coefficients,
power_hb_low=power_hb_low,
power_hb_high=power_hb_high,
)
scale = temp_d / (temp_a + 1e-20)
heartbeat_a = torch.sqrt(scale)
heartbeat_d = 1.0 / (heartbeat_a + 1e-20)
self.logger.info(f"{self.level1} scale acceptor and donor signals")
if heartbeat_a is not None:
result_a *= heartbeat_a.unsqueeze(0)
if heartbeat_d is not None:
result_d *= heartbeat_d.unsqueeze(0)
if mask is not None:
2023-07-10 13:06:54 +02:00
if initital_mask_update:
2023-07-09 18:55:12 +02:00
self.logger.info(f"{self.level1} update inital mask")
if initital_mask is None:
initital_mask = mask.clone()
else:
initital_mask *= mask
2023-07-10 13:06:54 +02:00
if (initital_mask_roi) and (initital_mask is not None):
2023-07-09 18:55:12 +02:00
self.logger.info(f"{self.level1} enter roi mask drawing modus")
yes_choices = ["yes", "y"]
contiue_roi: bool = True
image: np.ndarray = (result_a - result_d)[0, ...].cpu().numpy()
image[initital_mask.cpu().numpy() == 0] = float("NaN")
2023-07-10 13:06:54 +02:00
while contiue_roi:
2023-07-09 18:55:12 +02:00
user_input = input(
"Mask: Do you want to remove more pixel (yes/no)? "
)
if user_input.lower() in yes_choices:
plt.imshow(image, cmap="hot")
plt.title("Select a region for removal")
temp_roi = RoiPoly(color="g")
plt.show()
if len(temp_roi.x) > 0:
new_mask = temp_roi.get_mask(image)
new_mask_np = new_mask.astype(np.float32)
new_mask_torch = torch.tensor(
new_mask_np,
device=self.device,
dtype=torch.float32,
)
plt.imshow(image, cmap="hot")
temp_roi.display_roi()
plt.title("Selected region for removal")
print("Please close figure when ready...")
plt.show()
user_input = input(
"Mask: Remove these pixel (yes/no)? "
)
if user_input.lower() in yes_choices:
initital_mask *= 1.0 - new_mask_torch
image[new_mask] = float("NaN")
else:
contiue_roi = False
if initital_mask_name is not None:
self.logger.info(f"{self.level1} save mask")
np.save(initital_mask_name, initital_mask.cpu().numpy())
self.logger.info(f"{self.level0} end automatic_load")
2023-07-10 13:06:54 +02:00
# result = (1.0 + result_a) / (1.0 + result_d)
result = 1.0 + result_a - result_d
2023-07-09 19:26:30 +02:00
2023-07-09 18:55:12 +02:00
if (gaussian_blur_kernel_size is not None) and (gaussian_blur_kernel_size > 0):
gaussian_blur = tv.transforms.GaussianBlur(
kernel_size=[gaussian_blur_kernel_size, gaussian_blur_kernel_size],
sigma=gaussian_blur_sigma,
)
result = gaussian_blur(result)
if (bin_size_post is not None) and (bin_size_post > 1):
pool = torch.nn.AvgPool2d(
(bin_size_post, bin_size_post), stride=(bin_size_post, bin_size_post)
)
result = pool(result)
if mask is not None:
mask = (
(pool(mask.unsqueeze(0)) > 0).type(dtype=torch.float32).squeeze(0)
)
return result, mask
if __name__ == "__main__":
from Anime import Anime
# path: str = "/data_1/robert/2021-05-05/M3852M/raw"
path: str = "/data_1/robert/2021-05-21/M3852M/raw"
initital_mask_name: str | None = None
initital_mask_update: bool = True
initital_mask_roi: bool = False # default: True
experiment_id: int = 2
trial_id: int = 180
start_position: int = 0
start_position_coefficients: int = 100
remove_heartbeat: bool = True # i.e. use SVD
bin_size: int = 4
2023-07-10 13:06:54 +02:00
threshold: float | None = 0.05 # Between 0 and 1.0
2023-07-09 18:55:12 +02:00
example_position_x: int = 280
example_position_y: int = 440
display_logging_messages: bool = False
save_logging_messages: bool = False
show_example_timeseries: bool = True
play_movie: bool = True
# Post data processing modifiations
gaussian_blur_kernel_size: int | None = 3
gaussian_blur_sigma: float = 1.0
bin_size_post: int | None = None
# ------------------------
example_position_x = example_position_x // bin_size
example_position_y = example_position_y // bin_size
if bin_size_post is not None:
example_position_x = example_position_x // bin_size_post
example_position_y = example_position_y // bin_size_post
torch_device: torch.device = torch.device(
"cuda:0" if torch.cuda.is_available() else "cpu"
)
af = DataContainer(
path=path,
device=torch_device,
display_logging_messages=display_logging_messages,
save_logging_messages=save_logging_messages,
)
result, mask = af.automatic_load(
experiment_id=experiment_id,
trial_id=trial_id,
start_position=start_position,
remove_heartbeat=remove_heartbeat, # i.e. use SVD
bin_size=bin_size,
initital_mask_name=initital_mask_name,
initital_mask_update=initital_mask_update,
initital_mask_roi=initital_mask_roi,
start_position_coefficients=start_position_coefficients,
gaussian_blur_kernel_size=gaussian_blur_kernel_size,
gaussian_blur_sigma=gaussian_blur_sigma,
bin_size_post=bin_size_post,
2023-07-10 13:06:54 +02:00
threshold=threshold,
2023-07-09 18:55:12 +02:00
)
2023-07-10 13:06:54 +02:00
if show_example_timeseries:
2023-07-09 18:55:12 +02:00
plt.plot(result[:, example_position_x, example_position_y].cpu())
plt.show()
2023-07-10 13:06:54 +02:00
if play_movie:
2023-07-09 18:55:12 +02:00
ani = Anime()
2023-07-10 13:06:54 +02:00
ani.show(
result - 1.0, mask=mask, vmin_scale=0.5, vmax_scale=0.5
) # , vmin=0.98) # , vmin=1.0, vmax_scale=1.0)