From 369540f47259420f63e677f8426aceffa5823dfe Mon Sep 17 00:00:00 2001 From: David Rotermund <54365609+davrot@users.noreply.github.com> Date: Sat, 24 Feb 2024 17:28:26 +0100 Subject: [PATCH] Add files via upload --- new_pipeline/config.json | 26 +++ new_pipeline/functions/bandpass.py | 85 ++++++++++ new_pipeline/functions/create_logger.py | 37 +++++ .../functions/gauss_smear_individual.py | 127 ++++++++++++++ new_pipeline/functions/get_experiments.py | 19 +++ new_pipeline/functions/get_parts.py | 18 ++ new_pipeline/functions/get_trials.py | 18 ++ new_pipeline/functions/load_meta_data.py | 54 ++++++ new_pipeline/stage_1_get_ref_image.py | 143 ++++++++++++++++ new_pipeline/stage_2_make_heartbeat_mask.py | 155 ++++++++++++++++++ 10 files changed, 682 insertions(+) create mode 100644 new_pipeline/config.json create mode 100644 new_pipeline/functions/bandpass.py create mode 100644 new_pipeline/functions/create_logger.py create mode 100644 new_pipeline/functions/gauss_smear_individual.py create mode 100644 new_pipeline/functions/get_experiments.py create mode 100644 new_pipeline/functions/get_parts.py create mode 100644 new_pipeline/functions/get_trials.py create mode 100644 new_pipeline/functions/load_meta_data.py create mode 100644 new_pipeline/stage_1_get_ref_image.py create mode 100644 new_pipeline/stage_2_make_heartbeat_mask.py diff --git a/new_pipeline/config.json b/new_pipeline/config.json new file mode 100644 index 0000000..e0e6eac --- /dev/null +++ b/new_pipeline/config.json @@ -0,0 +1,26 @@ +{ + "basic_path": "/data_1/robert", + "recoding_data": "2021-05-05", + "mouse_identifier": "M3852M", + "raw_path": "raw", + "export_path": "output", + "ref_image_path": "ref_images", + "required_order": [ + "acceptor", + "donor", + "oxygenation", + "volume" + ], + "dtype": "float32", + "binning_enable": false, + "binning_kernel_size": 4, + "binning_stride": 4, + "binning_divisor_override": 1, + // Heart beat detection + "lower_freqency_bandpass": 5.0, // Hz + "upper_freqency_bandpass": 14.0, // Hz + // LED Ramp on + "skip_frames_in_the_beginning": 100, // Frames + // PyTorch + "force_to_cpu": false +} \ No newline at end of file diff --git a/new_pipeline/functions/bandpass.py b/new_pipeline/functions/bandpass.py new file mode 100644 index 0000000..171baf5 --- /dev/null +++ b/new_pipeline/functions/bandpass.py @@ -0,0 +1,85 @@ +import torchaudio as ta # type: ignore +import torch + + +@torch.no_grad() +def filtfilt( + input: torch.Tensor, + butter_a: torch.Tensor, + butter_b: torch.Tensor, +) -> torch.Tensor: + assert butter_a.ndim == 1 + assert butter_b.ndim == 1 + assert butter_a.shape[0] == butter_b.shape[0] + + process_data: torch.Tensor = input.detach().clone() + + padding_length = 12 * int(butter_a.shape[0]) + left_padding = 2 * process_data[..., 0].unsqueeze(-1) - process_data[ + ..., 1 : padding_length + 1 + ].flip(-1) + right_padding = 2 * process_data[..., -1].unsqueeze(-1) - process_data[ + ..., -(padding_length + 1) : -1 + ].flip(-1) + process_data_padded = torch.cat((left_padding, process_data, right_padding), dim=-1) + + output = ta.functional.filtfilt( + process_data_padded.unsqueeze(0), butter_a, butter_b, clamp=False + ).squeeze(0) + + output = output[..., padding_length:-padding_length] + return output + + +@torch.no_grad() +def butter_bandpass( + device: torch.device, + low_frequency: float = 0.1, + high_frequency: float = 1.0, + fs: float = 30.0, +) -> tuple[torch.Tensor, torch.Tensor]: + import scipy # type: ignore + + butter_b_np, butter_a_np = scipy.signal.butter( + 4, [low_frequency, high_frequency], btype="bandpass", output="ba", fs=fs + ) + butter_a = torch.tensor(butter_a_np, device=device, dtype=torch.float32) + butter_b = torch.tensor(butter_b_np, device=device, dtype=torch.float32) + return butter_a, butter_b + + +@torch.no_grad() +def chunk_iterator(array: torch.Tensor, chunk_size: int): + for i in range(0, array.shape[0], chunk_size): + yield array[i : i + chunk_size] + + +@torch.no_grad() +def bandpass( + data: torch.Tensor, + device: torch.device, + low_frequency: float = 0.1, + high_frequency: float = 1.0, + fs=30.0, + filtfilt_chuck_size: int = 10, +) -> torch.Tensor: + butter_a, butter_b = butter_bandpass( + device=device, + low_frequency=low_frequency, + high_frequency=high_frequency, + fs=fs, + ) + + index_full_dataset: torch.Tensor = torch.arange( + 0, data.shape[1], device=device, dtype=torch.int64 + ) + + for chunk in chunk_iterator(index_full_dataset, filtfilt_chuck_size): + temp_filtfilt = filtfilt( + data[:, chunk, :], + butter_a=butter_a, + butter_b=butter_b, + ) + data[:, chunk, :] = temp_filtfilt + + return data diff --git a/new_pipeline/functions/create_logger.py b/new_pipeline/functions/create_logger.py new file mode 100644 index 0000000..8fcfa8a --- /dev/null +++ b/new_pipeline/functions/create_logger.py @@ -0,0 +1,37 @@ +import logging +import datetime +import os + + +def create_logger( + save_logging_messages: bool, display_logging_messages: bool, log_stage_name: str +): + now = datetime.datetime.now() + dt_string_filename = now.strftime("%Y_%m_%d_%H_%M_%S") + + logger = logging.getLogger("MyLittleLogger") + logger.setLevel(logging.DEBUG) + + if save_logging_messages: + time_format = "%b %-d %Y %H:%M:%S" + logformat = "%(asctime)s %(message)s" + file_formatter = logging.Formatter(fmt=logformat, datefmt=time_format) + os.makedirs("logs_" + log_stage_name, exist_ok=True) + file_handler = logging.FileHandler( + os.path.join("logs_" + log_stage_name, f"log_{dt_string_filename}.txt") + ) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(file_formatter) + logger.addHandler(file_handler) + + if display_logging_messages: + time_format = "%H:%M:%S" + logformat = "%(asctime)s %(message)s" + stream_formatter = logging.Formatter(fmt=logformat, datefmt=time_format) + + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.INFO) + stream_handler.setFormatter(stream_formatter) + logger.addHandler(stream_handler) + + return logger diff --git a/new_pipeline/functions/gauss_smear_individual.py b/new_pipeline/functions/gauss_smear_individual.py new file mode 100644 index 0000000..36700e7 --- /dev/null +++ b/new_pipeline/functions/gauss_smear_individual.py @@ -0,0 +1,127 @@ +import torch +import math + + +@torch.no_grad() +def gauss_smear_individual( + input: torch.Tensor, + spatial_width: float, + temporal_width: float, + overwrite_fft_gauss: None | torch.Tensor = None, + use_matlab_mask: bool = True, + epsilon: float = float(torch.finfo(torch.float64).eps), +) -> tuple[torch.Tensor, torch.Tensor]: + + dim_x: int = int(2 * math.ceil(2 * spatial_width) + 1) + dim_y: int = int(2 * math.ceil(2 * spatial_width) + 1) + dim_t: int = int(2 * math.ceil(2 * temporal_width) + 1) + dims_xyt: torch.Tensor = torch.tensor( + [dim_x, dim_y, dim_t], dtype=torch.int64, device=input.device + ) + + if input.ndim == 2: + input = input.unsqueeze(-1) + + input_padded = torch.nn.functional.pad( + input.unsqueeze(0), + pad=( + dim_t, + dim_t, + dim_y, + dim_y, + dim_x, + dim_x, + ), + mode="replicate", + ).squeeze(0) + + if overwrite_fft_gauss is None: + center_x: int = int(math.ceil(input_padded.shape[0] / 2)) + center_y: int = int(math.ceil(input_padded.shape[1] / 2)) + center_z: int = int(math.ceil(input_padded.shape[2] / 2)) + grid_x: torch.Tensor = ( + torch.arange(0, input_padded.shape[0], device=input.device) - center_x + 1 + ) + grid_y: torch.Tensor = ( + torch.arange(0, input_padded.shape[1], device=input.device) - center_y + 1 + ) + grid_z: torch.Tensor = ( + torch.arange(0, input_padded.shape[2], device=input.device) - center_z + 1 + ) + + grid_x = grid_x.unsqueeze(-1).unsqueeze(-1) ** 2 + grid_y = grid_y.unsqueeze(0).unsqueeze(-1) ** 2 + grid_z = grid_z.unsqueeze(0).unsqueeze(0) ** 2 + + gauss_kernel: torch.Tensor = ( + (grid_x / (spatial_width**2)) + + (grid_y / (spatial_width**2)) + + (grid_z / (temporal_width**2)) + ) + + if use_matlab_mask: + filter_radius: torch.Tensor = (dims_xyt - 1) // 2 + + border_lower: list[int] = [ + center_x - int(filter_radius[0]) - 1, + center_y - int(filter_radius[1]) - 1, + center_z - int(filter_radius[2]) - 1, + ] + + border_upper: list[int] = [ + center_x + int(filter_radius[0]), + center_y + int(filter_radius[1]), + center_z + int(filter_radius[2]), + ] + + matlab_mask: torch.Tensor = torch.zeros_like(gauss_kernel) + matlab_mask[ + border_lower[0] : border_upper[0], + border_lower[1] : border_upper[1], + border_lower[2] : border_upper[2], + ] = 1.0 + + gauss_kernel = torch.exp(-gauss_kernel / 2.0) + if use_matlab_mask: + gauss_kernel = gauss_kernel * matlab_mask + + gauss_kernel[gauss_kernel < (epsilon * gauss_kernel.max())] = 0 + + sum_gauss_kernel: float = float(gauss_kernel.sum()) + + if sum_gauss_kernel != 0.0: + gauss_kernel = gauss_kernel / sum_gauss_kernel + + # FFT Shift + gauss_kernel = torch.cat( + (gauss_kernel[center_x - 1 :, :, :], gauss_kernel[: center_x - 1, :, :]), + dim=0, + ) + gauss_kernel = torch.cat( + (gauss_kernel[:, center_y - 1 :, :], gauss_kernel[:, : center_y - 1, :]), + dim=1, + ) + gauss_kernel = torch.cat( + (gauss_kernel[:, :, center_z - 1 :], gauss_kernel[:, :, : center_z - 1]), + dim=2, + ) + overwrite_fft_gauss = torch.fft.fftn(gauss_kernel) + input_padded_gauss_filtered: torch.Tensor = torch.real( + torch.fft.ifftn(torch.fft.fftn(input_padded) * overwrite_fft_gauss) + ) + else: + input_padded_gauss_filtered = torch.real( + torch.fft.ifftn(torch.fft.fftn(input_padded) * overwrite_fft_gauss) + ) + + start = dims_xyt + stop = ( + torch.tensor(input_padded.shape, device=dims_xyt.device, dtype=dims_xyt.dtype) + - dims_xyt + ) + + output = input_padded_gauss_filtered[ + start[0] : stop[0], start[1] : stop[1], start[2] : stop[2] + ] + + return (output, overwrite_fft_gauss) diff --git a/new_pipeline/functions/get_experiments.py b/new_pipeline/functions/get_experiments.py new file mode 100644 index 0000000..d92b936 --- /dev/null +++ b/new_pipeline/functions/get_experiments.py @@ -0,0 +1,19 @@ +import torch +import os +import glob + + +@torch.no_grad() +def get_experiments(path: str) -> torch.Tensor: + filename_np: str = os.path.join( + path, + "Exp*_Part001.npy", + ) + + list_str = glob.glob(filename_np) + list_int: list[int] = [] + for i in range(0, len(list_str)): + list_int.append(int(list_str[i].split("Exp")[-1].split("_Trial")[0])) + list_int = sorted(list_int) + + return torch.tensor(list_int).unique() diff --git a/new_pipeline/functions/get_parts.py b/new_pipeline/functions/get_parts.py new file mode 100644 index 0000000..d68e1ae --- /dev/null +++ b/new_pipeline/functions/get_parts.py @@ -0,0 +1,18 @@ +import torch +import os +import glob + + +@torch.no_grad() +def get_parts(path: str, experiment_id: int, trial_id: int) -> torch.Tensor: + filename_np: str = os.path.join( + path, + f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part*.npy", + ) + + list_str = glob.glob(filename_np) + list_int: list[int] = [] + for i in range(0, len(list_str)): + list_int.append(int(list_str[i].split("_Part")[-1].split(".npy")[0])) + list_int = sorted(list_int) + return torch.tensor(list_int).unique() diff --git a/new_pipeline/functions/get_trials.py b/new_pipeline/functions/get_trials.py new file mode 100644 index 0000000..8c687d9 --- /dev/null +++ b/new_pipeline/functions/get_trials.py @@ -0,0 +1,18 @@ +import torch +import os +import glob + + +@torch.no_grad() +def get_trials(path: str, experiment_id: int) -> torch.Tensor: + filename_np: str = os.path.join( + path, + f"Exp{experiment_id:03d}_Trial*_Part001.npy", + ) + + list_str = glob.glob(filename_np) + list_int: list[int] = [] + for i in range(0, len(list_str)): + list_int.append(int(list_str[i].split("_Trial")[-1].split("_Part")[0])) + list_int = sorted(list_int) + return torch.tensor(list_int).unique() diff --git a/new_pipeline/functions/load_meta_data.py b/new_pipeline/functions/load_meta_data.py new file mode 100644 index 0000000..641beb7 --- /dev/null +++ b/new_pipeline/functions/load_meta_data.py @@ -0,0 +1,54 @@ +import logging +import json + + +def load_meta_data( + mylogger: logging.Logger, filename_meta: str +) -> tuple[list[str], str, str, dict, dict, float, float, str]: + + mylogger.info("Loading meta data") + with open(filename_meta, "r") as file_handle: + metadata: dict = json.load(file_handle) + + channels: list[str] = metadata["channelKey"] + + mylogger.info(f"meta data: channel order: {channels}") + + mouse_markings: str = metadata["sessionMetaData"]["mouseMarkings"] + mylogger.info(f"meta data: mouse markings: {mouse_markings}") + + recording_date: str = metadata["sessionMetaData"]["date"] + mylogger.info(f"meta data: recording data: {recording_date}") + + stimulation_times: dict = metadata["sessionMetaData"]["stimulationTimes"] + mylogger.info(f"meta data: stimulation times: {stimulation_times}") + + experiment_names: dict = metadata["sessionMetaData"]["experimentNames"] + mylogger.info(f"meta data: experiment names: {experiment_names}") + + trial_recording_duration: float = float( + metadata["sessionMetaData"]["trialRecordingDuration"] + ) + mylogger.info( + f"meta data: trial recording duration: {trial_recording_duration} sec" + ) + + frame_time: float = float(metadata["sessionMetaData"]["frameTime"]) + mylogger.info( + f"meta data: frame time: {frame_time} sec ; frame rate: {1.0/frame_time}Hz" + ) + + mouse: str = metadata["sessionMetaData"]["mouse"] + mylogger.info(f"meta data: mouse: {mouse}") + mylogger.info("-==- Done -==-") + + return ( + channels, + mouse_markings, + recording_date, + stimulation_times, + experiment_names, + trial_recording_duration, + frame_time, + mouse, + ) diff --git a/new_pipeline/stage_1_get_ref_image.py b/new_pipeline/stage_1_get_ref_image.py new file mode 100644 index 0000000..e9eaff3 --- /dev/null +++ b/new_pipeline/stage_1_get_ref_image.py @@ -0,0 +1,143 @@ +import json +import os +from jsmin import jsmin # type: ignore +import torch +import numpy as np + + +from functions.get_experiments import get_experiments +from functions.get_trials import get_trials +from functions.get_parts import get_parts +from functions.bandpass import bandpass +from functions.create_logger import create_logger +from functions.load_meta_data import load_meta_data + +mylogger = create_logger( + save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_1" +) + +mylogger.info("loading config file") +with open("config.json", "r") as file: + config = json.loads(jsmin(file.read())) + +if torch.cuda.is_available(): + device_name: str = "cuda:0" +else: + device_name = "cpu" + +if config["force_to_cpu"]: + device_name = "cpu" + +mylogger.info(f"Using device: {device_name}") +device: torch.device = torch.device(device_name) + +dtype_str: str = config["dtype"] +dtype: torch.dtype = getattr(torch, dtype_str) + +raw_data_path: str = os.path.join( + config["basic_path"], + config["recoding_data"], + config["mouse_identifier"], + config["raw_path"], +) + +mylogger.info(f"Using data path: {raw_data_path}") + +first_experiment_id: int = int(get_experiments(raw_data_path).min()) +first_trial_id: int = int(get_trials(raw_data_path, first_experiment_id).min()) +first_part_id: int = int( + get_parts(raw_data_path, first_experiment_id, first_trial_id).min() +) + +filename_data: str = os.path.join( + raw_data_path, + f"Exp{first_experiment_id:03d}_Trial{first_trial_id:03d}_Part{first_part_id:03d}.npy", +) + +mylogger.info(f"Will use: {filename_data} for data") + +filename_meta: str = os.path.join( + raw_data_path, + f"Exp{first_experiment_id:03d}_Trial{first_trial_id:03d}_Part{first_part_id:03d}_meta.txt", +) + +mylogger.info(f"Will use: {filename_meta} for meta data") + +meta_channels: list[str] +meta_mouse_markings: str +meta_recording_date: str +meta_stimulation_times: dict +meta_experiment_names: dict +meta_trial_recording_duration: float +meta_frame_time: float +meta_mouse: str + +( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, +) = load_meta_data(mylogger=mylogger, filename_meta=filename_meta) + + +dtype_str = config["dtype"] +dtype_np: np.dtype = getattr(np, dtype_str) + +mylogger.info("Loading data") +data = torch.tensor( + np.load(filename_data).astype(dtype_np), dtype=dtype, device=torch.device("cpu") +) +mylogger.info("-==- Done -==-") + +output_path = config["ref_image_path"] +mylogger.info(f"Create directory {output_path} in the case it does not exist") +os.makedirs(output_path, exist_ok=True) + +mylogger.info("Reference images") +for i in range(0, len(meta_channels)): + temp_path: str = os.path.join(output_path, meta_channels[i] + ".npy") + mylogger.info(f"Extract and save: {temp_path}") + frame_id: int = data.shape[-2] // 2 + mylogger.info(f"Will use frame id: {frame_id}") + ref_image: np.ndarray = ( + data[:, :, frame_id, meta_channels.index(meta_channels[i])] + .clone() + .cpu() + .numpy() + ) + np.save(temp_path, ref_image) +mylogger.info("-==- Done -==-") + +sample_frequency: float = 1.0 / meta_frame_time +mylogger.info( + ( + f"Heartbeat power {config['lower_freqency_bandpass']}Hz" + f" - {config['upper_freqency_bandpass']}Hz," + f" sample-rate: {sample_frequency}," + f" skipping the first {config['skip_frames_in_the_beginning']} frames" + ) +) + +for i in range(0, len(meta_channels)): + temp_path = os.path.join(output_path, meta_channels[i] + "_var.npy") + mylogger.info(f"Extract and save: {temp_path}") + + heartbeat_ts: torch.Tensor = bandpass( + data=data[..., i], + device=data.device, + low_frequency=config["lower_freqency_bandpass"], + high_frequency=config["upper_freqency_bandpass"], + fs=sample_frequency, + filtfilt_chuck_size=10, + ) + + heartbeat_power = heartbeat_ts[..., config["skip_frames_in_the_beginning"] :].var( + dim=-1 + ) + np.save(temp_path, heartbeat_power) + +mylogger.info("-==- Done -==-") diff --git a/new_pipeline/stage_2_make_heartbeat_mask.py b/new_pipeline/stage_2_make_heartbeat_mask.py new file mode 100644 index 0000000..435027f --- /dev/null +++ b/new_pipeline/stage_2_make_heartbeat_mask.py @@ -0,0 +1,155 @@ +import matplotlib.pyplot as plt # type:ignore +import matplotlib +import numpy as np +import torch +import os +import json + +from jsmin import jsmin # type:ignore +from matplotlib.widgets import Slider, Button # type:ignore +from functools import partial +from functions.gauss_smear_individual import gauss_smear_individual +from functions.create_logger import create_logger + + +mylogger = create_logger( + save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_2" +) + +mylogger.info("loading config file") +with open("config.json", "r") as file: + config = json.loads(jsmin(file.read())) + +threshold: float = 0.05 +path: str = config["ref_image_path"] + +image_ref_file: str = os.path.join(path, "donor.npy") +image_var_file: str = os.path.join(path, "donor_var.npy") +heartbeat_mask_file: str = os.path.join(path, "heartbeat_mask.npy") +heartbeat_mask_threshold_file: str = os.path.join(path, "heartbeat_mask_threshold.npy") + +if torch.cuda.is_available(): + device_name: str = "cuda:0" +else: + device_name = "cpu" + +if config["force_to_cpu"]: + device_name = "cpu" + +mylogger.info(f"Using device: {device_name}") +device: torch.device = torch.device(device_name) + + +def next_frame( + i: float, images: np.ndarray, image_handle: matplotlib.image.AxesImage +) -> None: + global threshold + threshold = i + + display_image: np.ndarray = images.copy() + display_image[..., 2] = display_image[..., 0] + mask: np.ndarray = np.where(images[..., 2] >= i, 1.0, np.nan)[..., np.newaxis] + display_image *= mask + display_image = np.nan_to_num(display_image, nan=1.0) + + image_handle.set_data(display_image) + return + + +def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None: + global threshold + global volume_3color + global path + global mylogger + global heartbeat_mask_file + global heartbeat_mask_threshold_file + + mylogger.info(f"Threshold: {threshold}") + + mask: np.ndarray = volume_3color[..., 2] >= threshold + mylogger.info(f"Save mask to: {heartbeat_mask_file}") + np.save(heartbeat_mask_file, mask) + mylogger.info(f"Save threshold to: {heartbeat_mask_threshold_file}") + np.save(heartbeat_mask_threshold_file, np.array([threshold])) + exit() + + +def on_clicked_cancel(event: matplotlib.backend_bases.MouseEvent) -> None: + exit() + + +mylogger.info(f"loading image reference file: {image_ref_file}") +image_ref: np.ndarray = np.load(image_ref_file) +image_ref /= image_ref.max() + +mylogger.info(f"loading image heartbeat power: {image_var_file}") +image_var: np.ndarray = np.load(image_var_file) +image_var /= image_var.max() + +mylogger.info("Smear the image heartbeat power patially") +temp, _ = gauss_smear_individual( + input=torch.tensor(image_var[..., np.newaxis], device=device), + spatial_width=4.0, + temporal_width=0.1, + use_matlab_mask=False, +) +temp /= temp.max() + +mylogger.info("-==- DONE -==-") + +volume_3color = np.concatenate( + ( + np.zeros_like(image_ref[..., np.newaxis]), + image_ref[..., np.newaxis], + temp.cpu().numpy(), + ), + axis=-1, +) + +mylogger.info("Prepare image") + +display_image = volume_3color.copy() +display_image[..., 2] = display_image[..., 0] +mask = np.where(volume_3color[..., 2] >= threshold, 1.0, np.nan)[..., np.newaxis] +display_image *= mask +display_image = np.nan_to_num(display_image, nan=1.0) + +value_sort = np.sort(image_var.flatten()) +value_sort_max = value_sort[int(value_sort.shape[0] * 0.95)] +mylogger.info("-==- DONE -==-") + +mylogger.info("Create figure") + +fig: matplotlib.figure.Figure = plt.figure() + +image_handle = plt.imshow(display_image, vmin=0, vmax=1, cmap="hot") + +mylogger.info("Add controls") + +axfreq = fig.add_axes(rect=(0.4, 0.9, 0.3, 0.03)) +slice_slider = Slider( + ax=axfreq, + label="Slice", + valmin=0, + valmax=value_sort_max, + valinit=threshold, + valstep=value_sort_max / 100.0, +) +axbutton_accept = fig.add_axes(rect=(0.3, 0.85, 0.2, 0.04)) +button_accept = Button( + ax=axbutton_accept, label="Accept", image=None, color="0.85", hovercolor="0.95" +) +button_accept.on_clicked(on_clicked_accept) # type: ignore + +axbutton_cancel = fig.add_axes(rect=(0.55, 0.85, 0.2, 0.04)) +button_cancel = Button( + ax=axbutton_cancel, label="Cancel", image=None, color="0.85", hovercolor="0.95" +) +button_cancel.on_clicked(on_clicked_cancel) # type: ignore + +slice_slider.on_changed( + partial(next_frame, images=volume_3color, image_handle=image_handle) +) + +mylogger.info("Display") +plt.show()