From 2290dfe0d975a52a0cfc3ac63cc9522ff991cb16 Mon Sep 17 00:00:00 2001 From: David Rotermund <54365609+davrot@users.noreply.github.com> Date: Fri, 1 Mar 2024 01:17:12 +0100 Subject: [PATCH] Add files via upload --- config_M3879M_2021-10-05.json | 60 ++++++++ config_M_Sert_Cre_41.json | 60 ++++++++ config_M_Sert_Cre_49.json | 60 ++++++++ olivia_data_plotter_svd.py | 2 +- stage_1_get_ref_image.py | 210 ++++++++++++------------- stage_2_make_heartbeat_mask.py | 257 ++++++++++++++++--------------- stage_3_refine_mask.py | 272 +++++++++++++++++---------------- stage_4_process.py | 147 ++++++++++-------- 8 files changed, 648 insertions(+), 420 deletions(-) create mode 100644 config_M3879M_2021-10-05.json create mode 100644 config_M_Sert_Cre_41.json create mode 100644 config_M_Sert_Cre_49.json diff --git a/config_M3879M_2021-10-05.json b/config_M3879M_2021-10-05.json new file mode 100644 index 0000000..81b98ec --- /dev/null +++ b/config_M3879M_2021-10-05.json @@ -0,0 +1,60 @@ +{ + "basic_path": "/data_1/robert", + "recoding_data": "2021-10-05", + "mouse_identifier": "M3879M", + "raw_path": "raw", + "export_path": "output_M3879M_2021-10-05", + "ref_image_path": "ref_images_M3879M_2021-10-05", + // Ratio Sequence + "classical_ratio_mode": true, // true: a/d false: 1+a-d + // Regression + //"target_camera_acceptor": "acceptor", + "target_camera_acceptor": "", + "regressor_cameras_acceptor": [ + "oxygenation", + "volume" + ], + "target_camera_donor": "donor", + "regressor_cameras_donor": [ + // "oxygenation", + "volume" + ], + // binning + "binning_enable": true, + "binning_at_the_end": false, + "binning_kernel_size": 4, + "binning_stride": 4, + "binning_divisor_override": 1, + // alignment + "alignment_batch_size": 200, + "rotation_stabilization_threshold_factor": 3.0, // >= 1.0 + "rotation_stabilization_threshold_border": 0.9, // <= 1.0 + // Heart beat detection + "lower_freqency_bandpass": 5.0, // Hz + "upper_freqency_bandpass": 14.0, // Hz + "heartbeat_filtfilt_chuck_size": 10, + // Gauss smear + "gauss_smear_spatial_width": 8, + "gauss_smear_temporal_width": 0.1, + "gauss_smear_use_matlab_mask": false, + // LED Ramp on + "skip_frames_in_the_beginning": 100, // Frames + // PyTorch + "dtype": "float32", + "force_to_cpu": false, + // Save + "save_as_python": true, // produces .npz files (compressed) + "save_as_matlab": false, // produces .hd5 file (compressed) + // Save extra information + "save_alignment": false, + "save_heartbeat": false, + "save_factors": false, + "save_regression_coefficients": false, + // Not important parameter + "required_order": [ + "acceptor", + "donor", + "oxygenation", + "volume" + ] +} \ No newline at end of file diff --git a/config_M_Sert_Cre_41.json b/config_M_Sert_Cre_41.json new file mode 100644 index 0000000..668bd6c --- /dev/null +++ b/config_M_Sert_Cre_41.json @@ -0,0 +1,60 @@ +{ + "basic_path": "/data_1/hendrik", + "recoding_data": "2023-07-17", + "mouse_identifier": "M_Sert_Cre_41", + "raw_path": "raw", + "export_path": "output_M_Sert_Cre_41", + "ref_image_path": "ref_images_M_Sert_Cre_41", + // Ratio Sequence + "classical_ratio_mode": true, // true: a/d false: 1+a-d + // Regression + //"target_camera_acceptor": "acceptor", + "target_camera_acceptor": "", + "regressor_cameras_acceptor": [ + "oxygenation", + "volume" + ], + "target_camera_donor": "donor", + "regressor_cameras_donor": [ + // "oxygenation", + "volume" + ], + // binning + "binning_enable": true, + "binning_at_the_end": false, + "binning_kernel_size": 4, + "binning_stride": 4, + "binning_divisor_override": 1, + // alignment + "alignment_batch_size": 200, + "rotation_stabilization_threshold_factor": 3.0, // >= 1.0 + "rotation_stabilization_threshold_border": 0.9, // <= 1.0 + // Heart beat detection + "lower_freqency_bandpass": 5.0, // Hz + "upper_freqency_bandpass": 14.0, // Hz + "heartbeat_filtfilt_chuck_size": 10, + // Gauss smear + "gauss_smear_spatial_width": 8, + "gauss_smear_temporal_width": 0.1, + "gauss_smear_use_matlab_mask": false, + // LED Ramp on + "skip_frames_in_the_beginning": 100, // Frames + // PyTorch + "dtype": "float32", + "force_to_cpu": false, + // Save + "save_as_python": true, // produces .npz files (compressed) + "save_as_matlab": false, // produces .hd5 file (compressed) + // Save extra information + "save_alignment": false, + "save_heartbeat": false, + "save_factors": false, + "save_regression_coefficients": false, + // Not important parameter + "required_order": [ + "acceptor", + "donor", + "oxygenation", + "volume" + ] +} \ No newline at end of file diff --git a/config_M_Sert_Cre_49.json b/config_M_Sert_Cre_49.json new file mode 100644 index 0000000..1b0e58e --- /dev/null +++ b/config_M_Sert_Cre_49.json @@ -0,0 +1,60 @@ +{ + "basic_path": "/data_1/hendrik", + "recoding_data": "2023-03-15", + "mouse_identifier": "M_Sert_Cre_49", + "raw_path": "raw", + "export_path": "output_M_Sert_Cre_49", + "ref_image_path": "ref_images_M_Sert_Cre_49", + // Ratio Sequence + "classical_ratio_mode": true, // true: a/d false: 1+a-d + // Regression + //"target_camera_acceptor": "acceptor", + "target_camera_acceptor": "", + "regressor_cameras_acceptor": [ + "oxygenation", + "volume" + ], + "target_camera_donor": "donor", + "regressor_cameras_donor": [ + // "oxygenation", + "volume" + ], + // binning + "binning_enable": true, + "binning_at_the_end": false, + "binning_kernel_size": 4, + "binning_stride": 4, + "binning_divisor_override": 1, + // alignment + "alignment_batch_size": 200, + "rotation_stabilization_threshold_factor": 3.0, // >= 1.0 + "rotation_stabilization_threshold_border": 0.9, // <= 1.0 + // Heart beat detection + "lower_freqency_bandpass": 5.0, // Hz + "upper_freqency_bandpass": 14.0, // Hz + "heartbeat_filtfilt_chuck_size": 10, + // Gauss smear + "gauss_smear_spatial_width": 8, + "gauss_smear_temporal_width": 0.1, + "gauss_smear_use_matlab_mask": false, + // LED Ramp on + "skip_frames_in_the_beginning": 100, // Frames + // PyTorch + "dtype": "float32", + "force_to_cpu": false, + // Save + "save_as_python": true, // produces .npz files (compressed) + "save_as_matlab": false, // produces .hd5 file (compressed) + // Save extra information + "save_alignment": false, + "save_heartbeat": false, + "save_factors": false, + "save_regression_coefficients": false, + // Not important parameter + "required_order": [ + "acceptor", + "donor", + "oxygenation", + "volume" + ] +} \ No newline at end of file diff --git a/olivia_data_plotter_svd.py b/olivia_data_plotter_svd.py index 956b92a..c2bbffb 100644 --- a/olivia_data_plotter_svd.py +++ b/olivia_data_plotter_svd.py @@ -23,7 +23,7 @@ mylogger = create_logger( ) config = load_config(mylogger=mylogger) -experiment_id: int = 1 +experiment_id: int = 2 raw_data_path: str = os.path.join( config["basic_path"], diff --git a/stage_1_get_ref_image.py b/stage_1_get_ref_image.py index 637e324..0e5b6da 100644 --- a/stage_1_get_ref_image.py +++ b/stage_1_get_ref_image.py @@ -1,7 +1,7 @@ import os import torch import numpy as np - +import argh from functions.get_experiments import get_experiments from functions.get_trials import get_trials @@ -11,115 +11,119 @@ from functions.get_torch_device import get_torch_device from functions.load_config import load_config from functions.data_raw_loader import data_raw_loader -mylogger = create_logger( - save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_1" -) -config = load_config(mylogger=mylogger) - -if config["binning_enable"] and (config["binning_at_the_end"] is False): - device: torch.device = torch.device("cpu") -else: - device = get_torch_device(mylogger, config["force_to_cpu"]) - - -dtype_str: str = config["dtype"] -dtype: torch.dtype = getattr(torch, dtype_str) - -raw_data_path: str = os.path.join( - config["basic_path"], - config["recoding_data"], - config["mouse_identifier"], - config["raw_path"], -) - -mylogger.info(f"Using data path: {raw_data_path}") - -first_experiment_id: int = int(get_experiments(raw_data_path).min()) -first_trial_id: int = int(get_trials(raw_data_path, first_experiment_id).min()) - -meta_channels: list[str] -meta_mouse_markings: str -meta_recording_date: str -meta_stimulation_times: dict -meta_experiment_names: dict -meta_trial_recording_duration: float -meta_frame_time: float -meta_mouse: str -data: torch.Tensor - -if config["binning_enable"] and (config["binning_at_the_end"] is False): - force_to_cpu_memory: bool = True -else: - force_to_cpu_memory = False - -mylogger.info("Loading data") - -( - meta_channels, - meta_mouse_markings, - meta_recording_date, - meta_stimulation_times, - meta_experiment_names, - meta_trial_recording_duration, - meta_frame_time, - meta_mouse, - data, -) = data_raw_loader( - raw_data_path=raw_data_path, - mylogger=mylogger, - experiment_id=first_experiment_id, - trial_id=first_trial_id, - device=device, - force_to_cpu_memory=force_to_cpu_memory, - config=config, -) -mylogger.info("-==- Done -==-") - -output_path = config["ref_image_path"] -mylogger.info(f"Create directory {output_path} in the case it does not exist") -os.makedirs(output_path, exist_ok=True) - -mylogger.info("Reference images") -for i in range(0, len(meta_channels)): - temp_path: str = os.path.join(output_path, meta_channels[i] + ".npy") - mylogger.info(f"Extract and save: {temp_path}") - frame_id: int = data.shape[-2] // 2 - mylogger.info(f"Will use frame id: {frame_id}") - ref_image: np.ndarray = ( - data[:, :, frame_id, meta_channels.index(meta_channels[i])] - .clone() - .cpu() - .numpy() +def main(*, config_filename: str = "config.json") -> None: + mylogger = create_logger( + save_logging_messages=True, + display_logging_messages=True, + log_stage_name="stage_1", ) - np.save(temp_path, ref_image) -mylogger.info("-==- Done -==-") -sample_frequency: float = 1.0 / meta_frame_time -mylogger.info( + config = load_config(mylogger=mylogger, filename=config_filename) + + if config["binning_enable"] and (config["binning_at_the_end"] is False): + device: torch.device = torch.device("cpu") + else: + device = get_torch_device(mylogger, config["force_to_cpu"]) + + raw_data_path: str = os.path.join( + config["basic_path"], + config["recoding_data"], + config["mouse_identifier"], + config["raw_path"], + ) + + mylogger.info(f"Using data path: {raw_data_path}") + + first_experiment_id: int = int(get_experiments(raw_data_path).min()) + first_trial_id: int = int(get_trials(raw_data_path, first_experiment_id).min()) + + meta_channels: list[str] + meta_mouse_markings: str + meta_recording_date: str + meta_stimulation_times: dict + meta_experiment_names: dict + meta_trial_recording_duration: float + meta_frame_time: float + meta_mouse: str + data: torch.Tensor + + if config["binning_enable"] and (config["binning_at_the_end"] is False): + force_to_cpu_memory: bool = True + else: + force_to_cpu_memory = False + + mylogger.info("Loading data") + ( - f"Heartbeat power {config['lower_freqency_bandpass']} Hz" - f" - {config['upper_freqency_bandpass']} Hz," - f" sample-rate: {sample_frequency}," - f" skipping the first {config['skip_frames_in_the_beginning']} frames" + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + data, + ) = data_raw_loader( + raw_data_path=raw_data_path, + mylogger=mylogger, + experiment_id=first_experiment_id, + trial_id=first_trial_id, + device=device, + force_to_cpu_memory=force_to_cpu_memory, + config=config, ) -) + mylogger.info("-==- Done -==-") -for i in range(0, len(meta_channels)): - temp_path = os.path.join(output_path, meta_channels[i] + "_var.npy") - mylogger.info(f"Extract and save: {temp_path}") + output_path = config["ref_image_path"] + mylogger.info(f"Create directory {output_path} in the case it does not exist") + os.makedirs(output_path, exist_ok=True) - heartbeat_ts: torch.Tensor = bandpass( - data=data[..., i], - low_frequency=config["lower_freqency_bandpass"], - high_frequency=config["upper_freqency_bandpass"], - fs=sample_frequency, - filtfilt_chuck_size=10, + mylogger.info("Reference images") + for i in range(0, len(meta_channels)): + temp_path: str = os.path.join(output_path, meta_channels[i] + ".npy") + mylogger.info(f"Extract and save: {temp_path}") + frame_id: int = data.shape[-2] // 2 + mylogger.info(f"Will use frame id: {frame_id}") + ref_image: np.ndarray = ( + data[:, :, frame_id, meta_channels.index(meta_channels[i])] + .clone() + .cpu() + .numpy() + ) + np.save(temp_path, ref_image) + mylogger.info("-==- Done -==-") + + sample_frequency: float = 1.0 / meta_frame_time + mylogger.info( + ( + f"Heartbeat power {config['lower_freqency_bandpass']} Hz" + f" - {config['upper_freqency_bandpass']} Hz," + f" sample-rate: {sample_frequency}," + f" skipping the first {config['skip_frames_in_the_beginning']} frames" + ) ) - heartbeat_power = heartbeat_ts[..., config["skip_frames_in_the_beginning"] :].var( - dim=-1 - ) - np.save(temp_path, heartbeat_power) + for i in range(0, len(meta_channels)): + temp_path = os.path.join(output_path, meta_channels[i] + "_var.npy") + mylogger.info(f"Extract and save: {temp_path}") -mylogger.info("-==- Done -==-") + heartbeat_ts: torch.Tensor = bandpass( + data=data[..., i], + low_frequency=config["lower_freqency_bandpass"], + high_frequency=config["upper_freqency_bandpass"], + fs=sample_frequency, + filtfilt_chuck_size=10, + ) + + heartbeat_power = heartbeat_ts[ + ..., config["skip_frames_in_the_beginning"] : + ].var(dim=-1) + np.save(temp_path, heartbeat_power) + + mylogger.info("-==- Done -==-") + + +if __name__ == "__main__": + argh.dispatch_command(main) diff --git a/stage_2_make_heartbeat_mask.py b/stage_2_make_heartbeat_mask.py index e36516b..f9f3cd9 100644 --- a/stage_2_make_heartbeat_mask.py +++ b/stage_2_make_heartbeat_mask.py @@ -3,6 +3,7 @@ import matplotlib import numpy as np import torch import os +import argh from matplotlib.widgets import Slider, Button # type:ignore from functools import partial @@ -11,143 +12,151 @@ from functions.create_logger import create_logger from functions.get_torch_device import get_torch_device from functions.load_config import load_config -mylogger = create_logger( - save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_2" -) -config = load_config(mylogger=mylogger) +def main(*, config_filename: str = "config.json") -> None: + mylogger = create_logger( + save_logging_messages=True, + display_logging_messages=True, + log_stage_name="stage_2", + ) -path: str = config["ref_image_path"] -use_channel: str = "donor" -spatial_width: float = 4.0 -temporal_width: float = 0.1 + config = load_config(mylogger=mylogger, filename=config_filename) -threshold: float = 0.05 + path: str = config["ref_image_path"] + use_channel: str = "donor" + spatial_width: float = 4.0 + temporal_width: float = 0.1 -heartbeat_mask_threshold_file: str = os.path.join(path, "heartbeat_mask_threshold.npy") -if os.path.isfile(heartbeat_mask_threshold_file): - mylogger.info(f"loading previous threshold file: {heartbeat_mask_threshold_file}") - threshold = float(np.load(heartbeat_mask_threshold_file)[0]) + threshold: float = 0.05 -mylogger.info(f"initial threshold is {threshold}") + heartbeat_mask_threshold_file: str = os.path.join( + path, "heartbeat_mask_threshold.npy" + ) + if os.path.isfile(heartbeat_mask_threshold_file): + mylogger.info( + f"loading previous threshold file: {heartbeat_mask_threshold_file}" + ) + threshold = float(np.load(heartbeat_mask_threshold_file)[0]) -image_ref_file: str = os.path.join(path, use_channel + ".npy") -image_var_file: str = os.path.join(path, use_channel + "_var.npy") -heartbeat_mask_file: str = os.path.join(path, "heartbeat_mask.npy") + mylogger.info(f"initial threshold is {threshold}") -device = get_torch_device(mylogger, config["force_to_cpu"]) + image_ref_file: str = os.path.join(path, use_channel + ".npy") + image_var_file: str = os.path.join(path, use_channel + "_var.npy") + heartbeat_mask_file: str = os.path.join(path, "heartbeat_mask.npy") + device = get_torch_device(mylogger, config["force_to_cpu"]) -def next_frame( - i: float, images: np.ndarray, image_handle: matplotlib.image.AxesImage -) -> None: - global threshold - threshold = i + mylogger.info(f"loading image reference file: {image_ref_file}") + image_ref: np.ndarray = np.load(image_ref_file) + image_ref /= image_ref.max() - display_image: np.ndarray = images.copy() + mylogger.info(f"loading image heartbeat power: {image_var_file}") + image_var: np.ndarray = np.load(image_var_file) + image_var /= image_var.max() + + mylogger.info("Smear the image heartbeat power patially") + temp, _ = gauss_smear_individual( + input=torch.tensor(image_var[..., np.newaxis], device=device), + spatial_width=spatial_width, + temporal_width=temporal_width, + use_matlab_mask=False, + ) + temp /= temp.max() + + mylogger.info("-==- DONE -==-") + + image_3color = np.concatenate( + ( + np.zeros_like(image_ref[..., np.newaxis]), + image_ref[..., np.newaxis], + temp.cpu().numpy(), + ), + axis=-1, + ) + + mylogger.info("Prepare image") + + display_image = image_3color.copy() display_image[..., 2] = display_image[..., 0] - mask: np.ndarray = np.where(images[..., 2] >= i, 1.0, np.nan)[..., np.newaxis] + mask = np.where(image_3color[..., 2] >= threshold, 1.0, np.nan)[..., np.newaxis] display_image *= mask display_image = np.nan_to_num(display_image, nan=1.0) - image_handle.set_data(display_image) - return + value_sort = np.sort(image_var.flatten()) + value_sort_max = value_sort[int(value_sort.shape[0] * 0.95)] + mylogger.info("-==- DONE -==-") + + mylogger.info("Create figure") + + fig: matplotlib.figure.Figure = plt.figure() + + image_handle = plt.imshow(display_image, vmin=0, vmax=1, cmap="hot") + + mylogger.info("Add controls") + + def next_frame( + i: float, images: np.ndarray, image_handle: matplotlib.image.AxesImage + ) -> None: + nonlocal threshold + threshold = i + + display_image: np.ndarray = images.copy() + display_image[..., 2] = display_image[..., 0] + mask: np.ndarray = np.where(images[..., 2] >= i, 1.0, np.nan)[..., np.newaxis] + display_image *= mask + display_image = np.nan_to_num(display_image, nan=1.0) + + image_handle.set_data(display_image) + return + + def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None: + nonlocal threshold + nonlocal image_3color + nonlocal path + nonlocal mylogger + nonlocal heartbeat_mask_file + nonlocal heartbeat_mask_threshold_file + + mylogger.info(f"Threshold: {threshold}") + + mask: np.ndarray = image_3color[..., 2] >= threshold + mylogger.info(f"Save mask to: {heartbeat_mask_file}") + np.save(heartbeat_mask_file, mask) + mylogger.info(f"Save threshold to: {heartbeat_mask_threshold_file}") + np.save(heartbeat_mask_threshold_file, np.array([threshold])) + exit() + + def on_clicked_cancel(event: matplotlib.backend_bases.MouseEvent) -> None: + exit() + + axfreq = fig.add_axes(rect=(0.4, 0.9, 0.3, 0.03)) + slice_slider = Slider( + ax=axfreq, + label="Threshold", + valmin=0, + valmax=value_sort_max, + valinit=threshold, + valstep=value_sort_max / 100.0, + ) + axbutton_accept = fig.add_axes(rect=(0.3, 0.85, 0.2, 0.04)) + button_accept = Button( + ax=axbutton_accept, label="Accept", image=None, color="0.85", hovercolor="0.95" + ) + button_accept.on_clicked(on_clicked_accept) # type: ignore + + axbutton_cancel = fig.add_axes(rect=(0.55, 0.85, 0.2, 0.04)) + button_cancel = Button( + ax=axbutton_cancel, label="Cancel", image=None, color="0.85", hovercolor="0.95" + ) + button_cancel.on_clicked(on_clicked_cancel) # type: ignore + + slice_slider.on_changed( + partial(next_frame, images=image_3color, image_handle=image_handle) + ) + + mylogger.info("Display") + plt.show() -def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None: - global threshold - global image_3color - global path - global mylogger - global heartbeat_mask_file - global heartbeat_mask_threshold_file - - mylogger.info(f"Threshold: {threshold}") - - mask: np.ndarray = image_3color[..., 2] >= threshold - mylogger.info(f"Save mask to: {heartbeat_mask_file}") - np.save(heartbeat_mask_file, mask) - mylogger.info(f"Save threshold to: {heartbeat_mask_threshold_file}") - np.save(heartbeat_mask_threshold_file, np.array([threshold])) - exit() - - -def on_clicked_cancel(event: matplotlib.backend_bases.MouseEvent) -> None: - exit() - - -mylogger.info(f"loading image reference file: {image_ref_file}") -image_ref: np.ndarray = np.load(image_ref_file) -image_ref /= image_ref.max() - -mylogger.info(f"loading image heartbeat power: {image_var_file}") -image_var: np.ndarray = np.load(image_var_file) -image_var /= image_var.max() - -mylogger.info("Smear the image heartbeat power patially") -temp, _ = gauss_smear_individual( - input=torch.tensor(image_var[..., np.newaxis], device=device), - spatial_width=spatial_width, - temporal_width=temporal_width, - use_matlab_mask=False, -) -temp /= temp.max() - -mylogger.info("-==- DONE -==-") - -image_3color = np.concatenate( - ( - np.zeros_like(image_ref[..., np.newaxis]), - image_ref[..., np.newaxis], - temp.cpu().numpy(), - ), - axis=-1, -) - -mylogger.info("Prepare image") - -display_image = image_3color.copy() -display_image[..., 2] = display_image[..., 0] -mask = np.where(image_3color[..., 2] >= threshold, 1.0, np.nan)[..., np.newaxis] -display_image *= mask -display_image = np.nan_to_num(display_image, nan=1.0) - -value_sort = np.sort(image_var.flatten()) -value_sort_max = value_sort[int(value_sort.shape[0] * 0.95)] -mylogger.info("-==- DONE -==-") - -mylogger.info("Create figure") - -fig: matplotlib.figure.Figure = plt.figure() - -image_handle = plt.imshow(display_image, vmin=0, vmax=1, cmap="hot") - -mylogger.info("Add controls") - -axfreq = fig.add_axes(rect=(0.4, 0.9, 0.3, 0.03)) -slice_slider = Slider( - ax=axfreq, - label="Threshold", - valmin=0, - valmax=value_sort_max, - valinit=threshold, - valstep=value_sort_max / 100.0, -) -axbutton_accept = fig.add_axes(rect=(0.3, 0.85, 0.2, 0.04)) -button_accept = Button( - ax=axbutton_accept, label="Accept", image=None, color="0.85", hovercolor="0.95" -) -button_accept.on_clicked(on_clicked_accept) # type: ignore - -axbutton_cancel = fig.add_axes(rect=(0.55, 0.85, 0.2, 0.04)) -button_cancel = Button( - ax=axbutton_cancel, label="Cancel", image=None, color="0.85", hovercolor="0.95" -) -button_cancel.on_clicked(on_clicked_cancel) # type: ignore - -slice_slider.on_changed( - partial(next_frame, images=image_3color, image_handle=image_handle) -) - -mylogger.info("Display") -plt.show() +if __name__ == "__main__": + argh.dispatch_command(main) diff --git a/stage_3_refine_mask.py b/stage_3_refine_mask.py index 83f9ecd..1e68a93 100644 --- a/stage_3_refine_mask.py +++ b/stage_3_refine_mask.py @@ -9,9 +9,10 @@ from matplotlib.widgets import Button # type:ignore from roipoly import RoiPoly # type:ignore from functions.create_logger import create_logger -from functions.get_torch_device import get_torch_device from functions.load_config import load_config +import argh + def compose_image(image_3color: np.ndarray, mask: np.ndarray) -> np.ndarray: display_image = image_3color.copy() @@ -20,138 +21,145 @@ def compose_image(image_3color: np.ndarray, mask: np.ndarray) -> np.ndarray: return display_image -def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None: - global mylogger - global refined_mask_file - global mask +def main(*, config_filename: str = "config.json") -> None: + mylogger = create_logger( + save_logging_messages=True, + display_logging_messages=True, + log_stage_name="stage_3", + ) - mylogger.info(f"Save mask to: {refined_mask_file}") - np.save(refined_mask_file, mask) + config = load_config(mylogger=mylogger, filename=config_filename) - exit() + path: str = config["ref_image_path"] + use_channel: str = "donor" + + image_ref_file: str = os.path.join(path, use_channel + ".npy") + heartbeat_mask_file: str = os.path.join(path, "heartbeat_mask.npy") + refined_mask_file: str = os.path.join(path, "mask_not_rotated.npy") + + mylogger.info(f"loading image reference file: {image_ref_file}") + image_ref: np.ndarray = np.load(image_ref_file) + image_ref /= image_ref.max() + + mylogger.info(f"loading heartbeat mask: {heartbeat_mask_file}") + mask: np.ndarray = np.load(heartbeat_mask_file) + + image_3color = np.concatenate( + ( + np.zeros_like(image_ref[..., np.newaxis]), + image_ref[..., np.newaxis], + np.zeros_like(image_ref[..., np.newaxis]), + ), + axis=-1, + ) + + mylogger.info("-==- DONE -==-") + + fig, ax_main = plt.subplots() + + display_image = compose_image(image_3color=image_3color, mask=mask) + image_handle = ax_main.imshow(display_image, vmin=0, vmax=1, cmap="hot") + + mylogger.info("Add controls") + + def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None: + nonlocal mylogger + nonlocal refined_mask_file + nonlocal mask + + mylogger.info(f"Save mask to: {refined_mask_file}") + np.save(refined_mask_file, mask) + + exit() + + def on_clicked_cancel(event: matplotlib.backend_bases.MouseEvent) -> None: + nonlocal mylogger + mylogger.info("Ended without saving the mask") + exit() + + def on_clicked_add(event: matplotlib.backend_bases.MouseEvent) -> None: + nonlocal new_roi # type: ignore + nonlocal mask + nonlocal image_3color + nonlocal display_image + nonlocal mylogger + if len(new_roi.x) > 0: + mylogger.info( + "A ROI with the following coordiantes has been added to the mask" + ) + for i in range(0, len(new_roi.x)): + mylogger.info(f"{round(new_roi.x[i],1)} x {round(new_roi.y[i],1)}") + mylogger.info("") + new_mask = new_roi.get_mask(display_image[:, :, 0]) + mask[new_mask] = 0.0 + display_image = compose_image(image_3color=image_3color, mask=mask) + image_handle.set_data(display_image) + for line in ax_main.lines: + line.remove() + plt.draw() + + new_roi = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False) + + def on_clicked_remove(event: matplotlib.backend_bases.MouseEvent) -> None: + nonlocal new_roi # type: ignore + nonlocal mask + nonlocal image_3color + nonlocal display_image + if len(new_roi.x) > 0: + mylogger.info( + "A ROI with the following coordiantes has been removed from the mask" + ) + for i in range(0, len(new_roi.x)): + mylogger.info(f"{round(new_roi.x[i],1)} x {round(new_roi.y[i],1)}") + mylogger.info("") + new_mask = new_roi.get_mask(display_image[:, :, 0]) + mask[new_mask] = 1.0 + display_image = compose_image(image_3color=image_3color, mask=mask) + image_handle.set_data(display_image) + for line in ax_main.lines: + line.remove() + plt.draw() + new_roi = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False) + + axbutton_accept = fig.add_axes(rect=(0.3, 0.85, 0.2, 0.04)) + button_accept = Button( + ax=axbutton_accept, label="Accept", image=None, color="0.85", hovercolor="0.95" + ) + button_accept.on_clicked(on_clicked_accept) # type: ignore + + axbutton_cancel = fig.add_axes(rect=(0.5, 0.85, 0.2, 0.04)) + button_cancel = Button( + ax=axbutton_cancel, label="Cancel", image=None, color="0.85", hovercolor="0.95" + ) + button_cancel.on_clicked(on_clicked_cancel) # type: ignore + + axbutton_addmask = fig.add_axes(rect=(0.3, 0.9, 0.2, 0.04)) + button_addmask = Button( + ax=axbutton_addmask, + label="Add mask", + image=None, + color="0.85", + hovercolor="0.95", + ) + button_addmask.on_clicked(on_clicked_add) # type: ignore + + axbutton_removemask = fig.add_axes(rect=(0.5, 0.9, 0.2, 0.04)) + button_removemask = Button( + ax=axbutton_removemask, + label="Remove mask", + image=None, + color="0.85", + hovercolor="0.95", + ) + button_removemask.on_clicked(on_clicked_remove) # type: ignore + + # ax_main.cla() + + mylogger.info("Display") + new_roi: RoiPoly = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False) + + plt.show() -def on_clicked_cancel(event: matplotlib.backend_bases.MouseEvent) -> None: - global mylogger - mylogger.info("Ended without saving the mask") - exit() - - -def on_clicked_add(event: matplotlib.backend_bases.MouseEvent) -> None: - global new_roi - global mask - global image_3color - global display_image - global mylogger - if len(new_roi.x) > 0: - mylogger.info("A ROI with the following coordiantes has been added to the mask") - for i in range(0, len(new_roi.x)): - mylogger.info(f"{round(new_roi.x[i],1)} x {round(new_roi.y[i],1)}") - mylogger.info("") - new_mask = new_roi.get_mask(display_image[:, :, 0]) - mask[new_mask] = 0.0 - display_image = compose_image(image_3color=image_3color, mask=mask) - image_handle.set_data(display_image) - for line in ax_main.lines: - line.remove() - plt.draw() - - new_roi = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False) - - -def on_clicked_remove(event: matplotlib.backend_bases.MouseEvent) -> None: - global new_roi - global mask - global image_3color - global display_image - if len(new_roi.x) > 0: - mylogger.info( - "A ROI with the following coordiantes has been removed from the mask" - ) - for i in range(0, len(new_roi.x)): - mylogger.info(f"{round(new_roi.x[i],1)} x {round(new_roi.y[i],1)}") - mylogger.info("") - new_mask = new_roi.get_mask(display_image[:, :, 0]) - mask[new_mask] = 1.0 - display_image = compose_image(image_3color=image_3color, mask=mask) - image_handle.set_data(display_image) - for line in ax_main.lines: - line.remove() - plt.draw() - new_roi = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False) - - -mylogger = create_logger( - save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_3" -) - -config = load_config(mylogger=mylogger) - -device = get_torch_device(mylogger, config["force_to_cpu"]) - -path: str = config["ref_image_path"] -use_channel: str = "donor" - -image_ref_file: str = os.path.join(path, use_channel + ".npy") -heartbeat_mask_file: str = os.path.join(path, "heartbeat_mask.npy") -refined_mask_file: str = os.path.join(path, "mask_not_rotated.npy") - -mylogger.info(f"loading image reference file: {image_ref_file}") -image_ref: np.ndarray = np.load(image_ref_file) -image_ref /= image_ref.max() - -mylogger.info(f"loading heartbeat mask: {heartbeat_mask_file}") -mask: np.ndarray = np.load(heartbeat_mask_file) - -image_3color = np.concatenate( - ( - np.zeros_like(image_ref[..., np.newaxis]), - image_ref[..., np.newaxis], - np.zeros_like(image_ref[..., np.newaxis]), - ), - axis=-1, -) - -mylogger.info("-==- DONE -==-") - -fig, ax_main = plt.subplots() - -display_image = compose_image(image_3color=image_3color, mask=mask) -image_handle = ax_main.imshow(display_image, vmin=0, vmax=1, cmap="hot") - -mylogger.info("Add controls") - -axbutton_accept = fig.add_axes(rect=(0.3, 0.85, 0.2, 0.04)) -button_accept = Button( - ax=axbutton_accept, label="Accept", image=None, color="0.85", hovercolor="0.95" -) -button_accept.on_clicked(on_clicked_accept) # type: ignore - -axbutton_cancel = fig.add_axes(rect=(0.5, 0.85, 0.2, 0.04)) -button_cancel = Button( - ax=axbutton_cancel, label="Cancel", image=None, color="0.85", hovercolor="0.95" -) -button_cancel.on_clicked(on_clicked_cancel) # type: ignore - -axbutton_addmask = fig.add_axes(rect=(0.3, 0.9, 0.2, 0.04)) -button_addmask = Button( - ax=axbutton_addmask, label="Add mask", image=None, color="0.85", hovercolor="0.95" -) -button_addmask.on_clicked(on_clicked_add) # type: ignore - -axbutton_removemask = fig.add_axes(rect=(0.5, 0.9, 0.2, 0.04)) -button_removemask = Button( - ax=axbutton_removemask, - label="Remove mask", - image=None, - color="0.85", - hovercolor="0.95", -) -button_removemask.on_clicked(on_clicked_remove) # type: ignore - -# ax_main.cla() - -mylogger.info("Display") -new_roi: RoiPoly = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False) - -plt.show() +if __name__ == "__main__": + argh.dispatch_command(main) diff --git a/stage_4_process.py b/stage_4_process.py index 05abaea..909731f 100644 --- a/stage_4_process.py +++ b/stage_4_process.py @@ -20,6 +20,8 @@ from functions.gauss_smear_individual import gauss_smear_individual from functions.regression import regression from functions.data_raw_loader import data_raw_loader +import argh + @torch.no_grad() def process_trial( @@ -889,71 +891,96 @@ def process_trial( return -mylogger = create_logger( - save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_4" -) -config = load_config(mylogger=mylogger) - -if (config["save_as_python"] is False) and (config["save_as_matlab"] is False): - mylogger.info("No output will be created. ") - mylogger.info("Change save_as_python and/or save_as_matlab in the config file") - mylogger.info("ERROR: STOP!!!") - exit() - -if (len(config["target_camera_donor"]) == 0) and ( - len(config["target_camera_acceptor"]) == 0 -): - mylogger.info( - "Configure at least target_camera_donor or target_camera_acceptor correctly." +def main( + *, + config_filename: str = "config.json", + experiment_id_overwrite: int = -1, + trial_id_overwrite: int = -1, +) -> None: + mylogger = create_logger( + save_logging_messages=True, + display_logging_messages=True, + log_stage_name="stage_4", ) - mylogger.info("ERROR: STOP!!!") - exit() -device = get_torch_device(mylogger, config["force_to_cpu"]) + config = load_config(mylogger=mylogger, filename=config_filename) -mylogger.info(f"Create directory {config['export_path']} in the case it does not exist") -os.makedirs(config["export_path"], exist_ok=True) + if (config["save_as_python"] is False) and (config["save_as_matlab"] is False): + mylogger.info("No output will be created. ") + mylogger.info("Change save_as_python and/or save_as_matlab in the config file") + mylogger.info("ERROR: STOP!!!") + exit() -raw_data_path: str = os.path.join( - config["basic_path"], - config["recoding_data"], - config["mouse_identifier"], - config["raw_path"], -) - -if os.path.isdir(raw_data_path) is False: - mylogger.info(f"ERROR: could not find raw directory {raw_data_path}!!!!") - exit() - -experiments = get_experiments(raw_data_path) - -for experiment_counter in range(0, experiments.shape[0]): - experiment_id = int(experiments[experiment_counter]) - trials = get_trials(raw_data_path, experiment_id) - for trial_counter in range(0, trials.shape[0]): - trial_id = int(trials[trial_counter]) - - mylogger.info("") + if (len(config["target_camera_donor"]) == 0) and ( + len(config["target_camera_acceptor"]) == 0 + ): mylogger.info( - f"======= EXPERIMENT ID: {experiment_id} ==== TRIAL ID: {trial_id} =======" + "Configure at least target_camera_donor or target_camera_acceptor correctly." ) - mylogger.info("") + mylogger.info("ERROR: STOP!!!") + exit() - try: - process_trial( - config=config, - mylogger=mylogger, - experiment_id=experiment_id, - trial_id=trial_id, - device=device, - ) - except torch.cuda.OutOfMemoryError: - mylogger.info("WARNING: RUNNING IN FAILBACK MODE!!!!") - mylogger.info("Not enough GPU memory. Retry on CPU") - process_trial( - config=config, - mylogger=mylogger, - experiment_id=experiment_id, - trial_id=trial_id, - device=torch.device("cpu"), + device = get_torch_device(mylogger, config["force_to_cpu"]) + + mylogger.info( + f"Create directory {config['export_path']} in the case it does not exist" + ) + os.makedirs(config["export_path"], exist_ok=True) + + raw_data_path: str = os.path.join( + config["basic_path"], + config["recoding_data"], + config["mouse_identifier"], + config["raw_path"], + ) + + if os.path.isdir(raw_data_path) is False: + mylogger.info(f"ERROR: could not find raw directory {raw_data_path}!!!!") + exit() + + if experiment_id_overwrite == -1: + experiments = get_experiments(raw_data_path) + else: + assert experiment_id_overwrite >= 0 + experiments = torch.tensor([experiment_id_overwrite]) + + for experiment_counter in range(0, experiments.shape[0]): + experiment_id = int(experiments[experiment_counter]) + + if trial_id_overwrite == -1: + trials = get_trials(raw_data_path, experiment_id) + else: + assert trial_id_overwrite >= 0 + trials = torch.tensor([trial_id_overwrite]) + + for trial_counter in range(0, trials.shape[0]): + trial_id = int(trials[trial_counter]) + + mylogger.info("") + mylogger.info( + f"======= EXPERIMENT ID: {experiment_id} ==== TRIAL ID: {trial_id} =======" ) + mylogger.info("") + + try: + process_trial( + config=config, + mylogger=mylogger, + experiment_id=experiment_id, + trial_id=trial_id, + device=device, + ) + except torch.cuda.OutOfMemoryError: + mylogger.info("WARNING: RUNNING IN FAILBACK MODE!!!!") + mylogger.info("Not enough GPU memory. Retry on CPU") + process_trial( + config=config, + mylogger=mylogger, + experiment_id=experiment_id, + trial_id=trial_id, + device=torch.device("cpu"), + ) + + +if __name__ == "__main__": + argh.dispatch_command(main)