Add files via upload
This commit is contained in:
parent
db43df93eb
commit
2290dfe0d9
8 changed files with 648 additions and 420 deletions
60
config_M3879M_2021-10-05.json
Normal file
60
config_M3879M_2021-10-05.json
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
{
|
||||||
|
"basic_path": "/data_1/robert",
|
||||||
|
"recoding_data": "2021-10-05",
|
||||||
|
"mouse_identifier": "M3879M",
|
||||||
|
"raw_path": "raw",
|
||||||
|
"export_path": "output_M3879M_2021-10-05",
|
||||||
|
"ref_image_path": "ref_images_M3879M_2021-10-05",
|
||||||
|
// Ratio Sequence
|
||||||
|
"classical_ratio_mode": true, // true: a/d false: 1+a-d
|
||||||
|
// Regression
|
||||||
|
//"target_camera_acceptor": "acceptor",
|
||||||
|
"target_camera_acceptor": "",
|
||||||
|
"regressor_cameras_acceptor": [
|
||||||
|
"oxygenation",
|
||||||
|
"volume"
|
||||||
|
],
|
||||||
|
"target_camera_donor": "donor",
|
||||||
|
"regressor_cameras_donor": [
|
||||||
|
// "oxygenation",
|
||||||
|
"volume"
|
||||||
|
],
|
||||||
|
// binning
|
||||||
|
"binning_enable": true,
|
||||||
|
"binning_at_the_end": false,
|
||||||
|
"binning_kernel_size": 4,
|
||||||
|
"binning_stride": 4,
|
||||||
|
"binning_divisor_override": 1,
|
||||||
|
// alignment
|
||||||
|
"alignment_batch_size": 200,
|
||||||
|
"rotation_stabilization_threshold_factor": 3.0, // >= 1.0
|
||||||
|
"rotation_stabilization_threshold_border": 0.9, // <= 1.0
|
||||||
|
// Heart beat detection
|
||||||
|
"lower_freqency_bandpass": 5.0, // Hz
|
||||||
|
"upper_freqency_bandpass": 14.0, // Hz
|
||||||
|
"heartbeat_filtfilt_chuck_size": 10,
|
||||||
|
// Gauss smear
|
||||||
|
"gauss_smear_spatial_width": 8,
|
||||||
|
"gauss_smear_temporal_width": 0.1,
|
||||||
|
"gauss_smear_use_matlab_mask": false,
|
||||||
|
// LED Ramp on
|
||||||
|
"skip_frames_in_the_beginning": 100, // Frames
|
||||||
|
// PyTorch
|
||||||
|
"dtype": "float32",
|
||||||
|
"force_to_cpu": false,
|
||||||
|
// Save
|
||||||
|
"save_as_python": true, // produces .npz files (compressed)
|
||||||
|
"save_as_matlab": false, // produces .hd5 file (compressed)
|
||||||
|
// Save extra information
|
||||||
|
"save_alignment": false,
|
||||||
|
"save_heartbeat": false,
|
||||||
|
"save_factors": false,
|
||||||
|
"save_regression_coefficients": false,
|
||||||
|
// Not important parameter
|
||||||
|
"required_order": [
|
||||||
|
"acceptor",
|
||||||
|
"donor",
|
||||||
|
"oxygenation",
|
||||||
|
"volume"
|
||||||
|
]
|
||||||
|
}
|
60
config_M_Sert_Cre_41.json
Normal file
60
config_M_Sert_Cre_41.json
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
{
|
||||||
|
"basic_path": "/data_1/hendrik",
|
||||||
|
"recoding_data": "2023-07-17",
|
||||||
|
"mouse_identifier": "M_Sert_Cre_41",
|
||||||
|
"raw_path": "raw",
|
||||||
|
"export_path": "output_M_Sert_Cre_41",
|
||||||
|
"ref_image_path": "ref_images_M_Sert_Cre_41",
|
||||||
|
// Ratio Sequence
|
||||||
|
"classical_ratio_mode": true, // true: a/d false: 1+a-d
|
||||||
|
// Regression
|
||||||
|
//"target_camera_acceptor": "acceptor",
|
||||||
|
"target_camera_acceptor": "",
|
||||||
|
"regressor_cameras_acceptor": [
|
||||||
|
"oxygenation",
|
||||||
|
"volume"
|
||||||
|
],
|
||||||
|
"target_camera_donor": "donor",
|
||||||
|
"regressor_cameras_donor": [
|
||||||
|
// "oxygenation",
|
||||||
|
"volume"
|
||||||
|
],
|
||||||
|
// binning
|
||||||
|
"binning_enable": true,
|
||||||
|
"binning_at_the_end": false,
|
||||||
|
"binning_kernel_size": 4,
|
||||||
|
"binning_stride": 4,
|
||||||
|
"binning_divisor_override": 1,
|
||||||
|
// alignment
|
||||||
|
"alignment_batch_size": 200,
|
||||||
|
"rotation_stabilization_threshold_factor": 3.0, // >= 1.0
|
||||||
|
"rotation_stabilization_threshold_border": 0.9, // <= 1.0
|
||||||
|
// Heart beat detection
|
||||||
|
"lower_freqency_bandpass": 5.0, // Hz
|
||||||
|
"upper_freqency_bandpass": 14.0, // Hz
|
||||||
|
"heartbeat_filtfilt_chuck_size": 10,
|
||||||
|
// Gauss smear
|
||||||
|
"gauss_smear_spatial_width": 8,
|
||||||
|
"gauss_smear_temporal_width": 0.1,
|
||||||
|
"gauss_smear_use_matlab_mask": false,
|
||||||
|
// LED Ramp on
|
||||||
|
"skip_frames_in_the_beginning": 100, // Frames
|
||||||
|
// PyTorch
|
||||||
|
"dtype": "float32",
|
||||||
|
"force_to_cpu": false,
|
||||||
|
// Save
|
||||||
|
"save_as_python": true, // produces .npz files (compressed)
|
||||||
|
"save_as_matlab": false, // produces .hd5 file (compressed)
|
||||||
|
// Save extra information
|
||||||
|
"save_alignment": false,
|
||||||
|
"save_heartbeat": false,
|
||||||
|
"save_factors": false,
|
||||||
|
"save_regression_coefficients": false,
|
||||||
|
// Not important parameter
|
||||||
|
"required_order": [
|
||||||
|
"acceptor",
|
||||||
|
"donor",
|
||||||
|
"oxygenation",
|
||||||
|
"volume"
|
||||||
|
]
|
||||||
|
}
|
60
config_M_Sert_Cre_49.json
Normal file
60
config_M_Sert_Cre_49.json
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
{
|
||||||
|
"basic_path": "/data_1/hendrik",
|
||||||
|
"recoding_data": "2023-03-15",
|
||||||
|
"mouse_identifier": "M_Sert_Cre_49",
|
||||||
|
"raw_path": "raw",
|
||||||
|
"export_path": "output_M_Sert_Cre_49",
|
||||||
|
"ref_image_path": "ref_images_M_Sert_Cre_49",
|
||||||
|
// Ratio Sequence
|
||||||
|
"classical_ratio_mode": true, // true: a/d false: 1+a-d
|
||||||
|
// Regression
|
||||||
|
//"target_camera_acceptor": "acceptor",
|
||||||
|
"target_camera_acceptor": "",
|
||||||
|
"regressor_cameras_acceptor": [
|
||||||
|
"oxygenation",
|
||||||
|
"volume"
|
||||||
|
],
|
||||||
|
"target_camera_donor": "donor",
|
||||||
|
"regressor_cameras_donor": [
|
||||||
|
// "oxygenation",
|
||||||
|
"volume"
|
||||||
|
],
|
||||||
|
// binning
|
||||||
|
"binning_enable": true,
|
||||||
|
"binning_at_the_end": false,
|
||||||
|
"binning_kernel_size": 4,
|
||||||
|
"binning_stride": 4,
|
||||||
|
"binning_divisor_override": 1,
|
||||||
|
// alignment
|
||||||
|
"alignment_batch_size": 200,
|
||||||
|
"rotation_stabilization_threshold_factor": 3.0, // >= 1.0
|
||||||
|
"rotation_stabilization_threshold_border": 0.9, // <= 1.0
|
||||||
|
// Heart beat detection
|
||||||
|
"lower_freqency_bandpass": 5.0, // Hz
|
||||||
|
"upper_freqency_bandpass": 14.0, // Hz
|
||||||
|
"heartbeat_filtfilt_chuck_size": 10,
|
||||||
|
// Gauss smear
|
||||||
|
"gauss_smear_spatial_width": 8,
|
||||||
|
"gauss_smear_temporal_width": 0.1,
|
||||||
|
"gauss_smear_use_matlab_mask": false,
|
||||||
|
// LED Ramp on
|
||||||
|
"skip_frames_in_the_beginning": 100, // Frames
|
||||||
|
// PyTorch
|
||||||
|
"dtype": "float32",
|
||||||
|
"force_to_cpu": false,
|
||||||
|
// Save
|
||||||
|
"save_as_python": true, // produces .npz files (compressed)
|
||||||
|
"save_as_matlab": false, // produces .hd5 file (compressed)
|
||||||
|
// Save extra information
|
||||||
|
"save_alignment": false,
|
||||||
|
"save_heartbeat": false,
|
||||||
|
"save_factors": false,
|
||||||
|
"save_regression_coefficients": false,
|
||||||
|
// Not important parameter
|
||||||
|
"required_order": [
|
||||||
|
"acceptor",
|
||||||
|
"donor",
|
||||||
|
"oxygenation",
|
||||||
|
"volume"
|
||||||
|
]
|
||||||
|
}
|
|
@ -23,7 +23,7 @@ mylogger = create_logger(
|
||||||
)
|
)
|
||||||
config = load_config(mylogger=mylogger)
|
config = load_config(mylogger=mylogger)
|
||||||
|
|
||||||
experiment_id: int = 1
|
experiment_id: int = 2
|
||||||
|
|
||||||
raw_data_path: str = os.path.join(
|
raw_data_path: str = os.path.join(
|
||||||
config["basic_path"],
|
config["basic_path"],
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import argh
|
||||||
|
|
||||||
from functions.get_experiments import get_experiments
|
from functions.get_experiments import get_experiments
|
||||||
from functions.get_trials import get_trials
|
from functions.get_trials import get_trials
|
||||||
|
@ -11,21 +11,21 @@ from functions.get_torch_device import get_torch_device
|
||||||
from functions.load_config import load_config
|
from functions.load_config import load_config
|
||||||
from functions.data_raw_loader import data_raw_loader
|
from functions.data_raw_loader import data_raw_loader
|
||||||
|
|
||||||
|
|
||||||
|
def main(*, config_filename: str = "config.json") -> None:
|
||||||
mylogger = create_logger(
|
mylogger = create_logger(
|
||||||
save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_1"
|
save_logging_messages=True,
|
||||||
|
display_logging_messages=True,
|
||||||
|
log_stage_name="stage_1",
|
||||||
)
|
)
|
||||||
|
|
||||||
config = load_config(mylogger=mylogger)
|
config = load_config(mylogger=mylogger, filename=config_filename)
|
||||||
|
|
||||||
if config["binning_enable"] and (config["binning_at_the_end"] is False):
|
if config["binning_enable"] and (config["binning_at_the_end"] is False):
|
||||||
device: torch.device = torch.device("cpu")
|
device: torch.device = torch.device("cpu")
|
||||||
else:
|
else:
|
||||||
device = get_torch_device(mylogger, config["force_to_cpu"])
|
device = get_torch_device(mylogger, config["force_to_cpu"])
|
||||||
|
|
||||||
|
|
||||||
dtype_str: str = config["dtype"]
|
|
||||||
dtype: torch.dtype = getattr(torch, dtype_str)
|
|
||||||
|
|
||||||
raw_data_path: str = os.path.join(
|
raw_data_path: str = os.path.join(
|
||||||
config["basic_path"],
|
config["basic_path"],
|
||||||
config["recoding_data"],
|
config["recoding_data"],
|
||||||
|
@ -117,9 +117,13 @@ for i in range(0, len(meta_channels)):
|
||||||
filtfilt_chuck_size=10,
|
filtfilt_chuck_size=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
heartbeat_power = heartbeat_ts[..., config["skip_frames_in_the_beginning"] :].var(
|
heartbeat_power = heartbeat_ts[
|
||||||
dim=-1
|
..., config["skip_frames_in_the_beginning"] :
|
||||||
)
|
].var(dim=-1)
|
||||||
np.save(temp_path, heartbeat_power)
|
np.save(temp_path, heartbeat_power)
|
||||||
|
|
||||||
mylogger.info("-==- Done -==-")
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
argh.dispatch_command(main)
|
||||||
|
|
|
@ -3,6 +3,7 @@ import matplotlib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
import argh
|
||||||
|
|
||||||
from matplotlib.widgets import Slider, Button # type:ignore
|
from matplotlib.widgets import Slider, Button # type:ignore
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
@ -11,11 +12,15 @@ from functions.create_logger import create_logger
|
||||||
from functions.get_torch_device import get_torch_device
|
from functions.get_torch_device import get_torch_device
|
||||||
from functions.load_config import load_config
|
from functions.load_config import load_config
|
||||||
|
|
||||||
|
|
||||||
|
def main(*, config_filename: str = "config.json") -> None:
|
||||||
mylogger = create_logger(
|
mylogger = create_logger(
|
||||||
save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_2"
|
save_logging_messages=True,
|
||||||
|
display_logging_messages=True,
|
||||||
|
log_stage_name="stage_2",
|
||||||
)
|
)
|
||||||
|
|
||||||
config = load_config(mylogger=mylogger)
|
config = load_config(mylogger=mylogger, filename=config_filename)
|
||||||
|
|
||||||
path: str = config["ref_image_path"]
|
path: str = config["ref_image_path"]
|
||||||
use_channel: str = "donor"
|
use_channel: str = "donor"
|
||||||
|
@ -24,9 +29,13 @@ temporal_width: float = 0.1
|
||||||
|
|
||||||
threshold: float = 0.05
|
threshold: float = 0.05
|
||||||
|
|
||||||
heartbeat_mask_threshold_file: str = os.path.join(path, "heartbeat_mask_threshold.npy")
|
heartbeat_mask_threshold_file: str = os.path.join(
|
||||||
|
path, "heartbeat_mask_threshold.npy"
|
||||||
|
)
|
||||||
if os.path.isfile(heartbeat_mask_threshold_file):
|
if os.path.isfile(heartbeat_mask_threshold_file):
|
||||||
mylogger.info(f"loading previous threshold file: {heartbeat_mask_threshold_file}")
|
mylogger.info(
|
||||||
|
f"loading previous threshold file: {heartbeat_mask_threshold_file}"
|
||||||
|
)
|
||||||
threshold = float(np.load(heartbeat_mask_threshold_file)[0])
|
threshold = float(np.load(heartbeat_mask_threshold_file)[0])
|
||||||
|
|
||||||
mylogger.info(f"initial threshold is {threshold}")
|
mylogger.info(f"initial threshold is {threshold}")
|
||||||
|
@ -37,45 +46,6 @@ heartbeat_mask_file: str = os.path.join(path, "heartbeat_mask.npy")
|
||||||
|
|
||||||
device = get_torch_device(mylogger, config["force_to_cpu"])
|
device = get_torch_device(mylogger, config["force_to_cpu"])
|
||||||
|
|
||||||
|
|
||||||
def next_frame(
|
|
||||||
i: float, images: np.ndarray, image_handle: matplotlib.image.AxesImage
|
|
||||||
) -> None:
|
|
||||||
global threshold
|
|
||||||
threshold = i
|
|
||||||
|
|
||||||
display_image: np.ndarray = images.copy()
|
|
||||||
display_image[..., 2] = display_image[..., 0]
|
|
||||||
mask: np.ndarray = np.where(images[..., 2] >= i, 1.0, np.nan)[..., np.newaxis]
|
|
||||||
display_image *= mask
|
|
||||||
display_image = np.nan_to_num(display_image, nan=1.0)
|
|
||||||
|
|
||||||
image_handle.set_data(display_image)
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None:
|
|
||||||
global threshold
|
|
||||||
global image_3color
|
|
||||||
global path
|
|
||||||
global mylogger
|
|
||||||
global heartbeat_mask_file
|
|
||||||
global heartbeat_mask_threshold_file
|
|
||||||
|
|
||||||
mylogger.info(f"Threshold: {threshold}")
|
|
||||||
|
|
||||||
mask: np.ndarray = image_3color[..., 2] >= threshold
|
|
||||||
mylogger.info(f"Save mask to: {heartbeat_mask_file}")
|
|
||||||
np.save(heartbeat_mask_file, mask)
|
|
||||||
mylogger.info(f"Save threshold to: {heartbeat_mask_threshold_file}")
|
|
||||||
np.save(heartbeat_mask_threshold_file, np.array([threshold]))
|
|
||||||
exit()
|
|
||||||
|
|
||||||
|
|
||||||
def on_clicked_cancel(event: matplotlib.backend_bases.MouseEvent) -> None:
|
|
||||||
exit()
|
|
||||||
|
|
||||||
|
|
||||||
mylogger.info(f"loading image reference file: {image_ref_file}")
|
mylogger.info(f"loading image reference file: {image_ref_file}")
|
||||||
image_ref: np.ndarray = np.load(image_ref_file)
|
image_ref: np.ndarray = np.load(image_ref_file)
|
||||||
image_ref /= image_ref.max()
|
image_ref /= image_ref.max()
|
||||||
|
@ -124,6 +94,41 @@ image_handle = plt.imshow(display_image, vmin=0, vmax=1, cmap="hot")
|
||||||
|
|
||||||
mylogger.info("Add controls")
|
mylogger.info("Add controls")
|
||||||
|
|
||||||
|
def next_frame(
|
||||||
|
i: float, images: np.ndarray, image_handle: matplotlib.image.AxesImage
|
||||||
|
) -> None:
|
||||||
|
nonlocal threshold
|
||||||
|
threshold = i
|
||||||
|
|
||||||
|
display_image: np.ndarray = images.copy()
|
||||||
|
display_image[..., 2] = display_image[..., 0]
|
||||||
|
mask: np.ndarray = np.where(images[..., 2] >= i, 1.0, np.nan)[..., np.newaxis]
|
||||||
|
display_image *= mask
|
||||||
|
display_image = np.nan_to_num(display_image, nan=1.0)
|
||||||
|
|
||||||
|
image_handle.set_data(display_image)
|
||||||
|
return
|
||||||
|
|
||||||
|
def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None:
|
||||||
|
nonlocal threshold
|
||||||
|
nonlocal image_3color
|
||||||
|
nonlocal path
|
||||||
|
nonlocal mylogger
|
||||||
|
nonlocal heartbeat_mask_file
|
||||||
|
nonlocal heartbeat_mask_threshold_file
|
||||||
|
|
||||||
|
mylogger.info(f"Threshold: {threshold}")
|
||||||
|
|
||||||
|
mask: np.ndarray = image_3color[..., 2] >= threshold
|
||||||
|
mylogger.info(f"Save mask to: {heartbeat_mask_file}")
|
||||||
|
np.save(heartbeat_mask_file, mask)
|
||||||
|
mylogger.info(f"Save threshold to: {heartbeat_mask_threshold_file}")
|
||||||
|
np.save(heartbeat_mask_threshold_file, np.array([threshold]))
|
||||||
|
exit()
|
||||||
|
|
||||||
|
def on_clicked_cancel(event: matplotlib.backend_bases.MouseEvent) -> None:
|
||||||
|
exit()
|
||||||
|
|
||||||
axfreq = fig.add_axes(rect=(0.4, 0.9, 0.3, 0.03))
|
axfreq = fig.add_axes(rect=(0.4, 0.9, 0.3, 0.03))
|
||||||
slice_slider = Slider(
|
slice_slider = Slider(
|
||||||
ax=axfreq,
|
ax=axfreq,
|
||||||
|
@ -151,3 +156,7 @@ slice_slider.on_changed(
|
||||||
|
|
||||||
mylogger.info("Display")
|
mylogger.info("Display")
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
argh.dispatch_command(main)
|
||||||
|
|
|
@ -9,9 +9,10 @@ from matplotlib.widgets import Button # type:ignore
|
||||||
from roipoly import RoiPoly # type:ignore
|
from roipoly import RoiPoly # type:ignore
|
||||||
|
|
||||||
from functions.create_logger import create_logger
|
from functions.create_logger import create_logger
|
||||||
from functions.get_torch_device import get_torch_device
|
|
||||||
from functions.load_config import load_config
|
from functions.load_config import load_config
|
||||||
|
|
||||||
|
import argh
|
||||||
|
|
||||||
|
|
||||||
def compose_image(image_3color: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
def compose_image(image_3color: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
||||||
display_image = image_3color.copy()
|
display_image = image_3color.copy()
|
||||||
|
@ -20,74 +21,14 @@ def compose_image(image_3color: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
||||||
return display_image
|
return display_image
|
||||||
|
|
||||||
|
|
||||||
def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None:
|
def main(*, config_filename: str = "config.json") -> None:
|
||||||
global mylogger
|
|
||||||
global refined_mask_file
|
|
||||||
global mask
|
|
||||||
|
|
||||||
mylogger.info(f"Save mask to: {refined_mask_file}")
|
|
||||||
np.save(refined_mask_file, mask)
|
|
||||||
|
|
||||||
exit()
|
|
||||||
|
|
||||||
|
|
||||||
def on_clicked_cancel(event: matplotlib.backend_bases.MouseEvent) -> None:
|
|
||||||
global mylogger
|
|
||||||
mylogger.info("Ended without saving the mask")
|
|
||||||
exit()
|
|
||||||
|
|
||||||
|
|
||||||
def on_clicked_add(event: matplotlib.backend_bases.MouseEvent) -> None:
|
|
||||||
global new_roi
|
|
||||||
global mask
|
|
||||||
global image_3color
|
|
||||||
global display_image
|
|
||||||
global mylogger
|
|
||||||
if len(new_roi.x) > 0:
|
|
||||||
mylogger.info("A ROI with the following coordiantes has been added to the mask")
|
|
||||||
for i in range(0, len(new_roi.x)):
|
|
||||||
mylogger.info(f"{round(new_roi.x[i],1)} x {round(new_roi.y[i],1)}")
|
|
||||||
mylogger.info("")
|
|
||||||
new_mask = new_roi.get_mask(display_image[:, :, 0])
|
|
||||||
mask[new_mask] = 0.0
|
|
||||||
display_image = compose_image(image_3color=image_3color, mask=mask)
|
|
||||||
image_handle.set_data(display_image)
|
|
||||||
for line in ax_main.lines:
|
|
||||||
line.remove()
|
|
||||||
plt.draw()
|
|
||||||
|
|
||||||
new_roi = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False)
|
|
||||||
|
|
||||||
|
|
||||||
def on_clicked_remove(event: matplotlib.backend_bases.MouseEvent) -> None:
|
|
||||||
global new_roi
|
|
||||||
global mask
|
|
||||||
global image_3color
|
|
||||||
global display_image
|
|
||||||
if len(new_roi.x) > 0:
|
|
||||||
mylogger.info(
|
|
||||||
"A ROI with the following coordiantes has been removed from the mask"
|
|
||||||
)
|
|
||||||
for i in range(0, len(new_roi.x)):
|
|
||||||
mylogger.info(f"{round(new_roi.x[i],1)} x {round(new_roi.y[i],1)}")
|
|
||||||
mylogger.info("")
|
|
||||||
new_mask = new_roi.get_mask(display_image[:, :, 0])
|
|
||||||
mask[new_mask] = 1.0
|
|
||||||
display_image = compose_image(image_3color=image_3color, mask=mask)
|
|
||||||
image_handle.set_data(display_image)
|
|
||||||
for line in ax_main.lines:
|
|
||||||
line.remove()
|
|
||||||
plt.draw()
|
|
||||||
new_roi = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False)
|
|
||||||
|
|
||||||
|
|
||||||
mylogger = create_logger(
|
mylogger = create_logger(
|
||||||
save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_3"
|
save_logging_messages=True,
|
||||||
|
display_logging_messages=True,
|
||||||
|
log_stage_name="stage_3",
|
||||||
)
|
)
|
||||||
|
|
||||||
config = load_config(mylogger=mylogger)
|
config = load_config(mylogger=mylogger, filename=config_filename)
|
||||||
|
|
||||||
device = get_torch_device(mylogger, config["force_to_cpu"])
|
|
||||||
|
|
||||||
path: str = config["ref_image_path"]
|
path: str = config["ref_image_path"]
|
||||||
use_channel: str = "donor"
|
use_channel: str = "donor"
|
||||||
|
@ -121,6 +62,65 @@ image_handle = ax_main.imshow(display_image, vmin=0, vmax=1, cmap="hot")
|
||||||
|
|
||||||
mylogger.info("Add controls")
|
mylogger.info("Add controls")
|
||||||
|
|
||||||
|
def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None:
|
||||||
|
nonlocal mylogger
|
||||||
|
nonlocal refined_mask_file
|
||||||
|
nonlocal mask
|
||||||
|
|
||||||
|
mylogger.info(f"Save mask to: {refined_mask_file}")
|
||||||
|
np.save(refined_mask_file, mask)
|
||||||
|
|
||||||
|
exit()
|
||||||
|
|
||||||
|
def on_clicked_cancel(event: matplotlib.backend_bases.MouseEvent) -> None:
|
||||||
|
nonlocal mylogger
|
||||||
|
mylogger.info("Ended without saving the mask")
|
||||||
|
exit()
|
||||||
|
|
||||||
|
def on_clicked_add(event: matplotlib.backend_bases.MouseEvent) -> None:
|
||||||
|
nonlocal new_roi # type: ignore
|
||||||
|
nonlocal mask
|
||||||
|
nonlocal image_3color
|
||||||
|
nonlocal display_image
|
||||||
|
nonlocal mylogger
|
||||||
|
if len(new_roi.x) > 0:
|
||||||
|
mylogger.info(
|
||||||
|
"A ROI with the following coordiantes has been added to the mask"
|
||||||
|
)
|
||||||
|
for i in range(0, len(new_roi.x)):
|
||||||
|
mylogger.info(f"{round(new_roi.x[i],1)} x {round(new_roi.y[i],1)}")
|
||||||
|
mylogger.info("")
|
||||||
|
new_mask = new_roi.get_mask(display_image[:, :, 0])
|
||||||
|
mask[new_mask] = 0.0
|
||||||
|
display_image = compose_image(image_3color=image_3color, mask=mask)
|
||||||
|
image_handle.set_data(display_image)
|
||||||
|
for line in ax_main.lines:
|
||||||
|
line.remove()
|
||||||
|
plt.draw()
|
||||||
|
|
||||||
|
new_roi = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False)
|
||||||
|
|
||||||
|
def on_clicked_remove(event: matplotlib.backend_bases.MouseEvent) -> None:
|
||||||
|
nonlocal new_roi # type: ignore
|
||||||
|
nonlocal mask
|
||||||
|
nonlocal image_3color
|
||||||
|
nonlocal display_image
|
||||||
|
if len(new_roi.x) > 0:
|
||||||
|
mylogger.info(
|
||||||
|
"A ROI with the following coordiantes has been removed from the mask"
|
||||||
|
)
|
||||||
|
for i in range(0, len(new_roi.x)):
|
||||||
|
mylogger.info(f"{round(new_roi.x[i],1)} x {round(new_roi.y[i],1)}")
|
||||||
|
mylogger.info("")
|
||||||
|
new_mask = new_roi.get_mask(display_image[:, :, 0])
|
||||||
|
mask[new_mask] = 1.0
|
||||||
|
display_image = compose_image(image_3color=image_3color, mask=mask)
|
||||||
|
image_handle.set_data(display_image)
|
||||||
|
for line in ax_main.lines:
|
||||||
|
line.remove()
|
||||||
|
plt.draw()
|
||||||
|
new_roi = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False)
|
||||||
|
|
||||||
axbutton_accept = fig.add_axes(rect=(0.3, 0.85, 0.2, 0.04))
|
axbutton_accept = fig.add_axes(rect=(0.3, 0.85, 0.2, 0.04))
|
||||||
button_accept = Button(
|
button_accept = Button(
|
||||||
ax=axbutton_accept, label="Accept", image=None, color="0.85", hovercolor="0.95"
|
ax=axbutton_accept, label="Accept", image=None, color="0.85", hovercolor="0.95"
|
||||||
|
@ -135,7 +135,11 @@ button_cancel.on_clicked(on_clicked_cancel) # type: ignore
|
||||||
|
|
||||||
axbutton_addmask = fig.add_axes(rect=(0.3, 0.9, 0.2, 0.04))
|
axbutton_addmask = fig.add_axes(rect=(0.3, 0.9, 0.2, 0.04))
|
||||||
button_addmask = Button(
|
button_addmask = Button(
|
||||||
ax=axbutton_addmask, label="Add mask", image=None, color="0.85", hovercolor="0.95"
|
ax=axbutton_addmask,
|
||||||
|
label="Add mask",
|
||||||
|
image=None,
|
||||||
|
color="0.85",
|
||||||
|
hovercolor="0.95",
|
||||||
)
|
)
|
||||||
button_addmask.on_clicked(on_clicked_add) # type: ignore
|
button_addmask.on_clicked(on_clicked_add) # type: ignore
|
||||||
|
|
||||||
|
@ -155,3 +159,7 @@ mylogger.info("Display")
|
||||||
new_roi: RoiPoly = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False)
|
new_roi: RoiPoly = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False)
|
||||||
|
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
argh.dispatch_command(main)
|
||||||
|
|
|
@ -20,6 +20,8 @@ from functions.gauss_smear_individual import gauss_smear_individual
|
||||||
from functions.regression import regression
|
from functions.regression import regression
|
||||||
from functions.data_raw_loader import data_raw_loader
|
from functions.data_raw_loader import data_raw_loader
|
||||||
|
|
||||||
|
import argh
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def process_trial(
|
def process_trial(
|
||||||
|
@ -889,10 +891,19 @@ def process_trial(
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
*,
|
||||||
|
config_filename: str = "config.json",
|
||||||
|
experiment_id_overwrite: int = -1,
|
||||||
|
trial_id_overwrite: int = -1,
|
||||||
|
) -> None:
|
||||||
mylogger = create_logger(
|
mylogger = create_logger(
|
||||||
save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_4"
|
save_logging_messages=True,
|
||||||
|
display_logging_messages=True,
|
||||||
|
log_stage_name="stage_4",
|
||||||
)
|
)
|
||||||
config = load_config(mylogger=mylogger)
|
|
||||||
|
config = load_config(mylogger=mylogger, filename=config_filename)
|
||||||
|
|
||||||
if (config["save_as_python"] is False) and (config["save_as_matlab"] is False):
|
if (config["save_as_python"] is False) and (config["save_as_matlab"] is False):
|
||||||
mylogger.info("No output will be created. ")
|
mylogger.info("No output will be created. ")
|
||||||
|
@ -911,7 +922,9 @@ if (len(config["target_camera_donor"]) == 0) and (
|
||||||
|
|
||||||
device = get_torch_device(mylogger, config["force_to_cpu"])
|
device = get_torch_device(mylogger, config["force_to_cpu"])
|
||||||
|
|
||||||
mylogger.info(f"Create directory {config['export_path']} in the case it does not exist")
|
mylogger.info(
|
||||||
|
f"Create directory {config['export_path']} in the case it does not exist"
|
||||||
|
)
|
||||||
os.makedirs(config["export_path"], exist_ok=True)
|
os.makedirs(config["export_path"], exist_ok=True)
|
||||||
|
|
||||||
raw_data_path: str = os.path.join(
|
raw_data_path: str = os.path.join(
|
||||||
|
@ -925,11 +938,21 @@ if os.path.isdir(raw_data_path) is False:
|
||||||
mylogger.info(f"ERROR: could not find raw directory {raw_data_path}!!!!")
|
mylogger.info(f"ERROR: could not find raw directory {raw_data_path}!!!!")
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
|
if experiment_id_overwrite == -1:
|
||||||
experiments = get_experiments(raw_data_path)
|
experiments = get_experiments(raw_data_path)
|
||||||
|
else:
|
||||||
|
assert experiment_id_overwrite >= 0
|
||||||
|
experiments = torch.tensor([experiment_id_overwrite])
|
||||||
|
|
||||||
for experiment_counter in range(0, experiments.shape[0]):
|
for experiment_counter in range(0, experiments.shape[0]):
|
||||||
experiment_id = int(experiments[experiment_counter])
|
experiment_id = int(experiments[experiment_counter])
|
||||||
|
|
||||||
|
if trial_id_overwrite == -1:
|
||||||
trials = get_trials(raw_data_path, experiment_id)
|
trials = get_trials(raw_data_path, experiment_id)
|
||||||
|
else:
|
||||||
|
assert trial_id_overwrite >= 0
|
||||||
|
trials = torch.tensor([trial_id_overwrite])
|
||||||
|
|
||||||
for trial_counter in range(0, trials.shape[0]):
|
for trial_counter in range(0, trials.shape[0]):
|
||||||
trial_id = int(trials[trial_counter])
|
trial_id = int(trials[trial_counter])
|
||||||
|
|
||||||
|
@ -957,3 +980,7 @@ for experiment_counter in range(0, experiments.shape[0]):
|
||||||
trial_id=trial_id,
|
trial_id=trial_id,
|
||||||
device=torch.device("cpu"),
|
device=torch.device("cpu"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
argh.dispatch_command(main)
|
||||||
|
|
Loading…
Reference in a new issue