Add files via upload

This commit is contained in:
David Rotermund 2024-03-01 01:17:12 +01:00 committed by GitHub
parent db43df93eb
commit 2290dfe0d9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 648 additions and 420 deletions

View file

@ -0,0 +1,60 @@
{
"basic_path": "/data_1/robert",
"recoding_data": "2021-10-05",
"mouse_identifier": "M3879M",
"raw_path": "raw",
"export_path": "output_M3879M_2021-10-05",
"ref_image_path": "ref_images_M3879M_2021-10-05",
// Ratio Sequence
"classical_ratio_mode": true, // true: a/d false: 1+a-d
// Regression
//"target_camera_acceptor": "acceptor",
"target_camera_acceptor": "",
"regressor_cameras_acceptor": [
"oxygenation",
"volume"
],
"target_camera_donor": "donor",
"regressor_cameras_donor": [
// "oxygenation",
"volume"
],
// binning
"binning_enable": true,
"binning_at_the_end": false,
"binning_kernel_size": 4,
"binning_stride": 4,
"binning_divisor_override": 1,
// alignment
"alignment_batch_size": 200,
"rotation_stabilization_threshold_factor": 3.0, // >= 1.0
"rotation_stabilization_threshold_border": 0.9, // <= 1.0
// Heart beat detection
"lower_freqency_bandpass": 5.0, // Hz
"upper_freqency_bandpass": 14.0, // Hz
"heartbeat_filtfilt_chuck_size": 10,
// Gauss smear
"gauss_smear_spatial_width": 8,
"gauss_smear_temporal_width": 0.1,
"gauss_smear_use_matlab_mask": false,
// LED Ramp on
"skip_frames_in_the_beginning": 100, // Frames
// PyTorch
"dtype": "float32",
"force_to_cpu": false,
// Save
"save_as_python": true, // produces .npz files (compressed)
"save_as_matlab": false, // produces .hd5 file (compressed)
// Save extra information
"save_alignment": false,
"save_heartbeat": false,
"save_factors": false,
"save_regression_coefficients": false,
// Not important parameter
"required_order": [
"acceptor",
"donor",
"oxygenation",
"volume"
]
}

60
config_M_Sert_Cre_41.json Normal file
View file

@ -0,0 +1,60 @@
{
"basic_path": "/data_1/hendrik",
"recoding_data": "2023-07-17",
"mouse_identifier": "M_Sert_Cre_41",
"raw_path": "raw",
"export_path": "output_M_Sert_Cre_41",
"ref_image_path": "ref_images_M_Sert_Cre_41",
// Ratio Sequence
"classical_ratio_mode": true, // true: a/d false: 1+a-d
// Regression
//"target_camera_acceptor": "acceptor",
"target_camera_acceptor": "",
"regressor_cameras_acceptor": [
"oxygenation",
"volume"
],
"target_camera_donor": "donor",
"regressor_cameras_donor": [
// "oxygenation",
"volume"
],
// binning
"binning_enable": true,
"binning_at_the_end": false,
"binning_kernel_size": 4,
"binning_stride": 4,
"binning_divisor_override": 1,
// alignment
"alignment_batch_size": 200,
"rotation_stabilization_threshold_factor": 3.0, // >= 1.0
"rotation_stabilization_threshold_border": 0.9, // <= 1.0
// Heart beat detection
"lower_freqency_bandpass": 5.0, // Hz
"upper_freqency_bandpass": 14.0, // Hz
"heartbeat_filtfilt_chuck_size": 10,
// Gauss smear
"gauss_smear_spatial_width": 8,
"gauss_smear_temporal_width": 0.1,
"gauss_smear_use_matlab_mask": false,
// LED Ramp on
"skip_frames_in_the_beginning": 100, // Frames
// PyTorch
"dtype": "float32",
"force_to_cpu": false,
// Save
"save_as_python": true, // produces .npz files (compressed)
"save_as_matlab": false, // produces .hd5 file (compressed)
// Save extra information
"save_alignment": false,
"save_heartbeat": false,
"save_factors": false,
"save_regression_coefficients": false,
// Not important parameter
"required_order": [
"acceptor",
"donor",
"oxygenation",
"volume"
]
}

60
config_M_Sert_Cre_49.json Normal file
View file

@ -0,0 +1,60 @@
{
"basic_path": "/data_1/hendrik",
"recoding_data": "2023-03-15",
"mouse_identifier": "M_Sert_Cre_49",
"raw_path": "raw",
"export_path": "output_M_Sert_Cre_49",
"ref_image_path": "ref_images_M_Sert_Cre_49",
// Ratio Sequence
"classical_ratio_mode": true, // true: a/d false: 1+a-d
// Regression
//"target_camera_acceptor": "acceptor",
"target_camera_acceptor": "",
"regressor_cameras_acceptor": [
"oxygenation",
"volume"
],
"target_camera_donor": "donor",
"regressor_cameras_donor": [
// "oxygenation",
"volume"
],
// binning
"binning_enable": true,
"binning_at_the_end": false,
"binning_kernel_size": 4,
"binning_stride": 4,
"binning_divisor_override": 1,
// alignment
"alignment_batch_size": 200,
"rotation_stabilization_threshold_factor": 3.0, // >= 1.0
"rotation_stabilization_threshold_border": 0.9, // <= 1.0
// Heart beat detection
"lower_freqency_bandpass": 5.0, // Hz
"upper_freqency_bandpass": 14.0, // Hz
"heartbeat_filtfilt_chuck_size": 10,
// Gauss smear
"gauss_smear_spatial_width": 8,
"gauss_smear_temporal_width": 0.1,
"gauss_smear_use_matlab_mask": false,
// LED Ramp on
"skip_frames_in_the_beginning": 100, // Frames
// PyTorch
"dtype": "float32",
"force_to_cpu": false,
// Save
"save_as_python": true, // produces .npz files (compressed)
"save_as_matlab": false, // produces .hd5 file (compressed)
// Save extra information
"save_alignment": false,
"save_heartbeat": false,
"save_factors": false,
"save_regression_coefficients": false,
// Not important parameter
"required_order": [
"acceptor",
"donor",
"oxygenation",
"volume"
]
}

View file

@ -23,7 +23,7 @@ mylogger = create_logger(
)
config = load_config(mylogger=mylogger)
experiment_id: int = 1
experiment_id: int = 2
raw_data_path: str = os.path.join(
config["basic_path"],

View file

@ -1,7 +1,7 @@
import os
import torch
import numpy as np
import argh
from functions.get_experiments import get_experiments
from functions.get_trials import get_trials
@ -11,115 +11,119 @@ from functions.get_torch_device import get_torch_device
from functions.load_config import load_config
from functions.data_raw_loader import data_raw_loader
mylogger = create_logger(
save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_1"
)
config = load_config(mylogger=mylogger)
if config["binning_enable"] and (config["binning_at_the_end"] is False):
device: torch.device = torch.device("cpu")
else:
device = get_torch_device(mylogger, config["force_to_cpu"])
dtype_str: str = config["dtype"]
dtype: torch.dtype = getattr(torch, dtype_str)
raw_data_path: str = os.path.join(
config["basic_path"],
config["recoding_data"],
config["mouse_identifier"],
config["raw_path"],
)
mylogger.info(f"Using data path: {raw_data_path}")
first_experiment_id: int = int(get_experiments(raw_data_path).min())
first_trial_id: int = int(get_trials(raw_data_path, first_experiment_id).min())
meta_channels: list[str]
meta_mouse_markings: str
meta_recording_date: str
meta_stimulation_times: dict
meta_experiment_names: dict
meta_trial_recording_duration: float
meta_frame_time: float
meta_mouse: str
data: torch.Tensor
if config["binning_enable"] and (config["binning_at_the_end"] is False):
force_to_cpu_memory: bool = True
else:
force_to_cpu_memory = False
mylogger.info("Loading data")
(
meta_channels,
meta_mouse_markings,
meta_recording_date,
meta_stimulation_times,
meta_experiment_names,
meta_trial_recording_duration,
meta_frame_time,
meta_mouse,
data,
) = data_raw_loader(
raw_data_path=raw_data_path,
mylogger=mylogger,
experiment_id=first_experiment_id,
trial_id=first_trial_id,
device=device,
force_to_cpu_memory=force_to_cpu_memory,
config=config,
)
mylogger.info("-==- Done -==-")
output_path = config["ref_image_path"]
mylogger.info(f"Create directory {output_path} in the case it does not exist")
os.makedirs(output_path, exist_ok=True)
mylogger.info("Reference images")
for i in range(0, len(meta_channels)):
temp_path: str = os.path.join(output_path, meta_channels[i] + ".npy")
mylogger.info(f"Extract and save: {temp_path}")
frame_id: int = data.shape[-2] // 2
mylogger.info(f"Will use frame id: {frame_id}")
ref_image: np.ndarray = (
data[:, :, frame_id, meta_channels.index(meta_channels[i])]
.clone()
.cpu()
.numpy()
def main(*, config_filename: str = "config.json") -> None:
mylogger = create_logger(
save_logging_messages=True,
display_logging_messages=True,
log_stage_name="stage_1",
)
np.save(temp_path, ref_image)
mylogger.info("-==- Done -==-")
sample_frequency: float = 1.0 / meta_frame_time
mylogger.info(
config = load_config(mylogger=mylogger, filename=config_filename)
if config["binning_enable"] and (config["binning_at_the_end"] is False):
device: torch.device = torch.device("cpu")
else:
device = get_torch_device(mylogger, config["force_to_cpu"])
raw_data_path: str = os.path.join(
config["basic_path"],
config["recoding_data"],
config["mouse_identifier"],
config["raw_path"],
)
mylogger.info(f"Using data path: {raw_data_path}")
first_experiment_id: int = int(get_experiments(raw_data_path).min())
first_trial_id: int = int(get_trials(raw_data_path, first_experiment_id).min())
meta_channels: list[str]
meta_mouse_markings: str
meta_recording_date: str
meta_stimulation_times: dict
meta_experiment_names: dict
meta_trial_recording_duration: float
meta_frame_time: float
meta_mouse: str
data: torch.Tensor
if config["binning_enable"] and (config["binning_at_the_end"] is False):
force_to_cpu_memory: bool = True
else:
force_to_cpu_memory = False
mylogger.info("Loading data")
(
f"Heartbeat power {config['lower_freqency_bandpass']} Hz"
f" - {config['upper_freqency_bandpass']} Hz,"
f" sample-rate: {sample_frequency},"
f" skipping the first {config['skip_frames_in_the_beginning']} frames"
meta_channels,
meta_mouse_markings,
meta_recording_date,
meta_stimulation_times,
meta_experiment_names,
meta_trial_recording_duration,
meta_frame_time,
meta_mouse,
data,
) = data_raw_loader(
raw_data_path=raw_data_path,
mylogger=mylogger,
experiment_id=first_experiment_id,
trial_id=first_trial_id,
device=device,
force_to_cpu_memory=force_to_cpu_memory,
config=config,
)
)
mylogger.info("-==- Done -==-")
for i in range(0, len(meta_channels)):
temp_path = os.path.join(output_path, meta_channels[i] + "_var.npy")
mylogger.info(f"Extract and save: {temp_path}")
output_path = config["ref_image_path"]
mylogger.info(f"Create directory {output_path} in the case it does not exist")
os.makedirs(output_path, exist_ok=True)
heartbeat_ts: torch.Tensor = bandpass(
data=data[..., i],
low_frequency=config["lower_freqency_bandpass"],
high_frequency=config["upper_freqency_bandpass"],
fs=sample_frequency,
filtfilt_chuck_size=10,
mylogger.info("Reference images")
for i in range(0, len(meta_channels)):
temp_path: str = os.path.join(output_path, meta_channels[i] + ".npy")
mylogger.info(f"Extract and save: {temp_path}")
frame_id: int = data.shape[-2] // 2
mylogger.info(f"Will use frame id: {frame_id}")
ref_image: np.ndarray = (
data[:, :, frame_id, meta_channels.index(meta_channels[i])]
.clone()
.cpu()
.numpy()
)
np.save(temp_path, ref_image)
mylogger.info("-==- Done -==-")
sample_frequency: float = 1.0 / meta_frame_time
mylogger.info(
(
f"Heartbeat power {config['lower_freqency_bandpass']} Hz"
f" - {config['upper_freqency_bandpass']} Hz,"
f" sample-rate: {sample_frequency},"
f" skipping the first {config['skip_frames_in_the_beginning']} frames"
)
)
heartbeat_power = heartbeat_ts[..., config["skip_frames_in_the_beginning"] :].var(
dim=-1
)
np.save(temp_path, heartbeat_power)
for i in range(0, len(meta_channels)):
temp_path = os.path.join(output_path, meta_channels[i] + "_var.npy")
mylogger.info(f"Extract and save: {temp_path}")
mylogger.info("-==- Done -==-")
heartbeat_ts: torch.Tensor = bandpass(
data=data[..., i],
low_frequency=config["lower_freqency_bandpass"],
high_frequency=config["upper_freqency_bandpass"],
fs=sample_frequency,
filtfilt_chuck_size=10,
)
heartbeat_power = heartbeat_ts[
..., config["skip_frames_in_the_beginning"] :
].var(dim=-1)
np.save(temp_path, heartbeat_power)
mylogger.info("-==- Done -==-")
if __name__ == "__main__":
argh.dispatch_command(main)

View file

@ -3,6 +3,7 @@ import matplotlib
import numpy as np
import torch
import os
import argh
from matplotlib.widgets import Slider, Button # type:ignore
from functools import partial
@ -11,143 +12,151 @@ from functions.create_logger import create_logger
from functions.get_torch_device import get_torch_device
from functions.load_config import load_config
mylogger = create_logger(
save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_2"
)
config = load_config(mylogger=mylogger)
def main(*, config_filename: str = "config.json") -> None:
mylogger = create_logger(
save_logging_messages=True,
display_logging_messages=True,
log_stage_name="stage_2",
)
path: str = config["ref_image_path"]
use_channel: str = "donor"
spatial_width: float = 4.0
temporal_width: float = 0.1
config = load_config(mylogger=mylogger, filename=config_filename)
threshold: float = 0.05
path: str = config["ref_image_path"]
use_channel: str = "donor"
spatial_width: float = 4.0
temporal_width: float = 0.1
heartbeat_mask_threshold_file: str = os.path.join(path, "heartbeat_mask_threshold.npy")
if os.path.isfile(heartbeat_mask_threshold_file):
mylogger.info(f"loading previous threshold file: {heartbeat_mask_threshold_file}")
threshold = float(np.load(heartbeat_mask_threshold_file)[0])
threshold: float = 0.05
mylogger.info(f"initial threshold is {threshold}")
heartbeat_mask_threshold_file: str = os.path.join(
path, "heartbeat_mask_threshold.npy"
)
if os.path.isfile(heartbeat_mask_threshold_file):
mylogger.info(
f"loading previous threshold file: {heartbeat_mask_threshold_file}"
)
threshold = float(np.load(heartbeat_mask_threshold_file)[0])
image_ref_file: str = os.path.join(path, use_channel + ".npy")
image_var_file: str = os.path.join(path, use_channel + "_var.npy")
heartbeat_mask_file: str = os.path.join(path, "heartbeat_mask.npy")
mylogger.info(f"initial threshold is {threshold}")
device = get_torch_device(mylogger, config["force_to_cpu"])
image_ref_file: str = os.path.join(path, use_channel + ".npy")
image_var_file: str = os.path.join(path, use_channel + "_var.npy")
heartbeat_mask_file: str = os.path.join(path, "heartbeat_mask.npy")
device = get_torch_device(mylogger, config["force_to_cpu"])
def next_frame(
i: float, images: np.ndarray, image_handle: matplotlib.image.AxesImage
) -> None:
global threshold
threshold = i
mylogger.info(f"loading image reference file: {image_ref_file}")
image_ref: np.ndarray = np.load(image_ref_file)
image_ref /= image_ref.max()
display_image: np.ndarray = images.copy()
mylogger.info(f"loading image heartbeat power: {image_var_file}")
image_var: np.ndarray = np.load(image_var_file)
image_var /= image_var.max()
mylogger.info("Smear the image heartbeat power patially")
temp, _ = gauss_smear_individual(
input=torch.tensor(image_var[..., np.newaxis], device=device),
spatial_width=spatial_width,
temporal_width=temporal_width,
use_matlab_mask=False,
)
temp /= temp.max()
mylogger.info("-==- DONE -==-")
image_3color = np.concatenate(
(
np.zeros_like(image_ref[..., np.newaxis]),
image_ref[..., np.newaxis],
temp.cpu().numpy(),
),
axis=-1,
)
mylogger.info("Prepare image")
display_image = image_3color.copy()
display_image[..., 2] = display_image[..., 0]
mask: np.ndarray = np.where(images[..., 2] >= i, 1.0, np.nan)[..., np.newaxis]
mask = np.where(image_3color[..., 2] >= threshold, 1.0, np.nan)[..., np.newaxis]
display_image *= mask
display_image = np.nan_to_num(display_image, nan=1.0)
image_handle.set_data(display_image)
return
value_sort = np.sort(image_var.flatten())
value_sort_max = value_sort[int(value_sort.shape[0] * 0.95)]
mylogger.info("-==- DONE -==-")
mylogger.info("Create figure")
fig: matplotlib.figure.Figure = plt.figure()
image_handle = plt.imshow(display_image, vmin=0, vmax=1, cmap="hot")
mylogger.info("Add controls")
def next_frame(
i: float, images: np.ndarray, image_handle: matplotlib.image.AxesImage
) -> None:
nonlocal threshold
threshold = i
display_image: np.ndarray = images.copy()
display_image[..., 2] = display_image[..., 0]
mask: np.ndarray = np.where(images[..., 2] >= i, 1.0, np.nan)[..., np.newaxis]
display_image *= mask
display_image = np.nan_to_num(display_image, nan=1.0)
image_handle.set_data(display_image)
return
def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None:
nonlocal threshold
nonlocal image_3color
nonlocal path
nonlocal mylogger
nonlocal heartbeat_mask_file
nonlocal heartbeat_mask_threshold_file
mylogger.info(f"Threshold: {threshold}")
mask: np.ndarray = image_3color[..., 2] >= threshold
mylogger.info(f"Save mask to: {heartbeat_mask_file}")
np.save(heartbeat_mask_file, mask)
mylogger.info(f"Save threshold to: {heartbeat_mask_threshold_file}")
np.save(heartbeat_mask_threshold_file, np.array([threshold]))
exit()
def on_clicked_cancel(event: matplotlib.backend_bases.MouseEvent) -> None:
exit()
axfreq = fig.add_axes(rect=(0.4, 0.9, 0.3, 0.03))
slice_slider = Slider(
ax=axfreq,
label="Threshold",
valmin=0,
valmax=value_sort_max,
valinit=threshold,
valstep=value_sort_max / 100.0,
)
axbutton_accept = fig.add_axes(rect=(0.3, 0.85, 0.2, 0.04))
button_accept = Button(
ax=axbutton_accept, label="Accept", image=None, color="0.85", hovercolor="0.95"
)
button_accept.on_clicked(on_clicked_accept) # type: ignore
axbutton_cancel = fig.add_axes(rect=(0.55, 0.85, 0.2, 0.04))
button_cancel = Button(
ax=axbutton_cancel, label="Cancel", image=None, color="0.85", hovercolor="0.95"
)
button_cancel.on_clicked(on_clicked_cancel) # type: ignore
slice_slider.on_changed(
partial(next_frame, images=image_3color, image_handle=image_handle)
)
mylogger.info("Display")
plt.show()
def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None:
global threshold
global image_3color
global path
global mylogger
global heartbeat_mask_file
global heartbeat_mask_threshold_file
mylogger.info(f"Threshold: {threshold}")
mask: np.ndarray = image_3color[..., 2] >= threshold
mylogger.info(f"Save mask to: {heartbeat_mask_file}")
np.save(heartbeat_mask_file, mask)
mylogger.info(f"Save threshold to: {heartbeat_mask_threshold_file}")
np.save(heartbeat_mask_threshold_file, np.array([threshold]))
exit()
def on_clicked_cancel(event: matplotlib.backend_bases.MouseEvent) -> None:
exit()
mylogger.info(f"loading image reference file: {image_ref_file}")
image_ref: np.ndarray = np.load(image_ref_file)
image_ref /= image_ref.max()
mylogger.info(f"loading image heartbeat power: {image_var_file}")
image_var: np.ndarray = np.load(image_var_file)
image_var /= image_var.max()
mylogger.info("Smear the image heartbeat power patially")
temp, _ = gauss_smear_individual(
input=torch.tensor(image_var[..., np.newaxis], device=device),
spatial_width=spatial_width,
temporal_width=temporal_width,
use_matlab_mask=False,
)
temp /= temp.max()
mylogger.info("-==- DONE -==-")
image_3color = np.concatenate(
(
np.zeros_like(image_ref[..., np.newaxis]),
image_ref[..., np.newaxis],
temp.cpu().numpy(),
),
axis=-1,
)
mylogger.info("Prepare image")
display_image = image_3color.copy()
display_image[..., 2] = display_image[..., 0]
mask = np.where(image_3color[..., 2] >= threshold, 1.0, np.nan)[..., np.newaxis]
display_image *= mask
display_image = np.nan_to_num(display_image, nan=1.0)
value_sort = np.sort(image_var.flatten())
value_sort_max = value_sort[int(value_sort.shape[0] * 0.95)]
mylogger.info("-==- DONE -==-")
mylogger.info("Create figure")
fig: matplotlib.figure.Figure = plt.figure()
image_handle = plt.imshow(display_image, vmin=0, vmax=1, cmap="hot")
mylogger.info("Add controls")
axfreq = fig.add_axes(rect=(0.4, 0.9, 0.3, 0.03))
slice_slider = Slider(
ax=axfreq,
label="Threshold",
valmin=0,
valmax=value_sort_max,
valinit=threshold,
valstep=value_sort_max / 100.0,
)
axbutton_accept = fig.add_axes(rect=(0.3, 0.85, 0.2, 0.04))
button_accept = Button(
ax=axbutton_accept, label="Accept", image=None, color="0.85", hovercolor="0.95"
)
button_accept.on_clicked(on_clicked_accept) # type: ignore
axbutton_cancel = fig.add_axes(rect=(0.55, 0.85, 0.2, 0.04))
button_cancel = Button(
ax=axbutton_cancel, label="Cancel", image=None, color="0.85", hovercolor="0.95"
)
button_cancel.on_clicked(on_clicked_cancel) # type: ignore
slice_slider.on_changed(
partial(next_frame, images=image_3color, image_handle=image_handle)
)
mylogger.info("Display")
plt.show()
if __name__ == "__main__":
argh.dispatch_command(main)

View file

@ -9,9 +9,10 @@ from matplotlib.widgets import Button # type:ignore
from roipoly import RoiPoly # type:ignore
from functions.create_logger import create_logger
from functions.get_torch_device import get_torch_device
from functions.load_config import load_config
import argh
def compose_image(image_3color: np.ndarray, mask: np.ndarray) -> np.ndarray:
display_image = image_3color.copy()
@ -20,138 +21,145 @@ def compose_image(image_3color: np.ndarray, mask: np.ndarray) -> np.ndarray:
return display_image
def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None:
global mylogger
global refined_mask_file
global mask
def main(*, config_filename: str = "config.json") -> None:
mylogger = create_logger(
save_logging_messages=True,
display_logging_messages=True,
log_stage_name="stage_3",
)
mylogger.info(f"Save mask to: {refined_mask_file}")
np.save(refined_mask_file, mask)
config = load_config(mylogger=mylogger, filename=config_filename)
exit()
path: str = config["ref_image_path"]
use_channel: str = "donor"
image_ref_file: str = os.path.join(path, use_channel + ".npy")
heartbeat_mask_file: str = os.path.join(path, "heartbeat_mask.npy")
refined_mask_file: str = os.path.join(path, "mask_not_rotated.npy")
mylogger.info(f"loading image reference file: {image_ref_file}")
image_ref: np.ndarray = np.load(image_ref_file)
image_ref /= image_ref.max()
mylogger.info(f"loading heartbeat mask: {heartbeat_mask_file}")
mask: np.ndarray = np.load(heartbeat_mask_file)
image_3color = np.concatenate(
(
np.zeros_like(image_ref[..., np.newaxis]),
image_ref[..., np.newaxis],
np.zeros_like(image_ref[..., np.newaxis]),
),
axis=-1,
)
mylogger.info("-==- DONE -==-")
fig, ax_main = plt.subplots()
display_image = compose_image(image_3color=image_3color, mask=mask)
image_handle = ax_main.imshow(display_image, vmin=0, vmax=1, cmap="hot")
mylogger.info("Add controls")
def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None:
nonlocal mylogger
nonlocal refined_mask_file
nonlocal mask
mylogger.info(f"Save mask to: {refined_mask_file}")
np.save(refined_mask_file, mask)
exit()
def on_clicked_cancel(event: matplotlib.backend_bases.MouseEvent) -> None:
nonlocal mylogger
mylogger.info("Ended without saving the mask")
exit()
def on_clicked_add(event: matplotlib.backend_bases.MouseEvent) -> None:
nonlocal new_roi # type: ignore
nonlocal mask
nonlocal image_3color
nonlocal display_image
nonlocal mylogger
if len(new_roi.x) > 0:
mylogger.info(
"A ROI with the following coordiantes has been added to the mask"
)
for i in range(0, len(new_roi.x)):
mylogger.info(f"{round(new_roi.x[i],1)} x {round(new_roi.y[i],1)}")
mylogger.info("")
new_mask = new_roi.get_mask(display_image[:, :, 0])
mask[new_mask] = 0.0
display_image = compose_image(image_3color=image_3color, mask=mask)
image_handle.set_data(display_image)
for line in ax_main.lines:
line.remove()
plt.draw()
new_roi = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False)
def on_clicked_remove(event: matplotlib.backend_bases.MouseEvent) -> None:
nonlocal new_roi # type: ignore
nonlocal mask
nonlocal image_3color
nonlocal display_image
if len(new_roi.x) > 0:
mylogger.info(
"A ROI with the following coordiantes has been removed from the mask"
)
for i in range(0, len(new_roi.x)):
mylogger.info(f"{round(new_roi.x[i],1)} x {round(new_roi.y[i],1)}")
mylogger.info("")
new_mask = new_roi.get_mask(display_image[:, :, 0])
mask[new_mask] = 1.0
display_image = compose_image(image_3color=image_3color, mask=mask)
image_handle.set_data(display_image)
for line in ax_main.lines:
line.remove()
plt.draw()
new_roi = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False)
axbutton_accept = fig.add_axes(rect=(0.3, 0.85, 0.2, 0.04))
button_accept = Button(
ax=axbutton_accept, label="Accept", image=None, color="0.85", hovercolor="0.95"
)
button_accept.on_clicked(on_clicked_accept) # type: ignore
axbutton_cancel = fig.add_axes(rect=(0.5, 0.85, 0.2, 0.04))
button_cancel = Button(
ax=axbutton_cancel, label="Cancel", image=None, color="0.85", hovercolor="0.95"
)
button_cancel.on_clicked(on_clicked_cancel) # type: ignore
axbutton_addmask = fig.add_axes(rect=(0.3, 0.9, 0.2, 0.04))
button_addmask = Button(
ax=axbutton_addmask,
label="Add mask",
image=None,
color="0.85",
hovercolor="0.95",
)
button_addmask.on_clicked(on_clicked_add) # type: ignore
axbutton_removemask = fig.add_axes(rect=(0.5, 0.9, 0.2, 0.04))
button_removemask = Button(
ax=axbutton_removemask,
label="Remove mask",
image=None,
color="0.85",
hovercolor="0.95",
)
button_removemask.on_clicked(on_clicked_remove) # type: ignore
# ax_main.cla()
mylogger.info("Display")
new_roi: RoiPoly = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False)
plt.show()
def on_clicked_cancel(event: matplotlib.backend_bases.MouseEvent) -> None:
global mylogger
mylogger.info("Ended without saving the mask")
exit()
def on_clicked_add(event: matplotlib.backend_bases.MouseEvent) -> None:
global new_roi
global mask
global image_3color
global display_image
global mylogger
if len(new_roi.x) > 0:
mylogger.info("A ROI with the following coordiantes has been added to the mask")
for i in range(0, len(new_roi.x)):
mylogger.info(f"{round(new_roi.x[i],1)} x {round(new_roi.y[i],1)}")
mylogger.info("")
new_mask = new_roi.get_mask(display_image[:, :, 0])
mask[new_mask] = 0.0
display_image = compose_image(image_3color=image_3color, mask=mask)
image_handle.set_data(display_image)
for line in ax_main.lines:
line.remove()
plt.draw()
new_roi = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False)
def on_clicked_remove(event: matplotlib.backend_bases.MouseEvent) -> None:
global new_roi
global mask
global image_3color
global display_image
if len(new_roi.x) > 0:
mylogger.info(
"A ROI with the following coordiantes has been removed from the mask"
)
for i in range(0, len(new_roi.x)):
mylogger.info(f"{round(new_roi.x[i],1)} x {round(new_roi.y[i],1)}")
mylogger.info("")
new_mask = new_roi.get_mask(display_image[:, :, 0])
mask[new_mask] = 1.0
display_image = compose_image(image_3color=image_3color, mask=mask)
image_handle.set_data(display_image)
for line in ax_main.lines:
line.remove()
plt.draw()
new_roi = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False)
mylogger = create_logger(
save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_3"
)
config = load_config(mylogger=mylogger)
device = get_torch_device(mylogger, config["force_to_cpu"])
path: str = config["ref_image_path"]
use_channel: str = "donor"
image_ref_file: str = os.path.join(path, use_channel + ".npy")
heartbeat_mask_file: str = os.path.join(path, "heartbeat_mask.npy")
refined_mask_file: str = os.path.join(path, "mask_not_rotated.npy")
mylogger.info(f"loading image reference file: {image_ref_file}")
image_ref: np.ndarray = np.load(image_ref_file)
image_ref /= image_ref.max()
mylogger.info(f"loading heartbeat mask: {heartbeat_mask_file}")
mask: np.ndarray = np.load(heartbeat_mask_file)
image_3color = np.concatenate(
(
np.zeros_like(image_ref[..., np.newaxis]),
image_ref[..., np.newaxis],
np.zeros_like(image_ref[..., np.newaxis]),
),
axis=-1,
)
mylogger.info("-==- DONE -==-")
fig, ax_main = plt.subplots()
display_image = compose_image(image_3color=image_3color, mask=mask)
image_handle = ax_main.imshow(display_image, vmin=0, vmax=1, cmap="hot")
mylogger.info("Add controls")
axbutton_accept = fig.add_axes(rect=(0.3, 0.85, 0.2, 0.04))
button_accept = Button(
ax=axbutton_accept, label="Accept", image=None, color="0.85", hovercolor="0.95"
)
button_accept.on_clicked(on_clicked_accept) # type: ignore
axbutton_cancel = fig.add_axes(rect=(0.5, 0.85, 0.2, 0.04))
button_cancel = Button(
ax=axbutton_cancel, label="Cancel", image=None, color="0.85", hovercolor="0.95"
)
button_cancel.on_clicked(on_clicked_cancel) # type: ignore
axbutton_addmask = fig.add_axes(rect=(0.3, 0.9, 0.2, 0.04))
button_addmask = Button(
ax=axbutton_addmask, label="Add mask", image=None, color="0.85", hovercolor="0.95"
)
button_addmask.on_clicked(on_clicked_add) # type: ignore
axbutton_removemask = fig.add_axes(rect=(0.5, 0.9, 0.2, 0.04))
button_removemask = Button(
ax=axbutton_removemask,
label="Remove mask",
image=None,
color="0.85",
hovercolor="0.95",
)
button_removemask.on_clicked(on_clicked_remove) # type: ignore
# ax_main.cla()
mylogger.info("Display")
new_roi: RoiPoly = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False)
plt.show()
if __name__ == "__main__":
argh.dispatch_command(main)

View file

@ -20,6 +20,8 @@ from functions.gauss_smear_individual import gauss_smear_individual
from functions.regression import regression
from functions.data_raw_loader import data_raw_loader
import argh
@torch.no_grad()
def process_trial(
@ -889,71 +891,96 @@ def process_trial(
return
mylogger = create_logger(
save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_4"
)
config = load_config(mylogger=mylogger)
if (config["save_as_python"] is False) and (config["save_as_matlab"] is False):
mylogger.info("No output will be created. ")
mylogger.info("Change save_as_python and/or save_as_matlab in the config file")
mylogger.info("ERROR: STOP!!!")
exit()
if (len(config["target_camera_donor"]) == 0) and (
len(config["target_camera_acceptor"]) == 0
):
mylogger.info(
"Configure at least target_camera_donor or target_camera_acceptor correctly."
def main(
*,
config_filename: str = "config.json",
experiment_id_overwrite: int = -1,
trial_id_overwrite: int = -1,
) -> None:
mylogger = create_logger(
save_logging_messages=True,
display_logging_messages=True,
log_stage_name="stage_4",
)
mylogger.info("ERROR: STOP!!!")
exit()
device = get_torch_device(mylogger, config["force_to_cpu"])
config = load_config(mylogger=mylogger, filename=config_filename)
mylogger.info(f"Create directory {config['export_path']} in the case it does not exist")
os.makedirs(config["export_path"], exist_ok=True)
if (config["save_as_python"] is False) and (config["save_as_matlab"] is False):
mylogger.info("No output will be created. ")
mylogger.info("Change save_as_python and/or save_as_matlab in the config file")
mylogger.info("ERROR: STOP!!!")
exit()
raw_data_path: str = os.path.join(
config["basic_path"],
config["recoding_data"],
config["mouse_identifier"],
config["raw_path"],
)
if os.path.isdir(raw_data_path) is False:
mylogger.info(f"ERROR: could not find raw directory {raw_data_path}!!!!")
exit()
experiments = get_experiments(raw_data_path)
for experiment_counter in range(0, experiments.shape[0]):
experiment_id = int(experiments[experiment_counter])
trials = get_trials(raw_data_path, experiment_id)
for trial_counter in range(0, trials.shape[0]):
trial_id = int(trials[trial_counter])
mylogger.info("")
if (len(config["target_camera_donor"]) == 0) and (
len(config["target_camera_acceptor"]) == 0
):
mylogger.info(
f"======= EXPERIMENT ID: {experiment_id} ==== TRIAL ID: {trial_id} ======="
"Configure at least target_camera_donor or target_camera_acceptor correctly."
)
mylogger.info("")
mylogger.info("ERROR: STOP!!!")
exit()
try:
process_trial(
config=config,
mylogger=mylogger,
experiment_id=experiment_id,
trial_id=trial_id,
device=device,
)
except torch.cuda.OutOfMemoryError:
mylogger.info("WARNING: RUNNING IN FAILBACK MODE!!!!")
mylogger.info("Not enough GPU memory. Retry on CPU")
process_trial(
config=config,
mylogger=mylogger,
experiment_id=experiment_id,
trial_id=trial_id,
device=torch.device("cpu"),
device = get_torch_device(mylogger, config["force_to_cpu"])
mylogger.info(
f"Create directory {config['export_path']} in the case it does not exist"
)
os.makedirs(config["export_path"], exist_ok=True)
raw_data_path: str = os.path.join(
config["basic_path"],
config["recoding_data"],
config["mouse_identifier"],
config["raw_path"],
)
if os.path.isdir(raw_data_path) is False:
mylogger.info(f"ERROR: could not find raw directory {raw_data_path}!!!!")
exit()
if experiment_id_overwrite == -1:
experiments = get_experiments(raw_data_path)
else:
assert experiment_id_overwrite >= 0
experiments = torch.tensor([experiment_id_overwrite])
for experiment_counter in range(0, experiments.shape[0]):
experiment_id = int(experiments[experiment_counter])
if trial_id_overwrite == -1:
trials = get_trials(raw_data_path, experiment_id)
else:
assert trial_id_overwrite >= 0
trials = torch.tensor([trial_id_overwrite])
for trial_counter in range(0, trials.shape[0]):
trial_id = int(trials[trial_counter])
mylogger.info("")
mylogger.info(
f"======= EXPERIMENT ID: {experiment_id} ==== TRIAL ID: {trial_id} ======="
)
mylogger.info("")
try:
process_trial(
config=config,
mylogger=mylogger,
experiment_id=experiment_id,
trial_id=trial_id,
device=device,
)
except torch.cuda.OutOfMemoryError:
mylogger.info("WARNING: RUNNING IN FAILBACK MODE!!!!")
mylogger.info("Not enough GPU memory. Retry on CPU")
process_trial(
config=config,
mylogger=mylogger,
experiment_id=experiment_id,
trial_id=trial_id,
device=torch.device("cpu"),
)
if __name__ == "__main__":
argh.dispatch_command(main)