From cc54cf1a2972bd8347f19093782d783c13f51f00 Mon Sep 17 00:00:00 2001 From: David Rotermund <54365609+davrot@users.noreply.github.com> Date: Wed, 28 Feb 2024 16:13:11 +0100 Subject: [PATCH] Delete functions/DataContainer.py --- functions/DataContainer.py | 1857 ------------------------------------ 1 file changed, 1857 deletions(-) delete mode 100644 functions/DataContainer.py diff --git a/functions/DataContainer.py b/functions/DataContainer.py deleted file mode 100644 index ccc7233..0000000 --- a/functions/DataContainer.py +++ /dev/null @@ -1,1857 +0,0 @@ -# pip install roipoly natsort numpy matplotlib -# Also install: torch torchaudio torchvision -# (for details see https://pytorch.org/get-started/locally/ ) -# Tested on Python 3.11 - -import glob -import json -import logging -import math -import os -from datetime import datetime - -import matplotlib.pyplot as plt -import natsort -import numpy as np -import torch -import torchaudio as ta -import torchvision as tv -from roipoly import RoiPoly - -from functions.ImageAlignment import ImageAlignment - - -class DataContainer(torch.nn.Module): - ref_image_acceptor: torch.Tensor | None = None - ref_image_donor: torch.Tensor | None = None - - acceptor: torch.Tensor | None = None - donor: torch.Tensor | None = None - oxygenation: torch.Tensor | None = None - volume: torch.Tensor | None = None - - acceptor_whiten_mean: torch.Tensor | None = None - acceptor_whiten_k: torch.Tensor | None = None - acceptor_eigenvalues: torch.Tensor | None = None - acceptor_residuum: torch.Tensor | None = None - - donor_whiten_mean: torch.Tensor | None = None - donor_whiten_k: torch.Tensor | None = None - donor_eigenvalues: torch.Tensor | None = None - donor_residuum: torch.Tensor | None = None - - oxygenation_whiten_mean: torch.Tensor | None = None - oxygenation_whiten_k: torch.Tensor | None = None - oxygenation_eigenvalues: torch.Tensor | None = None - oxygenation_residuum: torch.Tensor | None = None - - volume_whiten_mean: torch.Tensor | None = None - volume_whiten_k: torch.Tensor | None = None - volume_eigenvalues: torch.Tensor | None = None - volume_residuum: torch.Tensor | None = None - - acceptor_scale: torch.Tensor | None = None - donor_scale: torch.Tensor | None = None - oxygenation_scale: torch.Tensor | None = None - volume_scale: torch.Tensor | None = None - - # ------- - image_alignment: ImageAlignment - - acceptor_index: int - donor_index: int - oxygenation_index: int - volume_index: int - - path: str - channels: list[str] - device: torch.device - - batch_size: int = 200 - - fill_value: float = -0.1 - - filtfilt_chuck_size: int = 10 - - level0 = str("=") - level1 = str("==") - level2 = str("===") - level3 = str("====") - - @torch.no_grad() - def __init__( - self, - path: str, - device: torch.device, - display_logging_messages: bool = False, - save_logging_messages: bool = False, - ) -> None: - super().__init__() - self.device = device - - assert path is not None - self.path = path - now = datetime.now() - dt_string_filename = now.strftime("%Y_%m_%d_%H_%M_%S") - - self.logger = logging.getLogger("DataContainer") - self.logger.setLevel(logging.DEBUG) - - if save_logging_messages: - time_format = "%b %-d %Y %H:%M:%S" - logformat = "%(asctime)s %(message)s" - file_formatter = logging.Formatter(fmt=logformat, datefmt=time_format) - - file_handler = logging.FileHandler(f"log_{dt_string_filename}.txt") - file_handler.setLevel(logging.INFO) - file_handler.setFormatter(file_formatter) - self.logger.addHandler(file_handler) - - if display_logging_messages: - time_format = "%b %-d %Y %H:%M:%S" - logformat = "%(asctime)s %(message)s" - stream_formatter = logging.Formatter(fmt=logformat, datefmt=time_format) - - stream_handler = logging.StreamHandler() - stream_handler.setLevel(logging.INFO) - stream_handler.setFormatter(stream_formatter) - self.logger.addHandler(stream_handler) - - file_input_ref_image = self._find_ref_image_file() - - data = np.load(file_input_ref_image, mmap_mode="r") - ref_image = torch.tensor( - data[:, :, data.shape[2] // 2, :].astype(np.float32), - device=self.device, - dtype=torch.float32, - ) - - json_postfix: str = "_meta.txt" - found_name_json: str = file_input_ref_image.replace(".npy", json_postfix) - - assert os.path.isfile(found_name_json) - - with open(found_name_json, "r") as file_handle: - metadata = json.load(file_handle) - self.channels = metadata["channelKey"] - - self.acceptor_index = self.channels.index("acceptor") - self.donor_index = self.channels.index("donor") - self.oxygenation_index = self.channels.index("oxygenation") - self.volume_index = self.channels.index("volume") - - self.ref_image_acceptor: torch.Tensor = ref_image[:, :, self.acceptor_index] - self.ref_image_donor: torch.Tensor = ref_image[:, :, self.donor_index] - - self.image_alignment = ImageAlignment( - default_dtype=torch.float32, device=self.device - ) - - @torch.no_grad() - def get_trials(self, experiment_id: int) -> torch.Tensor: - filename_np: str = os.path.join( - self.path, - f"Exp{experiment_id:03d}_Trial*_Part001.npy", - ) - - list_str = glob.glob(filename_np) - list_int: list[int] = [] - for i in range(0, len(list_str)): - list_int.append(int(list_str[i].split("_Trial")[-1].split("_Part")[0])) - list_int = sorted(list_int) - return torch.tensor(list_int).unique() - - @torch.no_grad() - def get_experiments( - self, - ) -> torch.Tensor: - filename_np: str = os.path.join( - self.path, - "Exp*_Part001.npy", - ) - - list_str = glob.glob(filename_np) - list_int: list[int] = [] - for i in range(0, len(list_str)): - list_int.append(int(list_str[i].split("Exp")[-1].split("_Trial")[0])) - list_int = sorted(list_int) - - return torch.tensor(list_int).unique() - - @torch.no_grad() - def load_data( # start_position_coefficients: OK - self, - experiment_id: int, - trial_id: int, - align: bool = True, - enable_secondary_data: bool = True, - mmap_mode: bool = True, - start_position_coefficients: int = 0, - ): - self.acceptor = None - self.donor = None - self.oxygenation = None - self.volume = None - - part_id: int = 1 - filename_np: str = os.path.join( - self.path, - f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}.npy", - ) - - filename_meta: str = os.path.join( - self.path, - f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}_meta.txt", - ) - - while (os.path.isfile(filename_np)) and (os.path.isfile(filename_meta)): - self.logger.info(f"{self.level3} work in {filename_np}") - # Check if channel asignment is still okay - with open(filename_meta, "r") as file_handle: - metadata = json.load(file_handle) - channels = metadata["channelKey"] - - assert len(channels) == len(self.channels) - for i in range(0, len(channels)): - assert channels[i] == self.channels[i] - - # Load the data... - self.logger.info(f"{self.level3} np.load") - if mmap_mode: - temp: np.ndarray = np.load(filename_np, mmap_mode="r") - else: - temp = np.load(filename_np) - - self.logger.info(f"{self.level3} organize acceptor") - if self.acceptor is None: - self.acceptor = torch.tensor( - temp[:, :, :, self.acceptor_index].astype(np.float32), - device=self.device, - dtype=torch.float32, - ) - - else: - assert self.acceptor is not None - assert self.acceptor.ndim + 1 == temp.ndim - assert self.acceptor.shape[0] == temp.shape[0] - assert self.acceptor.shape[1] == temp.shape[1] - # assert self.acceptor.shape[2] == temp.shape[2] - assert temp.shape[3] == 4 - - self.acceptor = torch.cat( - ( - self.acceptor, - torch.tensor( - temp[:, :, :, self.acceptor_index].astype(np.float32), - device=self.device, - dtype=torch.float32, - ), - ), - dim=2, - ) - - self.logger.info(f"{self.level3} organize donor") - if self.donor is None: - self.donor = torch.tensor( - temp[:, :, :, self.donor_index].astype(np.float32), - device=self.device, - dtype=torch.float32, - ) - - else: - assert self.donor is not None - assert self.donor.ndim + 1 == temp.ndim - assert self.donor.shape[0] == temp.shape[0] - assert self.donor.shape[1] == temp.shape[1] - # assert self.donor.shape[2] == temp.shape[2] - assert temp.shape[3] == 4 - - self.donor = torch.cat( - ( - self.donor, - torch.tensor( - temp[:, :, :, self.donor_index].astype(np.float32), - device=self.device, - dtype=torch.float32, - ), - ), - dim=2, - ) - - if enable_secondary_data: - self.logger.info(f"{self.level3} organize oxygenation") - if self.oxygenation is None: - self.oxygenation = torch.tensor( - temp[:, :, :, self.oxygenation_index].astype(np.float32), - device=self.device, - dtype=torch.float32, - ) - else: - assert self.oxygenation is not None - assert self.oxygenation.ndim + 1 == temp.ndim - assert self.oxygenation.shape[0] == temp.shape[0] - assert self.oxygenation.shape[1] == temp.shape[1] - # assert self.oxygenation.shape[2] == temp.shape[2] - assert temp.shape[3] == 4 - - self.oxygenation = torch.cat( - ( - self.oxygenation, - torch.tensor( - temp[:, :, :, self.oxygenation_index].astype( - np.float32 - ), - device=self.device, - dtype=torch.float32, - ), - ), - dim=2, - ) - - if self.volume is None: - self.logger.info(f"{self.level3} organize volume") - self.volume = torch.tensor( - temp[:, :, :, self.volume_index].astype(np.float32), - device=self.device, - dtype=torch.float32, - ) - else: - assert self.volume is not None - assert self.volume.ndim + 1 == temp.ndim - assert self.volume.shape[0] == temp.shape[0] - assert self.volume.shape[1] == temp.shape[1] - # assert self.volume.shape[2] == temp.shape[2] - assert temp.shape[3] == 4 - - self.volume = torch.cat( - ( - self.volume, - torch.tensor( - temp[:, :, :, self.volume_index].astype(np.float32), - device=self.device, - dtype=torch.float32, - ), - ), - dim=2, - ) - - part_id += 1 - filename_np = os.path.join( - self.path, - f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}.npy", - ) - filename_meta = os.path.join( - self.path, - f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}_meta.txt", - ) - - self.logger.info(f"{self.level3} move axis") - assert self.acceptor is not None - assert self.donor is not None - self.acceptor = self.acceptor.moveaxis(-1, 0) - self.donor = self.donor.moveaxis(-1, 0) - - if enable_secondary_data: - assert self.oxygenation is not None - assert self.volume is not None - self.oxygenation = self.oxygenation.moveaxis(-1, 0) - self.volume = self.volume.moveaxis(-1, 0) - - if align: - self.logger.info(f"{self.level3} move intra timeseries") - self._move_intra_timeseries( - enable_secondary_data=enable_secondary_data, - start_position_coefficients=start_position_coefficients, - ) - self.logger.info(f"{self.level3} rotate inter timeseries") - self._rotate_inter_timeseries( - enable_secondary_data=enable_secondary_data, - start_position_coefficients=start_position_coefficients, - ) - self.logger.info(f"{self.level3} move inter timeseries") - self._move_inter_timeseries( - enable_secondary_data=enable_secondary_data, - start_position_coefficients=start_position_coefficients, - ) - - # reset svd - self.acceptor_whiten_mean = None - self.acceptor_whiten_k = None - self.acceptor_eigenvalues = None - - self.donor_whiten_mean = None - self.donor_whiten_k = None - self.donor_eigenvalues = None - - self.oxygenation_whiten_mean = None - self.oxygenation_whiten_k = None - self.oxygenation_eigenvalues = None - - self.volume_whiten_mean = None - self.volume_whiten_k = None - self.volume_eigenvalues = None - - @torch.no_grad() - def _find_ref_image_file(self) -> str: - filename_postfix: str = "Exp*.npy" - new_list = glob.glob(os.path.join(self.path, filename_postfix)) - new_list = natsort.natsorted(new_list) - - found_name: str | None = None - for filename in new_list: - if (filename.find("Trial") != -1) and (filename.find("Part") != -1): - found_name = filename - break - assert found_name is not None - - return found_name - - @torch.no_grad() - def _calculate_translation( # start_position_coefficients: OK - self, - input: torch.Tensor, - reference_image: torch.Tensor, - start_position_coefficients: int = 0, - ) -> torch.Tensor: - tvec = torch.zeros((input.shape[0], 2)) - - data_loader = torch.utils.data.DataLoader( - torch.utils.data.TensorDataset(input[start_position_coefficients:, ...]), - batch_size=self.batch_size, - shuffle=False, - ) - start_position: int = 0 - for input_batch in data_loader: - assert len(input_batch) == 1 - - end_position = start_position + input_batch[0].shape[0] - - tvec_temp = self.image_alignment.dry_run_translation( - input=input_batch[0], - new_reference_image=reference_image, - ) - - assert tvec_temp is not None - - tvec[start_position:end_position, :] = tvec_temp - - start_position += input_batch[0].shape[0] - - tvec = torch.round(torch.median(tvec, dim=0)[0]) - return tvec - - @torch.no_grad() - def _calculate_rotation( # start_position_coefficients: OK - self, - input: torch.Tensor, - reference_image: torch.Tensor, - start_position_coefficients: int = 0, - ) -> torch.Tensor: - angle = torch.zeros((input.shape[0])) - - data_loader = torch.utils.data.DataLoader( - torch.utils.data.TensorDataset(input[start_position_coefficients:, ...]), - batch_size=self.batch_size, - shuffle=False, - ) - start_position: int = 0 - for input_batch in data_loader: - assert len(input_batch) == 1 - - end_position = start_position + input_batch[0].shape[0] - - angle_temp = self.image_alignment.dry_run_angle( - input=input_batch[0], - new_reference_image=reference_image, - ) - - assert angle_temp is not None - - angle[start_position:end_position] = angle_temp - - start_position += input_batch[0].shape[0] - - angle = torch.where(angle >= 180, 360.0 - angle, angle) - angle = torch.where(angle <= -180, 360.0 + angle, angle) - angle = torch.median(angle, dim=0)[0] - - return angle - - @torch.no_grad() - def _move_intra_timeseries( # start_position_coefficients: OK - self, - enable_secondary_data: bool = True, - start_position_coefficients: int = 0, - ) -> None: - # donor_volume - assert self.donor is not None - assert self.ref_image_donor is not None - tvec_donor = self._calculate_translation( - self.donor, - self.ref_image_donor, - start_position_coefficients=start_position_coefficients, - ) - - self.donor = tv.transforms.functional.affine( - img=self.donor, - angle=0, - translate=[tvec_donor[1], tvec_donor[0]], - scale=1.0, - shear=0, - fill=self.fill_value, - ) - - if enable_secondary_data: - assert self.volume is not None - self.volume = tv.transforms.functional.affine( - img=self.volume, - angle=0, - translate=[tvec_donor[1], tvec_donor[0]], - scale=1.0, - shear=0, - fill=self.fill_value, - ) - - # acceptor_oxy - assert self.acceptor is not None - assert self.ref_image_acceptor is not None - tvec_acceptor = self._calculate_translation( - self.acceptor, - self.ref_image_acceptor, - start_position_coefficients=start_position_coefficients, - ) - - self.acceptor = tv.transforms.functional.affine( - img=self.acceptor, - angle=0, - translate=[tvec_acceptor[1], tvec_acceptor[0]], - scale=1.0, - shear=0, - fill=self.fill_value, - ) - if enable_secondary_data: - assert self.oxygenation is not None - self.oxygenation = tv.transforms.functional.affine( - img=self.oxygenation, - angle=0, - translate=[tvec_acceptor[1], tvec_acceptor[0]], - scale=1.0, - shear=0, - fill=self.fill_value, - ) - - @torch.no_grad() - def _move_inter_timeseries( # start_position_coefficients: OK - self, - enable_secondary_data: bool = True, - start_position_coefficients: int = 0, - ) -> None: - # acceptor_oxy - assert self.acceptor is not None - assert self.ref_image_donor is not None - tvec = self._calculate_translation( - self.acceptor, - self.ref_image_donor, - start_position_coefficients=start_position_coefficients, - ) - - self.acceptor = tv.transforms.functional.affine( - img=self.acceptor, - angle=0, - translate=[tvec[1], tvec[0]], - scale=1.0, - shear=0, - fill=self.fill_value, - ) - - if enable_secondary_data: - assert self.oxygenation is not None - self.oxygenation = tv.transforms.functional.affine( - img=self.oxygenation, - angle=0, - translate=[tvec[1], tvec[0]], - scale=1.0, - shear=0, - fill=self.fill_value, - ) - - @torch.no_grad() - def _rotate_inter_timeseries( # start_position_coefficients: OK - self, - enable_secondary_data: bool = True, - start_position_coefficients: int = 0, - ) -> None: - # acceptor_oxy - assert self.acceptor is not None - assert self.ref_image_donor is not None - angle = self._calculate_rotation( - self.acceptor, - self.ref_image_donor, - start_position_coefficients=start_position_coefficients, - ) - - self.acceptor = tv.transforms.functional.affine( - img=self.acceptor, - angle=-float(angle), - translate=[0, 0], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=self.fill_value, - ) - - if enable_secondary_data: - assert self.oxygenation is not None - self.oxygenation = tv.transforms.functional.affine( - img=self.oxygenation, - angle=-float(angle), - translate=[0, 0], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=self.fill_value, - ) - - @torch.no_grad() - def _svd( # start_position_coefficients: OK - self, - input: torch.Tensor, - lowrank_method: bool = True, - lowrank_q: int = 6, - start_position_coefficients: int = 0, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - selection = torch.flatten( - input[start_position_coefficients:, ...].clone().movedim(0, -1), - start_dim=0, - end_dim=1, - ) - whiten_mean = torch.mean(selection, dim=-1) - selection -= whiten_mean.unsqueeze(-1) - whiten_mean = whiten_mean.reshape((input.shape[1], input.shape[2])) - - if lowrank_method is False: - svd_u, svd_s, _ = torch.linalg.svd(selection, full_matrices=False) - else: - svd_u, svd_s, _ = torch.svd_lowrank(selection, q=lowrank_q) - - whiten_k = ( - torch.sign(svd_u[0, :]).unsqueeze(0) * svd_u / (svd_s.unsqueeze(0) + 1e-20) - ) - whiten_k = whiten_k.reshape((input.shape[1], input.shape[2], svd_s.shape[0])) - eigenvalues = svd_s - - return whiten_mean, whiten_k, eigenvalues - - @torch.no_grad() - def _to_remove( # start_position_coefficients: OK - self, - input: torch.Tensor | None, - lowrank_method: bool = True, - lowrank_q: int = 6, - start_position_coefficients: int = 0, - ) -> tuple[torch.Tensor, torch.Tensor, float, torch.Tensor, torch.Tensor]: - assert input is not None - - id: int = 0 - ( - input_whiten_mean, - input_whiten_k, - input_eigenvalues, - ) = self._svd( - input, - lowrank_method=lowrank_method, - lowrank_q=lowrank_q, - start_position_coefficients=start_position_coefficients, - ) - - assert input_whiten_mean is not None - assert input_whiten_k is not None - assert input_eigenvalues is not None - - eigenvalue = float(input_eigenvalues[id]) - whiten_mean = input_whiten_mean - whiten_k = input_whiten_k[:, :, 0] - - data = (input - input_whiten_mean.unsqueeze(0)) * input_whiten_k[ - :, :, id - ].unsqueeze(0) - - input_svd = data.sum(dim=-1).sum(dim=-1).unsqueeze(-1).unsqueeze(-1) - factor = (data * input_svd).sum(dim=0, keepdim=True) / (input_svd**2).sum( - dim=0, keepdim=True - ) - to_remove = input_svd * factor - to_remove /= input_whiten_k[:, :, id].unsqueeze(0) + 1e-20 - to_remove += input_whiten_mean.unsqueeze(0) - - output = input - to_remove - - return output, to_remove, eigenvalue, whiten_mean, whiten_k - - @torch.no_grad() - def acceptor_svd_remove( # start_position_coefficients: OK - self, - lowrank_method: bool = True, - lowrank_q: int = 6, - start_position_coefficients: int = 0, - ) -> tuple[torch.Tensor, float, torch.Tensor, torch.Tensor]: - self.acceptor, to_remove, eigenvalue, whiten_mean, whiten_k = self._to_remove( - input=self.acceptor, - lowrank_method=lowrank_method, - lowrank_q=lowrank_q, - start_position_coefficients=start_position_coefficients, - ) - - return to_remove, eigenvalue, whiten_mean, whiten_k - - @torch.no_grad() - def donor_svd_remove( # start_position_coefficients: OK - self, - lowrank_method: bool = True, - lowrank_q: int = 6, - start_position_coefficients: int = 0, - ) -> tuple[torch.Tensor, float, torch.Tensor, torch.Tensor]: - self.donor, to_remove, eigenvalue, whiten_mean, whiten_k = self._to_remove( - input=self.donor, - lowrank_method=lowrank_method, - lowrank_q=lowrank_q, - start_position_coefficients=start_position_coefficients, - ) - - return to_remove, eigenvalue, whiten_mean, whiten_k - - @torch.no_grad() - def volume_svd_remove( # start_position_coefficients: OK - self, - lowrank_method: bool = True, - lowrank_q: int = 6, - start_position_coefficients: int = 0, - ) -> tuple[torch.Tensor, float, torch.Tensor, torch.Tensor]: - self.volume, to_remove, eigenvalue, whiten_mean, whiten_k = self._to_remove( - input=self.volume, - lowrank_method=lowrank_method, - lowrank_q=lowrank_q, - start_position_coefficients=start_position_coefficients, - ) - - return to_remove, eigenvalue, whiten_mean, whiten_k - - @torch.no_grad() - def oxygenation_svd_remove( # start_position_coefficients: OK - self, - lowrank_method: bool = True, - lowrank_q: int = 6, - start_position_coefficients: int = 0, - ) -> tuple[torch.Tensor, float, torch.Tensor, torch.Tensor]: - ( - self.oxygenation, - to_remove, - eigenvalue, - whiten_mean, - whiten_k, - ) = self._to_remove( - input=self.oxygenation, - lowrank_method=lowrank_method, - lowrank_q=lowrank_q, - start_position_coefficients=start_position_coefficients, - ) - - return to_remove, eigenvalue, whiten_mean, whiten_k - - @torch.no_grad() - def remove_heartbeat( # start_position_coefficients: OK - self, - iterations: int = 2, - lowrank_method: bool = True, - lowrank_q: int = 6, - enable_secondary_data: bool = True, - start_position_coefficients: int = 0, - ): - self.acceptor_residuum = None - self.donor_residuum = None - self.oxygenation_residuum = None - self.volume_residuum = None - - for _ in range(0, iterations): - to_remove, _, _, _ = self.acceptor_svd_remove( - lowrank_method=lowrank_method, - lowrank_q=lowrank_q, - start_position_coefficients=start_position_coefficients, - ) - if self.acceptor_residuum is None: - self.acceptor_residuum = to_remove - else: - self.acceptor_residuum += to_remove - - to_remove, _, _, _ = self.donor_svd_remove( - lowrank_method=lowrank_method, - lowrank_q=lowrank_q, - start_position_coefficients=start_position_coefficients, - ) - if self.donor_residuum is None: - self.donor_residuum = to_remove - else: - self.donor_residuum += to_remove - - if enable_secondary_data: - to_remove, _, _, _ = self.volume_svd_remove( - lowrank_method=lowrank_method, - lowrank_q=lowrank_q, - start_position_coefficients=start_position_coefficients, - ) - if self.volume_residuum is None: - self.volume_residuum = to_remove - else: - self.volume_residuum += to_remove - - to_remove, _, _, _ = self.oxygenation_svd_remove( - lowrank_method=lowrank_method, - lowrank_q=lowrank_q, - start_position_coefficients=start_position_coefficients, - ) - if self.oxygenation_residuum is None: - self.oxygenation_residuum = to_remove - else: - self.oxygenation_residuum += to_remove - - @torch.no_grad() - def remove_mean_data(self, enable_secondary_data: bool = True) -> None: - assert self.donor is not None - assert self.acceptor is not None - self.donor -= self.donor.mean(dim=0, keepdim=True) - self.acceptor -= self.acceptor.mean(dim=0, keepdim=True) - - if enable_secondary_data: - assert self.volume is not None - assert self.oxygenation is not None - self.volume -= self.volume.mean(dim=0, keepdim=True) - self.oxygenation -= self.oxygenation.mean(dim=0, keepdim=True) - - @torch.no_grad() - def remove_mean_residuum(self, enable_secondary_data: bool = True) -> None: - assert self.donor_residuum is not None - assert self.acceptor_residuum is not None - self.donor_residuum -= self.donor_residuum.mean(dim=0, keepdim=True) - self.acceptor_residuum -= self.acceptor_residuum.mean(dim=0, keepdim=True) - - if enable_secondary_data: - assert self.volume_residuum is not None - assert self.oxygenation_residuum is not None - self.volume_residuum -= self.volume_residuum.mean(dim=0, keepdim=True) - self.oxygenation_residuum -= self.oxygenation_residuum.mean( - dim=0, keepdim=True - ) - - @torch.no_grad() - def _calculate_linear_trend_data(self, input: torch.Tensor) -> torch.Tensor: - assert input.ndim == 3 - time_beam: torch.Tensor = torch.arange( - 0, input.shape[0], dtype=torch.float32, device=self.device - ) - time_beam -= time_beam.mean() - input_mean = input.mean(dim=0, keepdim=True) - factor = (time_beam.unsqueeze(-1).unsqueeze(-1) * (input - input_mean)).sum( - dim=0, keepdim=True - ) / (time_beam**2).sum(dim=0, keepdim=True).unsqueeze(-1).unsqueeze(-1) - - output = factor * time_beam.unsqueeze(-1).unsqueeze(-1) + input_mean - - return output - - @torch.no_grad() - def remove_linear_trend_data(self, enable_secondary_data: bool = True) -> None: - assert self.donor is not None - assert self.acceptor is not None - self.donor -= self._calculate_linear_trend_data(self.donor) - self.acceptor -= self._calculate_linear_trend_data(self.acceptor) - - if enable_secondary_data: - assert self.volume is not None - assert self.oxygenation is not None - self.volume -= self._calculate_linear_trend_data(self.volume) - self.oxygenation -= self._calculate_linear_trend_data(self.oxygenation) - - @torch.no_grad() - def remove_linear_trend_residuum( - self, - enable_secondary_data: bool = True, - ) -> None: - assert self.donor_residuum is not None - assert self.acceptor_residuum is not None - - self.donor_residuum -= self._calculate_linear_trend_data(self.donor_residuum) - self.acceptor_residuum -= self._calculate_linear_trend_data( - self.acceptor_residuum - ) - - if enable_secondary_data: - assert self.volume_residuum is not None - assert self.oxygenation_residuum is not None - self.volume_residuum -= self._calculate_linear_trend_data( - self.volume_residuum - ) - self.oxygenation_residuum -= self._calculate_linear_trend_data( - self.oxygenation_residuum - ) - - @torch.no_grad() - def frame_shift( - self, - enable_secondary_data: bool = True, - ): - assert self.donor is not None - assert self.acceptor is not None - self.donor = self.donor[1:, :, :] - self.acceptor = self.acceptor[1:, :, :] - - if enable_secondary_data: - assert self.volume is not None - assert self.oxygenation is not None - self.volume = (self.volume[1:, :, :] + self.volume[:-1, :, :]) / 2.0 - self.oxygenation = ( - self.oxygenation[1:, :, :] + self.oxygenation[:-1, :, :] - ) / 2.0 - - if self.donor_residuum is not None: - self.donor_residuum = self.donor_residuum[1:, :, :] - - if self.acceptor_residuum is not None: - self.acceptor_residuum = self.acceptor_residuum[1:, :, :] - - if enable_secondary_data: - if self.volume_residuum is not None: - self.volume_residuum = ( - self.volume_residuum[1:, :, :] + self.volume_residuum[:-1, :, :] - ) / 2.0 - - if self.oxygenation_residuum is not None: - self.oxygenation_residuum = ( - self.oxygenation_residuum[1:, :, :] - + self.oxygenation_residuum[:-1, :, :] - ) / 2.0 - - @torch.no_grad() - def cleaned_load_data( - self, - experiment_id: int, - trial_id: int, - align: bool = True, - iterations: int = 1, - lowrank_method: bool = True, - lowrank_q: int = 6, - remove_heartbeat: bool = True, - remove_mean: bool = True, - remove_linear: bool = True, - remove_heartbeat_mean: bool = False, - remove_heartbeat_linear: bool = False, - bin_size: int = 4, - do_frame_shift: bool = True, - enable_secondary_data: bool = True, - mmap_mode: bool = True, - initital_mask: torch.Tensor | None = None, - start_position_coefficients: int = 0, - ) -> None: - self.logger.info(f"{self.level2} start load_data") - self.load_data( - experiment_id=experiment_id, - trial_id=trial_id, - align=align, - enable_secondary_data=enable_secondary_data, - mmap_mode=mmap_mode, - start_position_coefficients=start_position_coefficients, - ) - assert self.donor is not None - assert self.acceptor is not None - - if bin_size > 1: - self.logger.info(f"{self.level2} spatial pooling") - pool = torch.nn.AvgPool2d((bin_size, bin_size), stride=(bin_size, bin_size)) - self.donor = pool(self.donor) - self.acceptor = pool(self.acceptor) - if enable_secondary_data: - assert self.volume is not None - assert self.oxygenation is not None - self.volume = pool(self.volume) - self.oxygenation = pool(self.oxygenation) - - if self.donor is not None: - self.donor_scale = self.donor.mean(dim=0, keepdim=True) - self.donor /= self.donor_scale - self.donor -= 1.0 - - if self.acceptor is not None: - self.acceptor_scale = self.acceptor.mean(dim=0, keepdim=True) - self.acceptor /= self.acceptor_scale - self.acceptor -= 1.0 - - if self.volume is not None: - self.volume_scale = self.volume.mean(dim=0, keepdim=True) - self.volume /= self.volume_scale - self.volume -= 1.0 - - if self.oxygenation is not None: - self.oxygenation_scale = self.oxygenation.mean(dim=0, keepdim=True) - self.oxygenation /= self.oxygenation_scale - self.oxygenation -= 1.0 - - if initital_mask is not None: - self.logger.info(f"{self.level2} initial mask is applied on the data") - assert self.acceptor is not None - assert self.donor is not None - assert initital_mask.ndim == 2 - assert initital_mask.shape[0] == self.donor.shape[1] - assert initital_mask.shape[1] == self.donor.shape[2] - - self.acceptor *= initital_mask.unsqueeze(0) - self.donor *= initital_mask.unsqueeze(0) - - if enable_secondary_data: - assert self.oxygenation is not None - assert self.volume is not None - self.oxygenation *= initital_mask.unsqueeze(0) - self.volume *= initital_mask.unsqueeze(0) - - if remove_heartbeat: - self.logger.info(f"{self.level2} remove the heart beat via SVD") - self.remove_heartbeat( - iterations=iterations, - lowrank_method=lowrank_method, - lowrank_q=lowrank_q, - enable_secondary_data=enable_secondary_data, - start_position_coefficients=start_position_coefficients, - ) - - if remove_mean: - self.logger.info(f"{self.level2} remove mean") - self.remove_mean_data(enable_secondary_data=enable_secondary_data) - - if remove_linear: - self.logger.info(f"{self.level2} remove linear trends") - self.remove_linear_trend_data(enable_secondary_data=enable_secondary_data) - - if remove_heartbeat: - if remove_heartbeat_mean: - self.logger.info(f"{self.level2} remove mean (heart beat signal)") - self.remove_mean_residuum(enable_secondary_data=enable_secondary_data) - if remove_heartbeat_linear: - self.logger.info( - f"{self.level2} remove linear trends (heart beat signal)" - ) - self.remove_linear_trend_residuum( - enable_secondary_data=enable_secondary_data - ) - - if do_frame_shift: - self.logger.info(f"{self.level2} frame shift") - self.frame_shift(enable_secondary_data=enable_secondary_data) - - @torch.no_grad() - def remove_other_signals( # start_position_coefficients: OK - self, - start_position_coefficients: int = 0, - match_iterations: int = 25, - export_parameters: bool = True, - ) -> tuple[ - torch.Tensor, - torch.Tensor, - float, - float, - float, - float, - torch.Tensor | None, - torch.Tensor | None, - ]: - assert self.acceptor is not None - assert self.donor is not None - assert self.oxygenation is not None - assert self.volume is not None - - index_full_dataset = torch.arange( - 0, self.acceptor.shape[1], device=self.device, dtype=torch.int64 - ) - - result_a: torch.Tensor = torch.zeros_like(self.acceptor) - result_d: torch.Tensor = torch.zeros_like(self.donor) - - max_scale_value_a = 0.0 - initial_scale_value_a = 0.0 - max_scale_value_d = 0.0 - initial_scale_value_d = 0.0 - - parameter_a: torch.Tensor | None = None - parameter_d: torch.Tensor | None = None - - for chunk in self._chunk_iterator(index_full_dataset, self.filtfilt_chuck_size): - a: torch.Tensor = self.acceptor[:, chunk, :].detach().clone() - d: torch.Tensor = self.donor[:, chunk, :].detach().clone() - - o: torch.Tensor = self.oxygenation[:, chunk, :].detach().clone() - v: torch.Tensor = self.volume[:, chunk, :].detach().clone() - - a_mean = a[start_position_coefficients:, ...].mean(dim=0, keepdim=True) - a_mean_full = a.mean(dim=0, keepdim=True) - a -= a_mean_full - a_correction = a_mean - a_mean_full - - d_mean = d[start_position_coefficients:, ...].mean(dim=0, keepdim=True) - d_mean_full = d.mean(dim=0, keepdim=True) - d -= d_mean_full - d_correction = d_mean - d_mean_full - - o_mean = o[start_position_coefficients:, ...].mean(dim=0, keepdim=True) - o_mean_full = o.mean(dim=0, keepdim=True) - o -= o_mean - o_correction = o_mean - o_mean_full - o_norm = 1.0 / ( - (o[start_position_coefficients:, ...] ** 2).sum(dim=0) + 1e-20 - ) - - v_mean = v[start_position_coefficients:, ...].mean(dim=0, keepdim=True) - v_mean_full = v.mean(dim=0, keepdim=True) - v -= v_mean - v_correction = v_mean - v_mean_full - v_norm = 1.0 / ( - (v[start_position_coefficients:, ...] ** 2).sum(dim=0) + 1e-20 - ) - - linear: torch.Tensor = ( - torch.arange(0, o.shape[0], device=self.device, dtype=torch.float32) - .unsqueeze(-1) - .unsqueeze(-1) - ) - l_mean = linear[start_position_coefficients:, ...].mean(dim=0, keepdim=True) - l_mean_full = linear.mean(dim=0, keepdim=True) - linear -= l_mean - l_correction = l_mean - l_mean_full - linear_norm = 1.0 / ( - (linear[start_position_coefficients:, ...] ** 2).sum(dim=0) + 1e-20 - ) - linear = torch.tile(linear, (1, o.shape[1], o.shape[2])) - linear_norm = torch.tile(linear_norm, (o.shape[1], o.shape[2])) - l_correction = torch.tile(l_correction, (1, o.shape[1], o.shape[2])) - - data = torch.cat( - (linear.unsqueeze(-1), o.unsqueeze(-1), v.unsqueeze(-1)), dim=-1 - ) - del linear - del o - del v - - data_mean_correction = torch.cat( - ( - l_correction.unsqueeze(-1), - o_correction.unsqueeze(-1), - v_correction.unsqueeze(-1), - ), - dim=-1, - ) - - data_norm = torch.cat( - (linear_norm.unsqueeze(-1), o_norm.unsqueeze(-1), v_norm.unsqueeze(-1)), - dim=-1, - ) - del linear_norm - del o_norm - del v_norm - - if export_parameters: - parameter_a_temp: torch.Tensor | None = torch.zeros_like(data_norm) - parameter_d_temp: torch.Tensor | None = torch.zeros_like(data_norm) - else: - parameter_a_temp = None - parameter_d_temp = None - - for mode_a in [True, False]: - if mode_a: - result = a.detach().clone() - result_mean_correct = a_correction - - else: - result = d.detach().clone() - result_mean_correct = d_correction - - for i in range(0, match_iterations): - scale = ( - ( - data[start_position_coefficients:, ...] - * ( - result[start_position_coefficients:, ...] - + result_mean_correct - ).unsqueeze(-1) - ).sum(dim=0) - ) * data_norm - - idx = torch.abs(scale).argmax(dim=-1) - scale = torch.gather(scale, -1, idx.unsqueeze(-1)).squeeze(-1) - - idx_3d = torch.tile(idx.unsqueeze(0), (data.shape[0], 1, 1)) - data_selected = torch.gather( - (data - data_mean_correction), -1, idx_3d.unsqueeze(-1) - ).squeeze(-1) - - result -= data_selected * scale.unsqueeze(0) - - if mode_a: - if i == 0: - initial_scale_value_a = max( - [max_scale_value_a, float(scale.max())] - ) - if parameter_a_temp is not None: - parameter_a_temp.scatter_add_( - -1, idx.unsqueeze(-1), scale.unsqueeze(-1) - ) - - else: - if i == 0: - initial_scale_value_d = max( - [max_scale_value_d, float(scale.max())] - ) - if parameter_d_temp is not None: - parameter_d_temp.scatter_add_( - -1, idx.unsqueeze(-1), scale.unsqueeze(-1) - ) - - if mode_a: - result_a[:, chunk, :] = result.detach().clone() - max_scale_value_a = max([max_scale_value_a, float(scale.max())]) - if parameter_a_temp is not None: - parameter_a_temp = torch.cat( - (parameter_a_temp, a_mean_full.squeeze(0).unsqueeze(-1)), - dim=-1, - ) - else: - result_d[:, chunk, :] = result.detach().clone() - max_scale_value_d = max([max_scale_value_d, float(scale.max())]) - if parameter_d_temp is not None: - parameter_d_temp = torch.cat( - (parameter_d_temp, d_mean_full.squeeze(0).unsqueeze(-1)), - dim=-1, - ) - if export_parameters: - if (parameter_a is None) and (parameter_a_temp is not None): - parameter_a = torch.zeros( - ( - self.acceptor.shape[1], - parameter_a_temp.shape[1], - parameter_a_temp.shape[2], - ), - device=self.device, - dtype=torch.float32, - ) - if (parameter_a is not None) and (parameter_a_temp is not None): - parameter_a[chunk, ...] = parameter_a_temp - - if (parameter_d is None) and (parameter_d_temp is not None): - parameter_d = torch.zeros( - ( - self.acceptor.shape[1], - parameter_d_temp.shape[1], - parameter_d_temp.shape[2], - ), - device=self.device, - dtype=torch.float32, - ) - if (parameter_d is not None) and (parameter_d_temp is not None): - parameter_d[chunk, ...] = parameter_d_temp - - self.logger.info( - f"{self.level3} acceptor -- Progression scale: {initial_scale_value_a} -> {max_scale_value_a}" - ) - self.logger.info( - f"{self.level3} donor -- Progression scale: {initial_scale_value_d} -> {max_scale_value_d}" - ) - return ( - result_a, - result_d, - max_scale_value_a, - initial_scale_value_a, - max_scale_value_d, - initial_scale_value_d, - parameter_a, - parameter_d, - ) - - @torch.no_grad() - def _filtfilt( - self, - input: torch.Tensor, - butter_a: torch.Tensor, - butter_b: torch.Tensor, - ) -> torch.Tensor: - assert butter_a.ndim == 1 - assert butter_b.ndim == 1 - assert butter_a.shape[0] == butter_b.shape[0] - - process_data: torch.Tensor = input.movedim(0, -1).detach().clone() - - padding_length = 12 * int(butter_a.shape[0]) - left_padding = 2 * process_data[..., 0].unsqueeze(-1) - process_data[ - ..., 1 : padding_length + 1 - ].flip(-1) - right_padding = 2 * process_data[..., -1].unsqueeze(-1) - process_data[ - ..., -(padding_length + 1) : -1 - ].flip(-1) - process_data_padded = torch.cat( - (left_padding, process_data, right_padding), dim=-1 - ) - - output = ta.functional.filtfilt( - process_data_padded.unsqueeze(0), butter_a, butter_b, clamp=False - ).squeeze(0) - output = output[..., padding_length:-padding_length].movedim(-1, 0) - return output - - @torch.no_grad() - def _butter_bandpass( - self, low_frequency: float = 5, high_frequency: float = 15, fs: float = 100.0 - ) -> tuple[torch.Tensor, torch.Tensor]: - import scipy - - butter_b_np, butter_a_np = scipy.signal.butter( - 4, [low_frequency, high_frequency], btype="bandpass", output="ba", fs=fs - ) - butter_a = torch.tensor(butter_a_np, device=self.device, dtype=torch.float32) - butter_b = torch.tensor(butter_b_np, device=self.device, dtype=torch.float32) - return butter_a, butter_b - - @torch.no_grad() - def _chunk_iterator(self, array: torch.Tensor, chunk_size: int): - for i in range(0, array.shape[0], chunk_size): - yield array[i : i + chunk_size] - - @torch.no_grad() - def heartbeat_scale( # start_position_coefficients: OK - self, - low_frequency: float = 5, - high_frequency: float = 15, - fs: float = 100.0, - apply_to_data: bool = False, - threshold: float | None = 0.5, - start_position_coefficients: int = 0, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: - assert self.donor_residuum is not None - assert self.acceptor_residuum is not None - - butter_a, butter_b = self._butter_bandpass( - low_frequency=low_frequency, high_frequency=high_frequency, fs=fs - ) - - butter_a, butter_b = self._butter_bandpass( - low_frequency=low_frequency, high_frequency=high_frequency, fs=100.0 - ) - self.logger.info(f"{self.level3} apply bandpass donor_residuum (filtfilt)") - - index_full_dataset: torch.Tensor = torch.arange( - 0, self.donor_residuum.shape[1], device=self.device, dtype=torch.int64 - ) - - hb_d = torch.zeros_like(self.donor_residuum[start_position_coefficients:, ...]) - for chunk in self._chunk_iterator(index_full_dataset, self.filtfilt_chuck_size): - temp_filtfilt = self._filtfilt( - self.donor_residuum[start_position_coefficients:, chunk, :], - butter_a=butter_a, - butter_b=butter_b, - ) - hb_d[:, chunk, :] = temp_filtfilt - - # hb_d = hb_d[start_position:, ...] - hb_d -= hb_d.mean(dim=0, keepdim=True) - - self.logger.info(f"{self.level3} apply bandpass acceptor_residuum (filtfilt)") - - index_full_dataset = torch.arange( - 0, self.acceptor_residuum.shape[1], device=self.device, dtype=torch.int64 - ) - hb_a = torch.zeros_like(self.donor_residuum[start_position_coefficients:, ...]) - for chunk in self._chunk_iterator(index_full_dataset, self.filtfilt_chuck_size): - temp_filtfilt = self._filtfilt( - self.acceptor_residuum[start_position_coefficients:, chunk, :], - butter_a=butter_a, - butter_b=butter_b, - ) - hb_a[:, chunk, :] = temp_filtfilt - - # hb_a = hb_a[start_position:, ...] - hb_a -= hb_a.mean(dim=0, keepdim=True) - - scale = (hb_a * hb_d).sum(dim=0) / (hb_a**2).sum(dim=0) - - heartbeat_a = torch.sqrt(scale) - heartbeat_d = 1.0 / (heartbeat_a + 1e-20) - - if apply_to_data: - if self.donor is not None: - self.donor *= heartbeat_d.unsqueeze(0) - if self.volume is not None: - self.volume *= heartbeat_d.unsqueeze(0) - if self.acceptor is not None: - self.acceptor *= heartbeat_a.unsqueeze(0) - if self.oxygenation is not None: - self.oxygenation *= heartbeat_a.unsqueeze(0) - - if threshold is not None: - self.logger.info(f"{self.level3} calculate mask") - assert self.donor_scale is not None - assert self.acceptor_scale is not None - temp_d = hb_d.std(dim=0) * self.donor_scale.squeeze(0) - temp_d -= temp_d.min() - temp_d /= temp_d.max() - - temp_a = hb_a.std(dim=0) * self.acceptor_scale.squeeze(0) - temp_a -= temp_a.min() - temp_a /= temp_a.max() - - mask = torch.where(temp_d > threshold, 1.0, 0.0) * torch.where( - temp_a > threshold, 1.0, 0.0 - ) - else: - mask = None - - return heartbeat_a, heartbeat_d, mask - - @torch.no_grad() - def measure_heartbeat_frequency( # start_position_coefficients: OK - self, - low_frequency: float = 5, - high_frequency: float = 15, - fs: float = 100.0, - use_input_source: str = "volume", - start_position_coefficients: int = 0, - half_width_frequency_window: float = 3.0, # Hz (on side ) - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if use_input_source == "donor": - assert self.donor is not None - hb: torch.Tensor = self.donor[start_position_coefficients:, ...] - - elif use_input_source == "acceptor": - assert self.acceptor is not None - hb = self.acceptor[start_position_coefficients:, ...] - - elif use_input_source == "volume": - assert self.volume is not None - hb = self.volume[start_position_coefficients:, ...] - - else: - assert self.oxygenation is not None - hb = self.oxygenation[start_position_coefficients:, ...] - - frequency_axis: torch.Tensor = ( - torch.fft.rfftfreq(hb.shape[0]).to(device=self.device) * fs - ) - - delta_idx = int( - math.ceil( - half_width_frequency_window - / (float(frequency_axis[1]) - float(frequency_axis[0])) - ) - ) - - idx: torch.Tensor = torch.where( - (frequency_axis >= low_frequency) * (frequency_axis <= high_frequency) - )[0] - - power_hb: torch.Tensor = torch.abs(torch.fft.rfft(hb, dim=0)) ** 2 - power_hb = power_hb[idx, :, :].argmax(dim=0) + idx[0] - power_hb_low = power_hb - delta_idx - power_hb_low = power_hb_low.clamp(min=0) - power_hb_high = power_hb + delta_idx - power_hb_high = power_hb_high.clamp(max=frequency_axis.shape[0]) - - return power_hb_low, power_hb_high, frequency_axis - - @torch.no_grad() - def measure_heartbeat_power( # start_position_coefficients: OK - self, - use_input_source: str = "donor", - start_position_coefficients: int = 0, - power_hb_low: torch.Tensor | None = None, - power_hb_high: torch.Tensor | None = None, - custom_input: torch.Tensor | None = None, - ) -> torch.Tensor: - if use_input_source == "donor": - assert self.donor is not None - hb: torch.Tensor = self.donor[start_position_coefficients:, ...] - - elif use_input_source == "acceptor": - assert self.acceptor is not None - hb = self.acceptor[start_position_coefficients:, ...] - - elif use_input_source == "volume": - assert self.volume is not None - hb = self.volume[start_position_coefficients:, ...] - - elif use_input_source == "custom": - assert custom_input is not None - hb = custom_input[start_position_coefficients:, ...] - else: - assert self.oxygenation is not None - hb = self.oxygenation[start_position_coefficients:, ...] - - counter: torch.Tensor = torch.zeros( - (hb.shape[1], hb.shape[2]), - dtype=hb.dtype, - device=self.device, - ) - - index_full_dataset = torch.arange( - 0, hb.shape[1], device=self.device, dtype=torch.int64 - ) - - power_hb: torch.Tensor | None = None - for chunk in self._chunk_iterator(index_full_dataset, self.filtfilt_chuck_size): - temp_power = torch.abs(torch.fft.rfft(hb[:, chunk, :], dim=0)) ** 2 - if power_hb is None: - power_hb = torch.zeros( - (temp_power.shape[0], hb.shape[1], temp_power.shape[2]), - dtype=temp_power.dtype, - device=temp_power.device, - ) - assert power_hb is not None - power_hb[:, chunk, :] = temp_power - - assert power_hb is not None - for pos in range(0, power_hb.shape[0]): - pos_torch = torch.tensor(pos, dtype=torch.int64, device=self.device) - slice_temp = ( - (pos_torch >= power_hb_low) * (pos_torch < power_hb_high) - ).type(dtype=power_hb.dtype) - power_hb[pos, ...] *= slice_temp - counter += slice_temp - power_hb = power_hb.sum(dim=0) / counter - - return power_hb - - @torch.no_grad() - def automatic_load( # start_position_coefficients: OK - self, - experiment_id: int = 1, - trial_id: int = 1, - start_position: int = 0, - start_position_coefficients: int = 100, - fs: float = 100.0, - use_regression: bool | None = None, - # Heartbeat - remove_heartbeat: bool = True, # i.e. use SVD - low_frequency: float = 5, # Hz Butter Bandpass Heartbeat - high_frequency: float = 15, # Hz Butter Bandpass Heartbeat - threshold: float | None = 0.5, # For the mask - # Extra exposed parameters: - align: bool = True, - iterations: int = 1, # SVD iterations: Do not touch! Keep at 1 - lowrank_method: bool = True, - lowrank_q: int = 6, - remove_heartbeat_mean: bool = False, - remove_heartbeat_linear: bool = False, - bin_size: int = 4, - do_frame_shift: bool | None = None, - half_width_frequency_window: float = 3.0, # Hz (on side ) measure_heartbeat_frequency - mmap_mode: bool = True, - initital_mask_name: str | None = None, - initital_mask_update: bool = True, - initital_mask_roi: bool = False, - gaussian_blur_kernel_size: int | None = 3, - gaussian_blur_sigma: float = 1.0, - bin_size_post: int | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - self.logger.info(f"{self.level0} start automatic_load") - - if use_regression is None: - use_regression = not remove_heartbeat - - if do_frame_shift is None: - do_frame_shift = not remove_heartbeat - - initital_mask: torch.Tensor | None = None - - if (initital_mask_name is not None) and os.path.isfile(initital_mask_name): - initital_mask = torch.tensor( - np.load(initital_mask_name), device=self.device, dtype=torch.float32 - ) - self.logger.info(f"{self.level1} try to load previous mask: found") - else: - self.logger.info(f"{self.level1} try to load previous mask: NOT found") - - self.logger.info(f"{self.level1} start cleaned_load_data") - self.cleaned_load_data( - experiment_id=experiment_id, - trial_id=trial_id, - remove_heartbeat=remove_heartbeat, - remove_mean=not use_regression, - remove_linear=not use_regression, - enable_secondary_data=use_regression, - align=align, - iterations=iterations, - lowrank_method=lowrank_method, - lowrank_q=lowrank_q, - remove_heartbeat_mean=remove_heartbeat_mean, - remove_heartbeat_linear=remove_heartbeat_linear, - bin_size=bin_size, - do_frame_shift=do_frame_shift, - mmap_mode=mmap_mode, - initital_mask=initital_mask, - start_position_coefficients=start_position_coefficients, - ) - - heartbeat_a: torch.Tensor | None = None - heartbeat_d: torch.Tensor | None = None - mask: torch.Tensor | None = None - power_hb_low: torch.Tensor | None = None - power_hb_high: torch.Tensor | None = None - - if remove_heartbeat: - self.logger.info(f"{self.level1} remove heart beat (heartbeat_scale)") - heartbeat_a, heartbeat_d, mask = self.heartbeat_scale( - low_frequency=low_frequency, - high_frequency=high_frequency, - fs=fs, - apply_to_data=False, - threshold=threshold, - start_position_coefficients=start_position_coefficients, - ) - else: - self.logger.info( - f"{self.level1} measure heart rate (measure_heartbeat_frequency)" - ) - assert self.volume is not None - ( - power_hb_low, - power_hb_high, - _, - ) = self.measure_heartbeat_frequency( - low_frequency=low_frequency, - high_frequency=high_frequency, - fs=fs, - use_input_source="volume", - start_position_coefficients=start_position_coefficients, - half_width_frequency_window=half_width_frequency_window, - ) - - if use_regression: - self.logger.info(f"{self.level1} use regression") - ( - result_a, - result_d, - _, - _, - _, - _, - _, - _, - ) = self.remove_other_signals( - start_position_coefficients=start_position_coefficients, - match_iterations=25, - export_parameters=False, - ) - result_a = result_a[start_position:, ...] - result_d = result_d[start_position:, ...] - else: - self.logger.info(f"{self.level1} don't use regression") - assert self.acceptor is not None - assert self.donor is not None - result_a = self.acceptor[start_position:, ...].clone() - result_d = self.donor[start_position:, ...].clone() - - if mask is not None: - result_a *= mask.unsqueeze(0) - result_d *= mask.unsqueeze(0) - - if remove_heartbeat is False: - self.logger.info( - f"{self.level1} donor: measure heart beat spectral power (measure_heartbeat_power)" - ) - temp_d = self.measure_heartbeat_power( - use_input_source="donor", - start_position_coefficients=start_position_coefficients, - power_hb_low=power_hb_low, - power_hb_high=power_hb_high, - ) - self.logger.info( - f"{self.level1} acceptor: measure heart beat spectral power (measure_heartbeat_power)" - ) - temp_a = self.measure_heartbeat_power( - use_input_source="acceptor", - start_position_coefficients=start_position_coefficients, - power_hb_low=power_hb_low, - power_hb_high=power_hb_high, - ) - scale = temp_d / (temp_a + 1e-20) - - heartbeat_a = torch.sqrt(scale) - heartbeat_d = 1.0 / (heartbeat_a + 1e-20) - - self.logger.info(f"{self.level1} scale acceptor and donor signals") - if heartbeat_a is not None: - result_a *= heartbeat_a.unsqueeze(0) - if heartbeat_d is not None: - result_d *= heartbeat_d.unsqueeze(0) - - if mask is not None: - if initital_mask_update: - self.logger.info(f"{self.level1} update inital mask") - if initital_mask is None: - initital_mask = mask.clone() - else: - initital_mask *= mask - - if (initital_mask_roi) and (initital_mask is not None): - self.logger.info(f"{self.level1} enter roi mask drawing modus") - yes_choices = ["yes", "y"] - contiue_roi: bool = True - - image: np.ndarray = (result_a - result_d)[0, ...].cpu().numpy() - image[initital_mask.cpu().numpy() == 0] = float("NaN") - - while contiue_roi: - user_input = input( - "Mask: Do you want to remove more pixel (yes/no)? " - ) - - if user_input.lower() in yes_choices: - plt.imshow(image, cmap="hot") - plt.title("Select a region for removal") - - temp_roi = RoiPoly(color="g") - plt.show() - - if len(temp_roi.x) > 0: - new_mask = temp_roi.get_mask(image) - new_mask_np = new_mask.astype(np.float32) - new_mask_torch = torch.tensor( - new_mask_np, - device=self.device, - dtype=torch.float32, - ) - - plt.imshow(image, cmap="hot") - temp_roi.display_roi() - plt.title("Selected region for removal") - print("Please close figure when ready...") - plt.show() - user_input = input( - "Mask: Remove these pixel (yes/no)? " - ) - - if user_input.lower() in yes_choices: - initital_mask *= 1.0 - new_mask_torch - image[new_mask] = float("NaN") - - else: - contiue_roi = False - - if initital_mask_name is not None: - self.logger.info(f"{self.level1} save mask") - np.save(initital_mask_name, initital_mask.cpu().numpy()) - - self.logger.info(f"{self.level0} end automatic_load") - - # result = (1.0 + result_a) / (1.0 + result_d) - result = 1.0 + result_a - result_d - - if (gaussian_blur_kernel_size is not None) and (gaussian_blur_kernel_size > 0): - gaussian_blur = tv.transforms.GaussianBlur( - kernel_size=[gaussian_blur_kernel_size, gaussian_blur_kernel_size], - sigma=gaussian_blur_sigma, - ) - result = gaussian_blur(result) - - if (bin_size_post is not None) and (bin_size_post > 1): - pool = torch.nn.AvgPool2d( - (bin_size_post, bin_size_post), stride=(bin_size_post, bin_size_post) - ) - result = pool(result) - - if mask is not None: - mask = ( - (pool(mask.unsqueeze(0)) > 0).type(dtype=torch.float32).squeeze(0) - ) - - return result, mask - - -if __name__ == "__main__": - from functions.Anime import Anime - - path: str = "/data_1/hendrik/2023-07-17/M_Sert_Cre_41/raw" - initital_mask_name: str | None = None - initital_mask_update: bool = True - initital_mask_roi: bool = False # default: True - - experiment_id: int = 1 - trial_id: int = 1 - start_position: int = 0 - start_position_coefficients: int = 100 - remove_heartbeat: bool = True # i.e. use SVD - svd_iterations: int = 1 # SVD iterations: Do not touch! Keep at 1 - bin_size: int = 4 - threshold: float | None = 0.05 # Between 0 and 1.0 - - example_position_x: int = 280 - example_position_y: int = 440 - - display_logging_messages: bool = False - save_logging_messages: bool = False - - show_example_timeseries: bool = True - save_example_timeseries: bool = False - play_movie: bool = False - - # Post data processing modifiations - gaussian_blur_kernel_size: int | None = 3 - gaussian_blur_sigma: float = 1.0 - bin_size_post: int | None = None - - # ------------------------ - example_position_x = example_position_x // bin_size - example_position_y = example_position_y // bin_size - if bin_size_post is not None: - example_position_x = example_position_x // bin_size_post - example_position_y = example_position_y // bin_size_post - - torch_device: torch.device = torch.device( - "cuda:0" if torch.cuda.is_available() else "cpu" - ) - - af = DataContainer( - path=path, - device=torch_device, - display_logging_messages=display_logging_messages, - save_logging_messages=save_logging_messages, - ) - result, mask = af.automatic_load( - experiment_id=experiment_id, - trial_id=trial_id, - start_position=start_position, - remove_heartbeat=remove_heartbeat, # i.e. use SVD - iterations=svd_iterations, - bin_size=bin_size, - initital_mask_name=initital_mask_name, - initital_mask_update=initital_mask_update, - initital_mask_roi=initital_mask_roi, - start_position_coefficients=start_position_coefficients, - gaussian_blur_kernel_size=gaussian_blur_kernel_size, - gaussian_blur_sigma=gaussian_blur_sigma, - bin_size_post=bin_size_post, - threshold=threshold, - ) - - if show_example_timeseries: - plt.plot(result[:, example_position_x, example_position_y].cpu()) - plt.show() - - if save_example_timeseries: - if remove_heartbeat: - np.save( - f"SVD_{svd_iterations}.npy", - result[:, example_position_x, example_position_y].cpu().numpy(), - ) - else: - np.save( - "Classic.npy", - result[:, example_position_x, example_position_y].cpu().numpy(), - ) - - if play_movie: - ani = Anime() - ani.show( - result - 1.0, mask=mask, vmin_scale=0.5, vmax_scale=0.5 - ) # , vmin=0.98) # , vmin=1.0, vmax_scale=1.0)