From 6b3146be0f06849d1613a2700d5478c6f801b9a5 Mon Sep 17 00:00:00 2001 From: David Rotermund <54365609+davrot@users.noreply.github.com> Date: Mon, 26 Feb 2024 18:56:59 +0100 Subject: [PATCH] Add files via upload --- new_pipeline/config.json | 9 +- new_pipeline/stage_1_get_ref_image.py | 7 +- new_pipeline/stage_2_make_heartbeat_mask.py | 8 +- new_pipeline/stage_3_refine_mask.py | 11 +- new_pipeline/stage_4_process.py | 619 ++++++++++++++++++++ 5 files changed, 632 insertions(+), 22 deletions(-) create mode 100644 new_pipeline/stage_4_process.py diff --git a/new_pipeline/config.json b/new_pipeline/config.json index e0e6eac..a9e4f2c 100644 --- a/new_pipeline/config.json +++ b/new_pipeline/config.json @@ -11,16 +11,21 @@ "oxygenation", "volume" ], - "dtype": "float32", - "binning_enable": false, + // binning + "binning_enable": true, + "binning_before_alignment": false, // otherwise at the end after everything else "binning_kernel_size": 4, "binning_stride": 4, "binning_divisor_override": 1, + // alignment + "alignment_batch_size": 200, // Heart beat detection "lower_freqency_bandpass": 5.0, // Hz "upper_freqency_bandpass": 14.0, // Hz + "heartbeat_filtfilt_chuck_size": 10, // LED Ramp on "skip_frames_in_the_beginning": 100, // Frames // PyTorch + "dtype": "float32", "force_to_cpu": false } \ No newline at end of file diff --git a/new_pipeline/stage_1_get_ref_image.py b/new_pipeline/stage_1_get_ref_image.py index c1c1eb4..2aa7bf8 100644 --- a/new_pipeline/stage_1_get_ref_image.py +++ b/new_pipeline/stage_1_get_ref_image.py @@ -1,6 +1,4 @@ -import json import os -from jsmin import jsmin # type: ignore import torch import numpy as np @@ -12,14 +10,13 @@ from functions.bandpass import bandpass from functions.create_logger import create_logger from functions.load_meta_data import load_meta_data from functions.get_torch_device import get_torch_device +from functions.load_config import load_config mylogger = create_logger( save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_1" ) -mylogger.info("loading config file") -with open("config.json", "r") as file: - config = json.loads(jsmin(file.read())) +config = load_config(mylogger=mylogger) device = get_torch_device(mylogger, config["force_to_cpu"]) diff --git a/new_pipeline/stage_2_make_heartbeat_mask.py b/new_pipeline/stage_2_make_heartbeat_mask.py index ee7a8f9..e36516b 100644 --- a/new_pipeline/stage_2_make_heartbeat_mask.py +++ b/new_pipeline/stage_2_make_heartbeat_mask.py @@ -3,23 +3,19 @@ import matplotlib import numpy as np import torch import os -import json -from jsmin import jsmin # type:ignore from matplotlib.widgets import Slider, Button # type:ignore from functools import partial from functions.gauss_smear_individual import gauss_smear_individual from functions.create_logger import create_logger from functions.get_torch_device import get_torch_device +from functions.load_config import load_config mylogger = create_logger( save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_2" ) -mylogger.info("loading config file") -with open("config.json", "r") as file: - config = json.loads(jsmin(file.read())) - +config = load_config(mylogger=mylogger) path: str = config["ref_image_path"] use_channel: str = "donor" diff --git a/new_pipeline/stage_3_refine_mask.py b/new_pipeline/stage_3_refine_mask.py index 6bb0bd7..0d02a19 100644 --- a/new_pipeline/stage_3_refine_mask.py +++ b/new_pipeline/stage_3_refine_mask.py @@ -1,5 +1,4 @@ import os -import json import numpy as np import matplotlib.pyplot as plt # type:ignore @@ -9,9 +8,9 @@ from matplotlib.widgets import Button # type:ignore # pip install roipoly from roipoly import RoiPoly -from jsmin import jsmin # type:ignore from functions.create_logger import create_logger from functions.get_torch_device import get_torch_device +from functions.load_config import load_config def compose_image(image_3color: np.ndarray, mask: np.ndarray) -> np.ndarray: @@ -86,9 +85,7 @@ mylogger = create_logger( save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_3" ) -mylogger.info("loading config file") -with open("config.json", "r") as file: - config = json.loads(jsmin(file.read())) +config = load_config(mylogger=mylogger) device = get_torch_device(mylogger, config["force_to_cpu"]) @@ -158,7 +155,3 @@ mylogger.info("Display") new_roi: RoiPoly = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False) plt.show() - - -# image_handle.remove() -# diff --git a/new_pipeline/stage_4_process.py b/new_pipeline/stage_4_process.py new file mode 100644 index 0000000..c84b380 --- /dev/null +++ b/new_pipeline/stage_4_process.py @@ -0,0 +1,619 @@ +import numpy as np +import torch +import torchvision as tv # type: ignore + +import os +import logging + +from functions.create_logger import create_logger +from functions.get_torch_device import get_torch_device +from functions.load_config import load_config +from functions.load_meta_data import load_meta_data + +from functions.get_experiments import get_experiments +from functions.get_trials import get_trials +from functions.get_parts import get_parts +from functions.binning import binning +from functions.ImageAlignment import ImageAlignment +from functions.align_refref import align_refref +from functions.perform_donor_volume_rotation import perform_donor_volume_rotation +from functions.perform_donor_volume_translation import perform_donor_volume_translation +from functions.bandpass import bandpass + +import matplotlib.pyplot as plt + + +@torch.no_grad() +def process_trial( + config: dict, + mylogger: logging.Logger, + experiment_id: int, + trial_id: int, + device: torch.device, +): + if device != torch.device("cpu"): + torch.cuda.empty_cache() + mylogger.info("Empty CUDA cache") + cuda_total_memory: int = torch.cuda.get_device_properties( + device.index + ).total_memory + else: + cuda_total_memory = 0 + + raw_data_path: str = os.path.join( + config["basic_path"], + config["recoding_data"], + config["mouse_identifier"], + config["raw_path"], + ) + + if os.path.isdir(raw_data_path) is False: + mylogger.info(f"ERROR: could not find raw directory {raw_data_path}!!!!") + return + + if (torch.where(get_experiments(raw_data_path) == experiment_id)[0].shape[0]) != 1: + mylogger.info(f"ERROR: could not find experiment id {experiment_id}!!!!") + return + + if ( + torch.where(get_trials(raw_data_path, experiment_id) == trial_id)[0].shape[0] + ) != 1: + mylogger.info(f"ERROR: could not find trial id {trial_id}!!!!") + return + + if get_parts(raw_data_path, experiment_id, trial_id).shape[0] != 1: + mylogger.info("ERROR: this has more than one part. NOT IMPLEMENTED YET!!!!") + assert get_parts(raw_data_path, experiment_id, trial_id).shape[0] == 1 + part_id: int = 1 + + experiment_name = f"Exp{experiment_id:03d}_Trial{trial_id:03d}" + mylogger.info(f"Will work on: {experiment_name}") + + filename_data: str = os.path.join( + raw_data_path, + f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}.npy", + ) + + mylogger.info(f"Will use: {filename_data} for data") + + filename_meta: str = os.path.join( + raw_data_path, + f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}_meta.txt", + ) + + mylogger.info(f"Will use: {filename_meta} for meta data") + + if os.path.isfile(filename_meta) is False: + mylogger.info(f"Could not load meta data... {filename_meta}") + mylogger.info(f"ERROR: skipping {experiment_name}!!!!") + return + + meta_channels: list[str] + meta_mouse_markings: str + meta_recording_date: str + meta_stimulation_times: dict + meta_experiment_names: dict + meta_trial_recording_duration: float + meta_frame_time: float + meta_mouse: str + + ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + ) = load_meta_data(mylogger=mylogger, filename_meta=filename_meta) + + dtype_str = config["dtype"] + mylogger.info(f"Data precision will be {dtype_str}") + dtype: torch.dtype = getattr(torch, dtype_str) + dtype_np: np.dtype = getattr(np, dtype_str) + + mylogger.info("Loading raw data") + + if device != torch.device("cpu"): + free_mem: int = cuda_total_memory - max( + [torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)] + ) + mylogger.info(f"CUDA memory before loading RAW data: {free_mem//1024} MByte") + + data_np: np.ndarray = np.load(filename_data, mmap_mode="r").astype(dtype_np) + data: torch.Tensor = torch.zeros(data_np.shape, dtype=dtype, device=device) + for i in range(0, len(config["required_order"])): + mylogger.info(f"Move raw data to PyTorch device: {config['required_order'][i]}") + idx = meta_channels.index(config["required_order"][i]) + data[..., i] = torch.tensor(data_np[..., idx], dtype=dtype, device=device) + + if device != torch.device("cpu"): + free_mem = cuda_total_memory - max( + [torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)] + ) + mylogger.info(f"CUDA memory after loading RAW data: {free_mem//1024} MByte") + + del data_np + mylogger.info(f"Data shape: {data.shape}") + mylogger.info("-==- Done -==-") + + mylogger.info("Reference images and mask") + + ref_image_path: str = config["ref_image_path"] + + ref_image_path_acceptor: str = os.path.join(ref_image_path, "acceptor.npy") + if os.path.isfile(ref_image_path_acceptor) is False: + mylogger.info(f"Could not load ref file: {ref_image_path_acceptor}") + return + + mylogger.info(f"Loading ref file data: {ref_image_path_acceptor}") + ref_image_acceptor: torch.Tensor = torch.tensor( + np.load(ref_image_path_acceptor).astype(dtype_np), dtype=dtype, device=device + ) + + ref_image_path_donor: str = os.path.join(ref_image_path, "donor.npy") + if os.path.isfile(ref_image_path_donor) is False: + mylogger.info(f"Could not load ref file: {ref_image_path_donor}") + return + + mylogger.info(f"Loading ref file data: {ref_image_path_donor}") + ref_image_donor: torch.Tensor = torch.tensor( + np.load(ref_image_path_donor).astype(dtype_np), dtype=dtype, device=device + ) + + ref_image_path_oxygenation: str = os.path.join(ref_image_path, "oxygenation.npy") + if os.path.isfile(ref_image_path_oxygenation) is False: + mylogger.info(f"Could not load ref file: {ref_image_path_oxygenation}") + return + + mylogger.info(f"Loading ref file data: {ref_image_path_oxygenation}") + ref_image_oxygenation: torch.Tensor = torch.tensor( + np.load(ref_image_path_oxygenation).astype(dtype_np), dtype=dtype, device=device + ) + + ref_image_path_volume: str = os.path.join(ref_image_path, "volume.npy") + if os.path.isfile(ref_image_path_volume) is False: + mylogger.info(f"Could not load ref file: {ref_image_path_volume}") + return + + mylogger.info(f"Loading ref file data: {ref_image_path_volume}") + ref_image_volume: torch.Tensor = torch.tensor( + np.load(ref_image_path_volume).astype(dtype_np), dtype=dtype, device=device + ) + + refined_mask_file: str = os.path.join(ref_image_path, "mask_not_rotated.npy") + if os.path.isfile(refined_mask_file) is False: + mylogger.info(f"Could not load mask file: {refined_mask_file}") + return + + mylogger.info(f"Loading mask file data: {refined_mask_file}") + mask: torch.Tensor = torch.tensor( + np.load(refined_mask_file).astype(dtype_np), dtype=dtype, device=device + ) + mylogger.info("-==- Done -==-") + + if config["binning_enable"] and config["binning_before_alignment"]: + mylogger.info("Binning of data") + mylogger.info( + ( + f"kernel_size={int(config['binning_kernel_size'])}," + f"stride={int(config['binning_stride'])}," + f"divisor_override={int(config['binning_divisor_override'])}" + ) + ) + + data = binning( + data, + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=int(config["binning_divisor_override"]), + ) + ref_image_acceptor = ( + binning( + ref_image_acceptor.unsqueeze(-1).unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=int(config["binning_divisor_override"]), + ) + .squeeze(-1) + .squeeze(-1) + ) + ref_image_donor = ( + binning( + ref_image_donor.unsqueeze(-1).unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=int(config["binning_divisor_override"]), + ) + .squeeze(-1) + .squeeze(-1) + ) + ref_image_oxygenation = ( + binning( + ref_image_oxygenation.unsqueeze(-1).unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=int(config["binning_divisor_override"]), + ) + .squeeze(-1) + .squeeze(-1) + ) + ref_image_volume = ( + binning( + ref_image_volume.unsqueeze(-1).unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=int(config["binning_divisor_override"]), + ) + .squeeze(-1) + .squeeze(-1) + ) + mask = ( + binning( + mask.unsqueeze(-1).unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=int(config["binning_divisor_override"]), + ) + .squeeze(-1) + .squeeze(-1) + ) + mylogger.info(f"Data shape: {data.shape}") + mylogger.info("-==- Done -==-") + + mylogger.info("Preparing alignment") + image_alignment = ImageAlignment(default_dtype=dtype, device=device) + + mylogger.info("Re-order Raw data") + data = data.moveaxis(-2, 0).moveaxis(-1, 0) + mylogger.info(f"Data shape: {data.shape}") + mylogger.info("-==- Done -==-") + + mylogger.info("Alignment of the ref images and the mask") + mylogger.info("Ref image of donor stays fixed.") + mylogger.info("Ref image of volume and the mask doesn't need to be touched") + mylogger.info("Calculate translation and rotation between the reference images") + angle_refref, tvec_refref, ref_image_acceptor, ref_image_donor = align_refref( + mylogger=mylogger, + ref_image_acceptor=ref_image_acceptor, + ref_image_donor=ref_image_donor, + image_alignment=image_alignment, + batch_size=config["alignment_batch_size"], + fill_value=-1.0, + ) + mylogger.info(f"Rotation: {round(float(angle_refref[0]),2)} degree") + mylogger.info( + f"Translation: {round(float(tvec_refref[0]),1)} x {round(float(tvec_refref[1]),1)} pixel" + ) + + temp_path: str = os.path.join( + config["export_path"], experiment_name + "_angle_refref.npy" + ) + mylogger.info(f"Save angle to {temp_path}") + np.save(temp_path, angle_refref.cpu()) + + temp_path = os.path.join( + config["export_path"], experiment_name + "_tvec_refref.npy" + ) + mylogger.info(f"Save translation vector to {temp_path}") + np.save(temp_path, tvec_refref.cpu()) + + mylogger.info("Moving & rotating the oxygenation ref image") + ref_image_oxygenation = tv.transforms.functional.affine( + img=ref_image_oxygenation.unsqueeze(0), + angle=-float(angle_refref), + translate=[0, 0], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=-1.0, + ) + + ref_image_oxygenation = tv.transforms.functional.affine( + img=ref_image_oxygenation, + angle=0, + translate=[tvec_refref[1], tvec_refref[0]], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=-1.0, + ).squeeze(0) + mylogger.info("-==- Done -==-") + + mylogger.info("Rotate and translate the acceptor and oxygenation data accordingly") + acceptor_index: int = config["required_order"].index("acceptor") + donor_index: int = config["required_order"].index("donor") + oxygenation_index: int = config["required_order"].index("oxygenation") + volume_index: int = config["required_order"].index("volume") + + mylogger.info("Rotate acceptor") + data[acceptor_index, ...] = tv.transforms.functional.affine( + img=data[acceptor_index, ...], + angle=-float(angle_refref), + translate=[0, 0], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=-1.0, + ) + + mylogger.info("Translate acceptor") + data[acceptor_index, ...] = tv.transforms.functional.affine( + img=data[acceptor_index, ...], + angle=0, + translate=[tvec_refref[1], tvec_refref[0]], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=-1.0, + ) + + mylogger.info("Rotate oxygenation") + data[oxygenation_index, ...] = tv.transforms.functional.affine( + img=data[oxygenation_index, ...], + angle=-float(angle_refref), + translate=[0, 0], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=-1.0, + ) + + mylogger.info("Translate oxygenation") + data[oxygenation_index, ...] = tv.transforms.functional.affine( + img=data[oxygenation_index, ...], + angle=0, + translate=[tvec_refref[1], tvec_refref[0]], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=-1.0, + ) + mylogger.info("-==- Done -==-") + + mylogger.info("Perform rotation between donor and volume and its ref images") + mylogger.info("for all frames and then rotate all the data accordingly") + perform_donor_volume_rotation + ( + data[acceptor_index, ...], + data[donor_index, ...], + data[oxygenation_index, ...], + data[volume_index, ...], + angle_donor_volume, + ) = perform_donor_volume_rotation( + mylogger=mylogger, + acceptor=data[acceptor_index, ...], + donor=data[donor_index, ...], + oxygenation=data[oxygenation_index, ...], + volume=data[volume_index, ...], + ref_image_donor=ref_image_donor, + ref_image_volume=ref_image_volume, + image_alignment=image_alignment, + batch_size=config["alignment_batch_size"], + fill_value=-1.0, + ) + + mylogger.info( + f"angles: " + f"min {round(float(angle_donor_volume.min()),2)} " + f"max {round(float(angle_donor_volume.max()),2)} " + f"mean {round(float(angle_donor_volume.mean()),2)} " + ) + + temp_path = os.path.join( + config["export_path"], experiment_name + "_angle_donor_volume.npy" + ) + mylogger.info(f"Save angles to {temp_path}") + np.save(temp_path, angle_donor_volume.cpu()) + mylogger.info("-==- Done -==-") + + mylogger.info("Perform translation between donor and volume and its ref images") + mylogger.info("for all frames and then translate all the data accordingly") + ( + data[acceptor_index, ...], + data[donor_index, ...], + data[oxygenation_index, ...], + data[volume_index, ...], + tvec_donor_volume, + ) = perform_donor_volume_translation( + mylogger=mylogger, + acceptor=data[acceptor_index, ...], + donor=data[donor_index, ...], + oxygenation=data[oxygenation_index, ...], + volume=data[volume_index, ...], + ref_image_donor=ref_image_donor, + ref_image_volume=ref_image_volume, + image_alignment=image_alignment, + batch_size=config["alignment_batch_size"], + fill_value=-1.0, + ) + + mylogger.info( + f"translation dim 0: " + f"min {round(float(tvec_donor_volume[:,0].min()),1)} " + f"max {round(float(tvec_donor_volume[:,0].max()),1)} " + f"mean {round(float(tvec_donor_volume[:,0].mean()),1)} " + ) + mylogger.info( + f"translation dim 1: " + f"min {round(float(tvec_donor_volume[:,1].min()),1)} " + f"max {round(float(tvec_donor_volume[:,1].max()),1)} " + f"mean {round(float(tvec_donor_volume[:,1].mean()),1)} " + ) + + temp_path = os.path.join( + config["export_path"], experiment_name + "_tvec_donor_volume.npy" + ) + mylogger.info(f"Save translation vector to {temp_path}") + np.save(temp_path, tvec_donor_volume.cpu()) + mylogger.info("-==- Done -==-") + + mylogger.info("Update mask with the new regions due to alignment") + + new_mask_area: torch.Tensor = torch.any(torch.any(data < -0.1, dim=0), dim=0).bool() + mask = (mask == 0).bool() + mask = torch.logical_or(mask, new_mask_area) + mask_positve: torch.Tensor = torch.logical_not(mask) + + mylogger.info("Update the data with the new mask") + data *= mask_positve.unsqueeze(0).unsqueeze(0).type(dtype=dtype) + mylogger.info("-==- Done -==-") + + mylogger.info("Interpolate the 'in-between' frames for oxygenation and volume") + data[oxygenation_index, 1:, ...] = ( + data[oxygenation_index, 1:, ...] + data[oxygenation_index, :-1, ...] + ) / 2.0 + data[volume_index, 1:, ...] = ( + data[volume_index, 1:, ...] + data[volume_index, :-1, ...] + ) / 2.0 + mylogger.info("-==- Done -==-") + + sample_frequency: float = 1.0 / meta_frame_time + + mylogger.info("Extract heartbeat from volume signal") + heartbeat_ts: torch.Tensor = bandpass( + data=data[volume_index, ...].movedim(0, -1).clone(), + device=data.device, + low_frequency=config["lower_freqency_bandpass"], + high_frequency=config["upper_freqency_bandpass"], + fs=sample_frequency, + filtfilt_chuck_size=config["heartbeat_filtfilt_chuck_size"], + ) + heartbeat_ts = heartbeat_ts.flatten(start_dim=0, end_dim=-2) + mask_flatten: torch.Tensor = mask_positve.flatten(start_dim=0, end_dim=-1) + + heartbeat_ts = heartbeat_ts[mask_flatten, :] + + heartbeat_ts = heartbeat_ts.movedim(0, -1) + heartbeat_ts -= heartbeat_ts.mean(dim=0, keepdim=True) + + volume_heartbeat, _, _ = torch.linalg.svd(heartbeat_ts, full_matrices=False) + volume_heartbeat = volume_heartbeat[:, 0] + volume_heartbeat -= volume_heartbeat[ + config["skip_frames_in_the_beginning"] : + ].mean() + + del heartbeat_ts + + temp_path = os.path.join( + config["export_path"], experiment_name + "_volume_heartbeat.npy" + ) + mylogger.info(f"Save volume heartbeat to {temp_path}") + np.save(temp_path, volume_heartbeat.cpu()) + mylogger.info("-==- Done -==-") + + volume_heartbeat = volume_heartbeat.unsqueeze(0).unsqueeze(0) + norm_volume_heartbeat = ( + volume_heartbeat[..., config["skip_frames_in_the_beginning"] :] ** 2 + ).sum(dim=-1) + + heartbeat_coefficients: torch.Tensor = torch.zeros( + (data.shape[0], data.shape[-2], data.shape[-1]), + dtype=data.dtype, + device=data.device, + ) + for i in range(0, data.shape[0]): + y = bandpass( + data=data[i, ...].movedim(0, -1).clone(), + device=data.device, + low_frequency=config["lower_freqency_bandpass"], + high_frequency=config["upper_freqency_bandpass"], + fs=sample_frequency, + filtfilt_chuck_size=config["heartbeat_filtfilt_chuck_size"], + )[..., config["skip_frames_in_the_beginning"] :] + y -= y.mean(dim=-1, keepdim=True) + + heartbeat_coefficients[i, ...] = ( + volume_heartbeat[..., config["skip_frames_in_the_beginning"] :] * y + ).sum(dim=-1) / norm_volume_heartbeat + + heartbeat_coefficients[i, ...] *= mask_positve.type( + dtype=heartbeat_coefficients.dtype + ) + del y + + temp_path = os.path.join( + config["export_path"], experiment_name + "_heartbeat_coefficients.npy" + ) + mylogger.info(f"Save heartbeat coefficients to {temp_path}") + np.save(temp_path, heartbeat_coefficients.cpu()) + mylogger.info("-==- Done -==-") + + mylogger.info("Remove heart beat from data") + data -= heartbeat_coefficients.unsqueeze(1) * volume_heartbeat.unsqueeze(0).movedim( + -1, 1 + ) + mylogger.info("-==- Done -==-") + + donor_heartbeat_factor = heartbeat_coefficients[donor_index, ...].clone() + acceptor_heartbeat_factor = heartbeat_coefficients[acceptor_index, ...].clone() + del heartbeat_coefficients + + mylogger.info("Calculate scaling factor for donor and acceptor") + donor_factor: torch.Tensor = ( + donor_heartbeat_factor + acceptor_heartbeat_factor + ) / (2 * donor_heartbeat_factor) + acceptor_factor: torch.Tensor = ( + donor_heartbeat_factor + acceptor_heartbeat_factor + ) / (2 * acceptor_heartbeat_factor) + + del donor_heartbeat_factor + del acceptor_heartbeat_factor + + temp_path = os.path.join( + config["export_path"], experiment_name + "_donor_factor.npy" + ) + mylogger.info(f"Save donor factor to {temp_path}") + np.save(temp_path, donor_factor.cpu()) + + temp_path = os.path.join( + config["export_path"], experiment_name + "_acceptor_factor.npy" + ) + mylogger.info(f"Save acceptor factor to {temp_path}") + np.save(temp_path, acceptor_factor.cpu()) + mylogger.info("-==- Done -==-") + + mylogger.info("Scale acceptor to heart beat amplitude") + mylogger.info("Calculate mean") + mean_values_acceptor = data[ + acceptor_index, config["skip_frames_in_the_beginning"] :, ... + ].nanmean(dim=0, keepdim=True) + mylogger.info("Remove mean") + data[acceptor_index, ...] -= mean_values_acceptor + mylogger.info("Apply acceptor_factor and mask") + data[acceptor_index, ...] *= acceptor_factor.unsqueeze(0) * mask.unsqueeze(0) + mylogger.info("Add mean") + data[acceptor_index, ...] += mean_values_acceptor + mylogger.info("-==- Done -==-") + + mylogger.info("Scale donor to heart beat amplitude") + mylogger.info("Calculate mean") + mean_values_donor = data[ + donor_index, config["skip_frames_in_the_beginning"] :, ... + ].nanmean(dim=0, keepdim=True) + mylogger.info("Remove mean") + data[donor_index, ...] -= mean_values_donor + mylogger.info("Apply donor_factor and mask") + data[donor_index, ...] *= donor_factor.unsqueeze(0) * mask.unsqueeze(0) + mylogger.info("Add mean") + data[donor_index, ...] += mean_values_donor + mylogger.info("-==- Done -==-") + + exit() + + return + + +mylogger = create_logger( + save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_4" +) +config = load_config(mylogger=mylogger) +device = get_torch_device(mylogger, config["force_to_cpu"]) + +mylogger.info(f"Create directory {config['export_path']} in the case it does not exist") +os.makedirs(config["export_path"], exist_ok=True) + +process_trial( + config=config, mylogger=mylogger, experiment_id=1, trial_id=1, device=device +)