diff --git a/config.json b/config.json new file mode 100644 index 0000000..2746a54 --- /dev/null +++ b/config.json @@ -0,0 +1,62 @@ +{ + "basic_path": "/data_1/hendrik", + "recoding_data": "2021-06-17", + "mouse_identifier": "M3859M", + //"basic_path": "/data_1/robert", + //"recoding_data": "2021-10-05", + //"mouse_identifier": "M3879M", + "raw_path": "raw", + "export_path": "output", + "ref_image_path": "ref_images", + // Ratio Sequence + "classical_ratio_mode": true, // true: a/d false: 1+a-d + // Regression + "target_camera_acceptor": "acceptor", + "regressor_cameras_acceptor": [ + "oxygenation", + "volume" + ], + "target_camera_donor": "donor", + "regressor_cameras_donor": [ + "oxygenation", + "volume" + ], + // binning + "binning_enable": true, + "binning_at_the_end": false, + "binning_kernel_size": 4, + "binning_stride": 4, + "binning_divisor_override": 1, + // alignment + "alignment_batch_size": 200, + "rotation_stabilization_threshold_factor": 3.0, // >= 1.0 + "rotation_stabilization_threshold_border": 0.9, // <= 1.0 + // Heart beat detection + "lower_freqency_bandpass": 5.0, // Hz + "upper_freqency_bandpass": 14.0, // Hz + "heartbeat_filtfilt_chuck_size": 10, + // Gauss smear + "gauss_smear_spatial_width": 8, + "gauss_smear_temporal_width": 0.1, + "gauss_smear_use_matlab_mask": false, + // LED Ramp on + "skip_frames_in_the_beginning": 100, // Frames + // PyTorch + "dtype": "float32", + "force_to_cpu": false, + // Save + "save_as_python": true, // produces .npz files (compressed) + "save_as_matlab": false, // produces .hd5 file (compressed) + // Save extra information + "save_alignment": false, + "save_heartbeat": false, + "save_factors": false, + "save_regression_coefficients": false, + // Not important parameter + "required_order": [ + "acceptor", + "donor", + "oxygenation", + "volume" + ] +} \ No newline at end of file diff --git a/stage_1_get_ref_image.py b/stage_1_get_ref_image.py new file mode 100644 index 0000000..55435f4 --- /dev/null +++ b/stage_1_get_ref_image.py @@ -0,0 +1,126 @@ +import os +import torch +import numpy as np + + +from functions.get_experiments import get_experiments +from functions.get_trials import get_trials +from functions.bandpass import bandpass +from functions.create_logger import create_logger +from functions.get_torch_device import get_torch_device +from functions.load_config import load_config +from functions.data_raw_loader import data_raw_loader + +mylogger = create_logger( + save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_1" +) + +config = load_config(mylogger=mylogger) + +if config["binning_enable"] and (config["binning_at_the_end"] is False): + device: torch.device = torch.device("cpu") +else: + device = get_torch_device(mylogger, config["force_to_cpu"]) + + +dtype_str: str = config["dtype"] +dtype: torch.dtype = getattr(torch, dtype_str) + +raw_data_path: str = os.path.join( + config["basic_path"], + config["recoding_data"], + config["mouse_identifier"], + config["raw_path"], +) + +mylogger.info(f"Using data path: {raw_data_path}") + +first_experiment_id: int = int(get_experiments(raw_data_path).min()) +first_trial_id: int = int(get_trials(raw_data_path, first_experiment_id).min()) + +meta_channels: list[str] +meta_mouse_markings: str +meta_recording_date: str +meta_stimulation_times: dict +meta_experiment_names: dict +meta_trial_recording_duration: float +meta_frame_time: float +meta_mouse: str +data: torch.Tensor + +if config["binning_enable"] and (config["binning_at_the_end"] is False): + force_to_cpu_memory: bool = True +else: + force_to_cpu_memory = False + +mylogger.info("Loading data") + +( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + data, +) = data_raw_loader( + raw_data_path=raw_data_path, + mylogger=mylogger, + experiment_id=first_experiment_id, + trial_id=first_trial_id, + device=device, + force_to_cpu_memory=force_to_cpu_memory, + config=config, +) +mylogger.info("-==- Done -==-") + +output_path = config["ref_image_path"] +mylogger.info(f"Create directory {output_path} in the case it does not exist") +os.makedirs(output_path, exist_ok=True) + +mylogger.info("Reference images") +for i in range(0, len(meta_channels)): + temp_path: str = os.path.join(output_path, meta_channels[i] + ".npy") + mylogger.info(f"Extract and save: {temp_path}") + frame_id: int = data.shape[-2] // 2 + mylogger.info(f"Will use frame id: {frame_id}") + ref_image: np.ndarray = ( + data[:, :, frame_id, meta_channels.index(meta_channels[i])] + .clone() + .cpu() + .numpy() + ) + np.save(temp_path, ref_image) +mylogger.info("-==- Done -==-") + +sample_frequency: float = 1.0 / meta_frame_time +mylogger.info( + ( + f"Heartbeat power {config['lower_freqency_bandpass']} Hz" + f" - {config['upper_freqency_bandpass']} Hz," + f" sample-rate: {sample_frequency}," + f" skipping the first {config['skip_frames_in_the_beginning']} frames" + ) +) + +for i in range(0, len(meta_channels)): + temp_path = os.path.join(output_path, meta_channels[i] + "_var.npy") + mylogger.info(f"Extract and save: {temp_path}") + + heartbeat_ts: torch.Tensor = bandpass( + data=data[..., i], + device=data.device, + low_frequency=config["lower_freqency_bandpass"], + high_frequency=config["upper_freqency_bandpass"], + fs=sample_frequency, + filtfilt_chuck_size=10, + ) + + heartbeat_power = heartbeat_ts[..., config["skip_frames_in_the_beginning"] :].var( + dim=-1 + ) + np.save(temp_path, heartbeat_power) + +mylogger.info("-==- Done -==-") diff --git a/stage_2_make_heartbeat_mask.py b/stage_2_make_heartbeat_mask.py new file mode 100644 index 0000000..e36516b --- /dev/null +++ b/stage_2_make_heartbeat_mask.py @@ -0,0 +1,153 @@ +import matplotlib.pyplot as plt # type:ignore +import matplotlib +import numpy as np +import torch +import os + +from matplotlib.widgets import Slider, Button # type:ignore +from functools import partial +from functions.gauss_smear_individual import gauss_smear_individual +from functions.create_logger import create_logger +from functions.get_torch_device import get_torch_device +from functions.load_config import load_config + +mylogger = create_logger( + save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_2" +) + +config = load_config(mylogger=mylogger) + +path: str = config["ref_image_path"] +use_channel: str = "donor" +spatial_width: float = 4.0 +temporal_width: float = 0.1 + +threshold: float = 0.05 + +heartbeat_mask_threshold_file: str = os.path.join(path, "heartbeat_mask_threshold.npy") +if os.path.isfile(heartbeat_mask_threshold_file): + mylogger.info(f"loading previous threshold file: {heartbeat_mask_threshold_file}") + threshold = float(np.load(heartbeat_mask_threshold_file)[0]) + +mylogger.info(f"initial threshold is {threshold}") + +image_ref_file: str = os.path.join(path, use_channel + ".npy") +image_var_file: str = os.path.join(path, use_channel + "_var.npy") +heartbeat_mask_file: str = os.path.join(path, "heartbeat_mask.npy") + +device = get_torch_device(mylogger, config["force_to_cpu"]) + + +def next_frame( + i: float, images: np.ndarray, image_handle: matplotlib.image.AxesImage +) -> None: + global threshold + threshold = i + + display_image: np.ndarray = images.copy() + display_image[..., 2] = display_image[..., 0] + mask: np.ndarray = np.where(images[..., 2] >= i, 1.0, np.nan)[..., np.newaxis] + display_image *= mask + display_image = np.nan_to_num(display_image, nan=1.0) + + image_handle.set_data(display_image) + return + + +def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None: + global threshold + global image_3color + global path + global mylogger + global heartbeat_mask_file + global heartbeat_mask_threshold_file + + mylogger.info(f"Threshold: {threshold}") + + mask: np.ndarray = image_3color[..., 2] >= threshold + mylogger.info(f"Save mask to: {heartbeat_mask_file}") + np.save(heartbeat_mask_file, mask) + mylogger.info(f"Save threshold to: {heartbeat_mask_threshold_file}") + np.save(heartbeat_mask_threshold_file, np.array([threshold])) + exit() + + +def on_clicked_cancel(event: matplotlib.backend_bases.MouseEvent) -> None: + exit() + + +mylogger.info(f"loading image reference file: {image_ref_file}") +image_ref: np.ndarray = np.load(image_ref_file) +image_ref /= image_ref.max() + +mylogger.info(f"loading image heartbeat power: {image_var_file}") +image_var: np.ndarray = np.load(image_var_file) +image_var /= image_var.max() + +mylogger.info("Smear the image heartbeat power patially") +temp, _ = gauss_smear_individual( + input=torch.tensor(image_var[..., np.newaxis], device=device), + spatial_width=spatial_width, + temporal_width=temporal_width, + use_matlab_mask=False, +) +temp /= temp.max() + +mylogger.info("-==- DONE -==-") + +image_3color = np.concatenate( + ( + np.zeros_like(image_ref[..., np.newaxis]), + image_ref[..., np.newaxis], + temp.cpu().numpy(), + ), + axis=-1, +) + +mylogger.info("Prepare image") + +display_image = image_3color.copy() +display_image[..., 2] = display_image[..., 0] +mask = np.where(image_3color[..., 2] >= threshold, 1.0, np.nan)[..., np.newaxis] +display_image *= mask +display_image = np.nan_to_num(display_image, nan=1.0) + +value_sort = np.sort(image_var.flatten()) +value_sort_max = value_sort[int(value_sort.shape[0] * 0.95)] +mylogger.info("-==- DONE -==-") + +mylogger.info("Create figure") + +fig: matplotlib.figure.Figure = plt.figure() + +image_handle = plt.imshow(display_image, vmin=0, vmax=1, cmap="hot") + +mylogger.info("Add controls") + +axfreq = fig.add_axes(rect=(0.4, 0.9, 0.3, 0.03)) +slice_slider = Slider( + ax=axfreq, + label="Threshold", + valmin=0, + valmax=value_sort_max, + valinit=threshold, + valstep=value_sort_max / 100.0, +) +axbutton_accept = fig.add_axes(rect=(0.3, 0.85, 0.2, 0.04)) +button_accept = Button( + ax=axbutton_accept, label="Accept", image=None, color="0.85", hovercolor="0.95" +) +button_accept.on_clicked(on_clicked_accept) # type: ignore + +axbutton_cancel = fig.add_axes(rect=(0.55, 0.85, 0.2, 0.04)) +button_cancel = Button( + ax=axbutton_cancel, label="Cancel", image=None, color="0.85", hovercolor="0.95" +) +button_cancel.on_clicked(on_clicked_cancel) # type: ignore + +slice_slider.on_changed( + partial(next_frame, images=image_3color, image_handle=image_handle) +) + +mylogger.info("Display") +plt.show() diff --git a/stage_3_refine_mask.py b/stage_3_refine_mask.py new file mode 100644 index 0000000..83f9ecd --- /dev/null +++ b/stage_3_refine_mask.py @@ -0,0 +1,157 @@ +import os +import numpy as np + +import matplotlib.pyplot as plt # type:ignore +import matplotlib +from matplotlib.widgets import Button # type:ignore + +# pip install roipoly +from roipoly import RoiPoly # type:ignore + +from functions.create_logger import create_logger +from functions.get_torch_device import get_torch_device +from functions.load_config import load_config + + +def compose_image(image_3color: np.ndarray, mask: np.ndarray) -> np.ndarray: + display_image = image_3color.copy() + display_image[..., 2] = display_image[..., 0] + display_image[mask == 0, :] = 1.0 + return display_image + + +def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None: + global mylogger + global refined_mask_file + global mask + + mylogger.info(f"Save mask to: {refined_mask_file}") + np.save(refined_mask_file, mask) + + exit() + + +def on_clicked_cancel(event: matplotlib.backend_bases.MouseEvent) -> None: + global mylogger + mylogger.info("Ended without saving the mask") + exit() + + +def on_clicked_add(event: matplotlib.backend_bases.MouseEvent) -> None: + global new_roi + global mask + global image_3color + global display_image + global mylogger + if len(new_roi.x) > 0: + mylogger.info("A ROI with the following coordiantes has been added to the mask") + for i in range(0, len(new_roi.x)): + mylogger.info(f"{round(new_roi.x[i],1)} x {round(new_roi.y[i],1)}") + mylogger.info("") + new_mask = new_roi.get_mask(display_image[:, :, 0]) + mask[new_mask] = 0.0 + display_image = compose_image(image_3color=image_3color, mask=mask) + image_handle.set_data(display_image) + for line in ax_main.lines: + line.remove() + plt.draw() + + new_roi = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False) + + +def on_clicked_remove(event: matplotlib.backend_bases.MouseEvent) -> None: + global new_roi + global mask + global image_3color + global display_image + if len(new_roi.x) > 0: + mylogger.info( + "A ROI with the following coordiantes has been removed from the mask" + ) + for i in range(0, len(new_roi.x)): + mylogger.info(f"{round(new_roi.x[i],1)} x {round(new_roi.y[i],1)}") + mylogger.info("") + new_mask = new_roi.get_mask(display_image[:, :, 0]) + mask[new_mask] = 1.0 + display_image = compose_image(image_3color=image_3color, mask=mask) + image_handle.set_data(display_image) + for line in ax_main.lines: + line.remove() + plt.draw() + new_roi = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False) + + +mylogger = create_logger( + save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_3" +) + +config = load_config(mylogger=mylogger) + +device = get_torch_device(mylogger, config["force_to_cpu"]) + +path: str = config["ref_image_path"] +use_channel: str = "donor" + +image_ref_file: str = os.path.join(path, use_channel + ".npy") +heartbeat_mask_file: str = os.path.join(path, "heartbeat_mask.npy") +refined_mask_file: str = os.path.join(path, "mask_not_rotated.npy") + +mylogger.info(f"loading image reference file: {image_ref_file}") +image_ref: np.ndarray = np.load(image_ref_file) +image_ref /= image_ref.max() + +mylogger.info(f"loading heartbeat mask: {heartbeat_mask_file}") +mask: np.ndarray = np.load(heartbeat_mask_file) + +image_3color = np.concatenate( + ( + np.zeros_like(image_ref[..., np.newaxis]), + image_ref[..., np.newaxis], + np.zeros_like(image_ref[..., np.newaxis]), + ), + axis=-1, +) + +mylogger.info("-==- DONE -==-") + +fig, ax_main = plt.subplots() + +display_image = compose_image(image_3color=image_3color, mask=mask) +image_handle = ax_main.imshow(display_image, vmin=0, vmax=1, cmap="hot") + +mylogger.info("Add controls") + +axbutton_accept = fig.add_axes(rect=(0.3, 0.85, 0.2, 0.04)) +button_accept = Button( + ax=axbutton_accept, label="Accept", image=None, color="0.85", hovercolor="0.95" +) +button_accept.on_clicked(on_clicked_accept) # type: ignore + +axbutton_cancel = fig.add_axes(rect=(0.5, 0.85, 0.2, 0.04)) +button_cancel = Button( + ax=axbutton_cancel, label="Cancel", image=None, color="0.85", hovercolor="0.95" +) +button_cancel.on_clicked(on_clicked_cancel) # type: ignore + +axbutton_addmask = fig.add_axes(rect=(0.3, 0.9, 0.2, 0.04)) +button_addmask = Button( + ax=axbutton_addmask, label="Add mask", image=None, color="0.85", hovercolor="0.95" +) +button_addmask.on_clicked(on_clicked_add) # type: ignore + +axbutton_removemask = fig.add_axes(rect=(0.5, 0.9, 0.2, 0.04)) +button_removemask = Button( + ax=axbutton_removemask, + label="Remove mask", + image=None, + color="0.85", + hovercolor="0.95", +) +button_removemask.on_clicked(on_clicked_remove) # type: ignore + +# ax_main.cla() + +mylogger.info("Display") +new_roi: RoiPoly = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False) + +plt.show() diff --git a/stage_4_process.py b/stage_4_process.py new file mode 100644 index 0000000..b822856 --- /dev/null +++ b/stage_4_process.py @@ -0,0 +1,918 @@ +import numpy as np +import torch +import torchvision as tv # type: ignore + +import os +import logging +import h5py # type: ignore + +from functions.create_logger import create_logger +from functions.get_torch_device import get_torch_device +from functions.load_config import load_config +from functions.get_experiments import get_experiments +from functions.get_trials import get_trials +from functions.binning import binning +from functions.ImageAlignment import ImageAlignment +from functions.align_refref import align_refref +from functions.perform_donor_volume_rotation import perform_donor_volume_rotation +from functions.perform_donor_volume_translation import perform_donor_volume_translation +from functions.bandpass import bandpass +from functions.gauss_smear_individual import gauss_smear_individual +from functions.regression import regression +from functions.data_raw_loader import data_raw_loader + + +@torch.no_grad() +def process_trial( + config: dict, + mylogger: logging.Logger, + experiment_id: int, + trial_id: int, + device: torch.device, +): + + mylogger.info("") + mylogger.info("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") + mylogger.info("~ TRIAL START ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") + mylogger.info("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") + mylogger.info("") + + if device != torch.device("cpu"): + torch.cuda.empty_cache() + mylogger.info("Empty CUDA cache") + cuda_total_memory: int = torch.cuda.get_device_properties( + device.index + ).total_memory + else: + cuda_total_memory = 0 + + raw_data_path: str = os.path.join( + config["basic_path"], + config["recoding_data"], + config["mouse_identifier"], + config["raw_path"], + ) + + if config["binning_enable"] and (config["binning_at_the_end"] is False): + force_to_cpu_memory: bool = True + else: + force_to_cpu_memory = False + + meta_channels: list[str] + meta_mouse_markings: str + meta_recording_date: str + meta_stimulation_times: dict + meta_experiment_names: dict + meta_trial_recording_duration: float + meta_frame_time: float + meta_mouse: str + data: torch.Tensor + + ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + data, + ) = data_raw_loader( + raw_data_path=raw_data_path, + mylogger=mylogger, + experiment_id=experiment_id, + trial_id=trial_id, + device=device, + force_to_cpu_memory=force_to_cpu_memory, + config=config, + ) + experiment_name: str = f"Exp{experiment_id:03d}_Trial{trial_id:03d}" + + dtype_str = config["dtype"] + dtype_np: np.dtype = getattr(np, dtype_str) + + dtype: torch.dtype = data.dtype + + if device != torch.device("cpu"): + free_mem = cuda_total_memory - max( + [torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)] + ) + mylogger.info(f"CUDA memory: {free_mem//1024} MByte") + + mylogger.info(f"Data shape: {data.shape}") + mylogger.info("-==- Done -==-") + + mylogger.info("Finding limit values in the RAW data and mark them for masking") + limit: float = (2**16) - 1 + for i in range(0, data.shape[3]): + zero_pixel_mask: torch.Tensor = torch.any(data[..., i] >= limit, dim=-1) + data[zero_pixel_mask, :, i] = -100.0 + mylogger.info( + f"{meta_channels[i]}: " + f"found {int(zero_pixel_mask.type(dtype=dtype).sum())} pixel " + f"with limit values " + ) + mylogger.info("-==- Done -==-") + + mylogger.info("Reference images and mask") + + ref_image_path: str = config["ref_image_path"] + + ref_image_path_acceptor: str = os.path.join(ref_image_path, "acceptor.npy") + if os.path.isfile(ref_image_path_acceptor) is False: + mylogger.info(f"Could not load ref file: {ref_image_path_acceptor}") + assert os.path.isfile(ref_image_path_acceptor) + return + + mylogger.info(f"Loading ref file data: {ref_image_path_acceptor}") + ref_image_acceptor: torch.Tensor = torch.tensor( + np.load(ref_image_path_acceptor).astype(dtype_np), dtype=dtype, device=device + ) + + ref_image_path_donor: str = os.path.join(ref_image_path, "donor.npy") + if os.path.isfile(ref_image_path_donor) is False: + mylogger.info(f"Could not load ref file: {ref_image_path_donor}") + assert os.path.isfile(ref_image_path_donor) + return + + mylogger.info(f"Loading ref file data: {ref_image_path_donor}") + ref_image_donor: torch.Tensor = torch.tensor( + np.load(ref_image_path_donor).astype(dtype_np), dtype=dtype, device=device + ) + + ref_image_path_oxygenation: str = os.path.join(ref_image_path, "oxygenation.npy") + if os.path.isfile(ref_image_path_oxygenation) is False: + mylogger.info(f"Could not load ref file: {ref_image_path_oxygenation}") + assert os.path.isfile(ref_image_path_oxygenation) + return + + mylogger.info(f"Loading ref file data: {ref_image_path_oxygenation}") + ref_image_oxygenation: torch.Tensor = torch.tensor( + np.load(ref_image_path_oxygenation).astype(dtype_np), dtype=dtype, device=device + ) + + ref_image_path_volume: str = os.path.join(ref_image_path, "volume.npy") + if os.path.isfile(ref_image_path_volume) is False: + mylogger.info(f"Could not load ref file: {ref_image_path_volume}") + assert os.path.isfile(ref_image_path_volume) + return + + mylogger.info(f"Loading ref file data: {ref_image_path_volume}") + ref_image_volume: torch.Tensor = torch.tensor( + np.load(ref_image_path_volume).astype(dtype_np), dtype=dtype, device=device + ) + + refined_mask_file: str = os.path.join(ref_image_path, "mask_not_rotated.npy") + if os.path.isfile(refined_mask_file) is False: + mylogger.info(f"Could not load mask file: {refined_mask_file}") + assert os.path.isfile(refined_mask_file) + return + + mylogger.info(f"Loading mask file data: {refined_mask_file}") + mask: torch.Tensor = torch.tensor( + np.load(refined_mask_file).astype(dtype_np), dtype=dtype, device=device + ) + mylogger.info("-==- Done -==-") + + if config["binning_enable"] and (config["binning_at_the_end"] is False): + mylogger.info("Binning of data") + mylogger.info( + ( + f"kernel_size={int(config['binning_kernel_size'])}, " + f"stride={int(config['binning_stride'])}, " + f"divisor_override={int(config['binning_divisor_override'])}" + ) + ) + + data = binning( + data, + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=int(config["binning_divisor_override"]), + ).to(device=device) + ref_image_acceptor = ( + binning( + ref_image_acceptor.unsqueeze(-1).unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=int(config["binning_divisor_override"]), + ) + .squeeze(-1) + .squeeze(-1) + ) + ref_image_donor = ( + binning( + ref_image_donor.unsqueeze(-1).unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=int(config["binning_divisor_override"]), + ) + .squeeze(-1) + .squeeze(-1) + ) + ref_image_oxygenation = ( + binning( + ref_image_oxygenation.unsqueeze(-1).unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=int(config["binning_divisor_override"]), + ) + .squeeze(-1) + .squeeze(-1) + ) + ref_image_volume = ( + binning( + ref_image_volume.unsqueeze(-1).unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=int(config["binning_divisor_override"]), + ) + .squeeze(-1) + .squeeze(-1) + ) + mask = ( + binning( + mask.unsqueeze(-1).unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=int(config["binning_divisor_override"]), + ) + .squeeze(-1) + .squeeze(-1) + ) + mylogger.info(f"Data shape: {data.shape}") + mylogger.info("-==- Done -==-") + + mylogger.info("Preparing alignment") + image_alignment = ImageAlignment(default_dtype=dtype, device=device) + + mylogger.info("Re-order Raw data") + data = data.moveaxis(-2, 0).moveaxis(-1, 0) + mylogger.info(f"Data shape: {data.shape}") + mylogger.info("-==- Done -==-") + + mylogger.info("Alignment of the ref images and the mask") + mylogger.info("Ref image of donor stays fixed.") + mylogger.info("Ref image of volume and the mask doesn't need to be touched") + mylogger.info("Calculate translation and rotation between the reference images") + angle_refref, tvec_refref, ref_image_acceptor, ref_image_donor = align_refref( + mylogger=mylogger, + ref_image_acceptor=ref_image_acceptor, + ref_image_donor=ref_image_donor, + image_alignment=image_alignment, + batch_size=config["alignment_batch_size"], + fill_value=-100.0, + ) + mylogger.info(f"Rotation: {round(float(angle_refref[0]),2)} degree") + mylogger.info( + f"Translation: {round(float(tvec_refref[0]),1)} x {round(float(tvec_refref[1]),1)} pixel" + ) + + if config["save_alignment"]: + temp_path: str = os.path.join( + config["export_path"], experiment_name + "_angle_refref.npy" + ) + mylogger.info(f"Save angle to {temp_path}") + np.save(temp_path, angle_refref.cpu()) + + temp_path = os.path.join( + config["export_path"], experiment_name + "_tvec_refref.npy" + ) + mylogger.info(f"Save translation vector to {temp_path}") + np.save(temp_path, tvec_refref.cpu()) + + mylogger.info("Moving & rotating the oxygenation ref image") + ref_image_oxygenation = tv.transforms.functional.affine( + img=ref_image_oxygenation.unsqueeze(0), + angle=-float(angle_refref), + translate=[0, 0], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=-100.0, + ) + + ref_image_oxygenation = tv.transforms.functional.affine( + img=ref_image_oxygenation, + angle=0, + translate=[tvec_refref[1], tvec_refref[0]], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=-100.0, + ).squeeze(0) + mylogger.info("-==- Done -==-") + + mylogger.info("Rotate and translate the acceptor and oxygenation data accordingly") + acceptor_index: int = config["required_order"].index("acceptor") + donor_index: int = config["required_order"].index("donor") + oxygenation_index: int = config["required_order"].index("oxygenation") + volume_index: int = config["required_order"].index("volume") + + mylogger.info("Rotate acceptor") + data[acceptor_index, ...] = tv.transforms.functional.affine( + img=data[acceptor_index, ...], + angle=-float(angle_refref), + translate=[0, 0], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=-100.0, + ) + + mylogger.info("Translate acceptor") + data[acceptor_index, ...] = tv.transforms.functional.affine( + img=data[acceptor_index, ...], + angle=0, + translate=[tvec_refref[1], tvec_refref[0]], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=-100.0, + ) + + mylogger.info("Rotate oxygenation") + data[oxygenation_index, ...] = tv.transforms.functional.affine( + img=data[oxygenation_index, ...], + angle=-float(angle_refref), + translate=[0, 0], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=-100.0, + ) + + mylogger.info("Translate oxygenation") + data[oxygenation_index, ...] = tv.transforms.functional.affine( + img=data[oxygenation_index, ...], + angle=0, + translate=[tvec_refref[1], tvec_refref[0]], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=-100.0, + ) + mylogger.info("-==- Done -==-") + + mylogger.info("Perform rotation between donor and volume and its ref images") + mylogger.info("for all frames and then rotate all the data accordingly") + perform_donor_volume_rotation + ( + data[acceptor_index, ...], + data[donor_index, ...], + data[oxygenation_index, ...], + data[volume_index, ...], + angle_donor_volume, + ) = perform_donor_volume_rotation( + mylogger=mylogger, + acceptor=data[acceptor_index, ...], + donor=data[donor_index, ...], + oxygenation=data[oxygenation_index, ...], + volume=data[volume_index, ...], + ref_image_donor=ref_image_donor, + ref_image_volume=ref_image_volume, + image_alignment=image_alignment, + batch_size=config["alignment_batch_size"], + fill_value=-100.0, + config=config, + ) + + mylogger.info( + f"angles: " + f"min {round(float(angle_donor_volume.min()),2)} " + f"max {round(float(angle_donor_volume.max()),2)} " + f"mean {round(float(angle_donor_volume.mean()),2)} " + ) + + if config["save_alignment"]: + temp_path = os.path.join( + config["export_path"], experiment_name + "_angle_donor_volume.npy" + ) + mylogger.info(f"Save angles to {temp_path}") + np.save(temp_path, angle_donor_volume.cpu()) + mylogger.info("-==- Done -==-") + + mylogger.info("Perform translation between donor and volume and its ref images") + mylogger.info("for all frames and then translate all the data accordingly") + ( + data[acceptor_index, ...], + data[donor_index, ...], + data[oxygenation_index, ...], + data[volume_index, ...], + tvec_donor_volume, + ) = perform_donor_volume_translation( + mylogger=mylogger, + acceptor=data[acceptor_index, ...], + donor=data[donor_index, ...], + oxygenation=data[oxygenation_index, ...], + volume=data[volume_index, ...], + ref_image_donor=ref_image_donor, + ref_image_volume=ref_image_volume, + image_alignment=image_alignment, + batch_size=config["alignment_batch_size"], + fill_value=-100.0, + config=config, + ) + + mylogger.info( + f"translation dim 0: " + f"min {round(float(tvec_donor_volume[:,0].min()),1)} " + f"max {round(float(tvec_donor_volume[:,0].max()),1)} " + f"mean {round(float(tvec_donor_volume[:,0].mean()),1)} " + ) + mylogger.info( + f"translation dim 1: " + f"min {round(float(tvec_donor_volume[:,1].min()),1)} " + f"max {round(float(tvec_donor_volume[:,1].max()),1)} " + f"mean {round(float(tvec_donor_volume[:,1].mean()),1)} " + ) + + if config["save_alignment"]: + temp_path = os.path.join( + config["export_path"], experiment_name + "_tvec_donor_volume.npy" + ) + mylogger.info(f"Save translation vector to {temp_path}") + np.save(temp_path, tvec_donor_volume.cpu()) + mylogger.info("-==- Done -==-") + + mylogger.info("Finding zeros values in the RAW data and mark them for masking") + for i in range(0, data.shape[0]): + zero_pixel_mask = torch.any(data[i, ...] == 0, dim=0) + data[i, :, zero_pixel_mask] = -100.0 + mylogger.info( + f"{config['required_order'][i]}: " + f"found {int(zero_pixel_mask.type(dtype=dtype).sum())} pixel " + f"with zeros " + ) + mylogger.info("-==- Done -==-") + + mylogger.info("Update mask with the new regions due to alignment") + + new_mask_area: torch.Tensor = torch.any(torch.any(data < -0.1, dim=0), dim=0).bool() + mask = (mask == 0).bool() + mask = torch.logical_or(mask, new_mask_area) + mask_negative: torch.Tensor = mask.clone() + mask_positve: torch.Tensor = torch.logical_not(mask) + del mask + + mylogger.info("Update the data with the new mask") + data *= mask_positve.unsqueeze(0).unsqueeze(0).type(dtype=dtype) + mylogger.info("-==- Done -==-") + + mylogger.info("Interpolate the 'in-between' frames for oxygenation and volume") + data[oxygenation_index, 1:, ...] = ( + data[oxygenation_index, 1:, ...] + data[oxygenation_index, :-1, ...] + ) / 2.0 + data[volume_index, 1:, ...] = ( + data[volume_index, 1:, ...] + data[volume_index, :-1, ...] + ) / 2.0 + mylogger.info("-==- Done -==-") + + sample_frequency: float = 1.0 / meta_frame_time + + mylogger.info("Extract heartbeat from volume signal") + heartbeat_ts: torch.Tensor = bandpass( + data=data[volume_index, ...].movedim(0, -1).clone(), + device=data.device, + low_frequency=config["lower_freqency_bandpass"], + high_frequency=config["upper_freqency_bandpass"], + fs=sample_frequency, + filtfilt_chuck_size=config["heartbeat_filtfilt_chuck_size"], + ) + heartbeat_ts = heartbeat_ts.flatten(start_dim=0, end_dim=-2) + mask_flatten: torch.Tensor = mask_positve.flatten(start_dim=0, end_dim=-1) + + heartbeat_ts = heartbeat_ts[mask_flatten, :] + + heartbeat_ts = heartbeat_ts.movedim(0, -1) + heartbeat_ts -= heartbeat_ts.mean(dim=0, keepdim=True) + + volume_heartbeat, _, _ = torch.linalg.svd(heartbeat_ts, full_matrices=False) + volume_heartbeat = volume_heartbeat[:, 0] + volume_heartbeat -= volume_heartbeat[ + config["skip_frames_in_the_beginning"] : + ].mean() + + del heartbeat_ts + + if device != torch.device("cpu"): + torch.cuda.empty_cache() + mylogger.info("Empty CUDA cache") + free_mem = cuda_total_memory - max( + [torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)] + ) + mylogger.info(f"CUDA memory: {free_mem//1024} MByte") + + if config["save_heartbeat"]: + temp_path = os.path.join( + config["export_path"], experiment_name + "_volume_heartbeat.npy" + ) + mylogger.info(f"Save volume heartbeat to {temp_path}") + np.save(temp_path, volume_heartbeat.cpu()) + mylogger.info("-==- Done -==-") + + volume_heartbeat = volume_heartbeat.unsqueeze(0).unsqueeze(0) + norm_volume_heartbeat = ( + volume_heartbeat[..., config["skip_frames_in_the_beginning"] :] ** 2 + ).sum(dim=-1) + + heartbeat_coefficients: torch.Tensor = torch.zeros( + (data.shape[0], data.shape[-2], data.shape[-1]), + dtype=data.dtype, + device=data.device, + ) + for i in range(0, data.shape[0]): + y = bandpass( + data=data[i, ...].movedim(0, -1).clone(), + device=data.device, + low_frequency=config["lower_freqency_bandpass"], + high_frequency=config["upper_freqency_bandpass"], + fs=sample_frequency, + filtfilt_chuck_size=config["heartbeat_filtfilt_chuck_size"], + )[..., config["skip_frames_in_the_beginning"] :] + y -= y.mean(dim=-1, keepdim=True) + + heartbeat_coefficients[i, ...] = ( + volume_heartbeat[..., config["skip_frames_in_the_beginning"] :] * y + ).sum(dim=-1) / norm_volume_heartbeat + + heartbeat_coefficients[i, ...] *= mask_positve.type( + dtype=heartbeat_coefficients.dtype + ) + del y + + if config["save_heartbeat"]: + temp_path = os.path.join( + config["export_path"], experiment_name + "_heartbeat_coefficients.npy" + ) + mylogger.info(f"Save heartbeat coefficients to {temp_path}") + np.save(temp_path, heartbeat_coefficients.cpu()) + mylogger.info("-==- Done -==-") + + mylogger.info("Remove heart beat from data") + data -= heartbeat_coefficients.unsqueeze(1) * volume_heartbeat.unsqueeze(0).movedim( + -1, 1 + ) + mylogger.info("-==- Done -==-") + + donor_heartbeat_factor = heartbeat_coefficients[donor_index, ...].clone() + acceptor_heartbeat_factor = heartbeat_coefficients[acceptor_index, ...].clone() + del heartbeat_coefficients + + if device != torch.device("cpu"): + torch.cuda.empty_cache() + mylogger.info("Empty CUDA cache") + free_mem = cuda_total_memory - max( + [torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)] + ) + mylogger.info(f"CUDA memory: {free_mem//1024} MByte") + + mylogger.info("Calculate scaling factor for donor and acceptor") + donor_factor: torch.Tensor = ( + donor_heartbeat_factor + acceptor_heartbeat_factor + ) / (2 * donor_heartbeat_factor) + acceptor_factor: torch.Tensor = ( + donor_heartbeat_factor + acceptor_heartbeat_factor + ) / (2 * acceptor_heartbeat_factor) + + del donor_heartbeat_factor + del acceptor_heartbeat_factor + + if config["save_factors"]: + temp_path = os.path.join( + config["export_path"], experiment_name + "_donor_factor.npy" + ) + mylogger.info(f"Save donor factor to {temp_path}") + np.save(temp_path, donor_factor.cpu()) + + temp_path = os.path.join( + config["export_path"], experiment_name + "_acceptor_factor.npy" + ) + mylogger.info(f"Save acceptor factor to {temp_path}") + np.save(temp_path, acceptor_factor.cpu()) + mylogger.info("-==- Done -==-") + + mylogger.info("Scale acceptor to heart beat amplitude") + mylogger.info("Calculate mean") + mean_values_acceptor = data[ + acceptor_index, config["skip_frames_in_the_beginning"] :, ... + ].nanmean(dim=0, keepdim=True) + + mylogger.info("Remove mean") + data[acceptor_index, ...] -= mean_values_acceptor + mylogger.info("Apply acceptor_factor and mask") + data[acceptor_index, ...] *= acceptor_factor.unsqueeze(0) * mask_positve.unsqueeze( + 0 + ) + mylogger.info("Add mean") + data[acceptor_index, ...] += mean_values_acceptor + mylogger.info("-==- Done -==-") + + mylogger.info("Scale donor to heart beat amplitude") + mylogger.info("Calculate mean") + mean_values_donor = data[ + donor_index, config["skip_frames_in_the_beginning"] :, ... + ].nanmean(dim=0, keepdim=True) + mylogger.info("Remove mean") + data[donor_index, ...] -= mean_values_donor + mylogger.info("Apply donor_factor and mask") + data[donor_index, ...] *= donor_factor.unsqueeze(0) * mask_positve.unsqueeze(0) + mylogger.info("Add mean") + data[donor_index, ...] += mean_values_donor + mylogger.info("-==- Done -==-") + + mylogger.info("Divide by mean over time") + data /= data[:, config["skip_frames_in_the_beginning"] :, ...].nanmean( + dim=1, + keepdim=True, + ) + data = data.nan_to_num(nan=0.0) + mylogger.info("-==- Done -==-") + + mylogger.info("Preparation for regression -- Gauss smear") + spatial_width = float(config["gauss_smear_spatial_width"]) + + if config["binning_enable"] and (config["binning_at_the_end"] is False): + spatial_width /= float(config["binning_kernel_size"]) + + mylogger.info( + f"Mask -- " + f"spatial width: {spatial_width}, " + f"temporal width: {float(config['gauss_smear_temporal_width'])}, " + f"use matlab mode: {bool(config['gauss_smear_use_matlab_mask'])} " + ) + + input_mask = mask_positve.type(dtype=dtype).clone() + + filtered_mask: torch.Tensor + filtered_mask, _ = gauss_smear_individual( + input=input_mask, + spatial_width=spatial_width, + temporal_width=float(config["gauss_smear_temporal_width"]), + use_matlab_mask=bool(config["gauss_smear_use_matlab_mask"]), + epsilon=float(torch.finfo(input_mask.dtype).eps), + ) + + mylogger.info("creating a copy of the data") + data_filtered = data.clone().movedim(1, -1) + if device != torch.device("cpu"): + torch.cuda.empty_cache() + mylogger.info("Empty CUDA cache") + free_mem = cuda_total_memory - max( + [torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)] + ) + mylogger.info(f"CUDA memory: {free_mem//1024} MByte") + + overwrite_fft_gauss: None | torch.Tensor = None + for i in range(0, data_filtered.shape[0]): + mylogger.info( + f"{config['required_order'][i]} -- " + f"spatial width: {spatial_width}, " + f"temporal width: {float(config['gauss_smear_temporal_width'])}, " + f"use matlab mode: {bool(config['gauss_smear_use_matlab_mask'])} " + ) + data_filtered[i, ...] *= input_mask.unsqueeze(-1) + data_filtered[i, ...], overwrite_fft_gauss = gauss_smear_individual( + input=data_filtered[i, ...], + spatial_width=spatial_width, + temporal_width=float(config["gauss_smear_temporal_width"]), + overwrite_fft_gauss=overwrite_fft_gauss, + use_matlab_mask=bool(config["gauss_smear_use_matlab_mask"]), + epsilon=float(torch.finfo(input_mask.dtype).eps), + ) + + data_filtered[i, ...] /= filtered_mask + 1e-20 + data_filtered[i, ...] += 1.0 - input_mask.unsqueeze(-1) + + del filtered_mask + del overwrite_fft_gauss + del input_mask + mylogger.info("data_filtered is populated") + + if device != torch.device("cpu"): + torch.cuda.empty_cache() + mylogger.info("Empty CUDA cache") + free_mem = cuda_total_memory - max( + [torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)] + ) + mylogger.info(f"CUDA memory: {free_mem//1024} MByte") + mylogger.info("-==- Done -==-") + + mylogger.info("Preperation for Regression") + mylogger.info("Move time dimensions of data to the last dimension") + data = data.movedim(1, -1) + + mylogger.info("Regression Acceptor") + mylogger.info(f"Target: {config['target_camera_acceptor']}") + mylogger.info( + f"Regressors: constant, linear and {config['regressor_cameras_acceptor']}" + ) + target_id: int = config["required_order"].index(config["target_camera_acceptor"]) + regressor_id: list[int] = [] + for i in range(0, len(config["regressor_cameras_acceptor"])): + regressor_id.append( + config["required_order"].index(config["regressor_cameras_acceptor"][i]) + ) + + data_acceptor, coefficients_acceptor = regression( + mylogger=mylogger, + target_camera_id=target_id, + regressor_camera_ids=regressor_id, + mask=mask_negative, + data=data, + data_filtered=data_filtered, + first_none_ramp_frame=int(config["skip_frames_in_the_beginning"]), + ) + + if config["save_regression_coefficients"]: + temp_path = os.path.join( + config["export_path"], experiment_name + "_coefficients_acceptor.npy" + ) + mylogger.info(f"Save acceptor coefficients to {temp_path}") + np.save(temp_path, coefficients_acceptor.cpu()) + del coefficients_acceptor + + mylogger.info("-==- Done -==-") + + mylogger.info("Regression Donor") + mylogger.info(f"Target: {config['target_camera_donor']}") + mylogger.info( + f"Regressors: constant, linear and {config['regressor_cameras_donor']}" + ) + target_id = config["required_order"].index(config["target_camera_donor"]) + regressor_id = [] + for i in range(0, len(config["regressor_cameras_donor"])): + regressor_id.append( + config["required_order"].index(config["regressor_cameras_donor"][i]) + ) + + data_donor, coefficients_donor = regression( + mylogger=mylogger, + target_camera_id=target_id, + regressor_camera_ids=regressor_id, + mask=mask_negative, + data=data, + data_filtered=data_filtered, + first_none_ramp_frame=int(config["skip_frames_in_the_beginning"]), + ) + + if config["save_regression_coefficients"]: + temp_path = os.path.join( + config["export_path"], experiment_name + "_coefficients_donor.npy" + ) + mylogger.info(f"Save acceptor donor to {temp_path}") + np.save(temp_path, coefficients_donor.cpu()) + del coefficients_donor + mylogger.info("-==- Done -==-") + + del data + del data_filtered + + if device != torch.device("cpu"): + torch.cuda.empty_cache() + mylogger.info("Empty CUDA cache") + free_mem = cuda_total_memory - max( + [torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)] + ) + mylogger.info(f"CUDA memory: {free_mem//1024} MByte") + + mylogger.info("Calculate ratio sequence") + if config["classical_ratio_mode"]: + mylogger.info("via acceptor / donor") + ratio_sequence: torch.Tensor = data_acceptor / data_donor + mylogger.info("via / mean over time") + ratio_sequence /= ratio_sequence.mean(dim=-1, keepdim=True) + else: + mylogger.info("via 1.0 + acceptor - donor") + ratio_sequence = 1.0 + data_acceptor - data_donor + + mylogger.info("Remove nan") + ratio_sequence = torch.nan_to_num(ratio_sequence, nan=0.0) + mylogger.info("-==- Done -==-") + + if config["binning_enable"] and config["binning_at_the_end"]: + mylogger.info("Binning of data") + mylogger.info( + ( + f"kernel_size={int(config['binning_kernel_size'])}, " + f"stride={int(config['binning_stride'])}, " + "divisor_override=None" + ) + ) + + ratio_sequence = binning( + ratio_sequence.unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=None, + ).squeeze(-1) + + mask_positve = ( + binning( + mask_positve.unsqueeze(-1).unsqueeze(-1).type(dtype=dtype), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=None, + ) + .squeeze(-1) + .squeeze(-1) + ) + mask_positve = (mask_positve > 0).type(torch.bool) + + if config["save_as_python"]: + temp_path = os.path.join( + config["export_path"], experiment_name + "_ratio_sequence.npz" + ) + mylogger.info(f"Save ratio_sequence and mask to {temp_path}") + np.savez_compressed( + temp_path, ratio_sequence=ratio_sequence.cpu(), mask=mask_positve.cpu() + ) + + if config["save_as_matlab"]: + temp_path = os.path.join( + config["export_path"], experiment_name + "_ratio_sequence.hd5" + ) + mylogger.info(f"Save ratio_sequence and mask to {temp_path}") + file_handle = h5py.File(temp_path, "w") + + mask_positve = mask_positve.movedim(0, -1) + ratio_sequence = ratio_sequence.movedim(1, -1).movedim(0, -1) + _ = file_handle.create_dataset( + "mask", + data=mask_positve.type(torch.uint8).cpu(), + compression="gzip", + compression_opts=9, + ) + _ = file_handle.create_dataset( + "ratio_sequence", + data=ratio_sequence.cpu(), + compression="gzip", + compression_opts=9, + ) + mylogger.info("Reminder: How to read with matlab:") + mylogger.info(f"mask = h5read('{temp_path}','/mask');") + mylogger.info(f"ratio_sequence = h5read('{temp_path}','/ratio_sequence');") + file_handle.close() + + del ratio_sequence + del mask_positve + del mask_negative + + mylogger.info("") + mylogger.info("***********************************************") + mylogger.info("* TRIAL END ***********************************") + mylogger.info("***********************************************") + mylogger.info("") + + return + + +mylogger = create_logger( + save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_4" +) +config = load_config(mylogger=mylogger) + +if (config["save_as_python"] is False) and (config["save_as_matlab"] is False): + mylogger.info("No output will be created. ") + mylogger.info("Change save_as_python and/or save_as_matlab in the config file") + mylogger.info("ERROR: STOP!!!") + exit() + +device = get_torch_device(mylogger, config["force_to_cpu"]) + +mylogger.info(f"Create directory {config['export_path']} in the case it does not exist") +os.makedirs(config["export_path"], exist_ok=True) + +raw_data_path: str = os.path.join( + config["basic_path"], + config["recoding_data"], + config["mouse_identifier"], + config["raw_path"], +) + +if os.path.isdir(raw_data_path) is False: + mylogger.info(f"ERROR: could not find raw directory {raw_data_path}!!!!") + exit() + +experiments = get_experiments(raw_data_path) + +for experiment_counter in range(0, experiments.shape[0]): + experiment_id = int(experiments[experiment_counter]) + trials = get_trials(raw_data_path, experiment_id) + for trial_counter in range(0, trials.shape[0]): + trial_id = int(trials[trial_counter]) + + mylogger.info("") + mylogger.info( + f"======= EXPERIMENT ID: {experiment_id} ==== TRIAL ID: {trial_id} =======" + ) + mylogger.info("") + + process_trial( + config=config, + mylogger=mylogger, + experiment_id=experiment_id, + trial_id=trial_id, + device=device, + )