Add files via upload
This commit is contained in:
parent
369540f472
commit
9fa12c258e
3 changed files with 184 additions and 33 deletions
|
@ -11,6 +11,7 @@ from functions.get_parts import get_parts
|
||||||
from functions.bandpass import bandpass
|
from functions.bandpass import bandpass
|
||||||
from functions.create_logger import create_logger
|
from functions.create_logger import create_logger
|
||||||
from functions.load_meta_data import load_meta_data
|
from functions.load_meta_data import load_meta_data
|
||||||
|
from functions.get_torch_device import get_torch_device
|
||||||
|
|
||||||
mylogger = create_logger(
|
mylogger = create_logger(
|
||||||
save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_1"
|
save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_1"
|
||||||
|
@ -20,16 +21,7 @@ mylogger.info("loading config file")
|
||||||
with open("config.json", "r") as file:
|
with open("config.json", "r") as file:
|
||||||
config = json.loads(jsmin(file.read()))
|
config = json.loads(jsmin(file.read()))
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
device = get_torch_device(mylogger, config["force_to_cpu"])
|
||||||
device_name: str = "cuda:0"
|
|
||||||
else:
|
|
||||||
device_name = "cpu"
|
|
||||||
|
|
||||||
if config["force_to_cpu"]:
|
|
||||||
device_name = "cpu"
|
|
||||||
|
|
||||||
mylogger.info(f"Using device: {device_name}")
|
|
||||||
device: torch.device = torch.device(device_name)
|
|
||||||
|
|
||||||
dtype_str: str = config["dtype"]
|
dtype_str: str = config["dtype"]
|
||||||
dtype: torch.dtype = getattr(torch, dtype_str)
|
dtype: torch.dtype = getattr(torch, dtype_str)
|
||||||
|
@ -115,8 +107,8 @@ mylogger.info("-==- Done -==-")
|
||||||
sample_frequency: float = 1.0 / meta_frame_time
|
sample_frequency: float = 1.0 / meta_frame_time
|
||||||
mylogger.info(
|
mylogger.info(
|
||||||
(
|
(
|
||||||
f"Heartbeat power {config['lower_freqency_bandpass']}Hz"
|
f"Heartbeat power {config['lower_freqency_bandpass']} Hz"
|
||||||
f" - {config['upper_freqency_bandpass']}Hz,"
|
f" - {config['upper_freqency_bandpass']} Hz,"
|
||||||
f" sample-rate: {sample_frequency},"
|
f" sample-rate: {sample_frequency},"
|
||||||
f" skipping the first {config['skip_frames_in_the_beginning']} frames"
|
f" skipping the first {config['skip_frames_in_the_beginning']} frames"
|
||||||
)
|
)
|
||||||
|
|
|
@ -10,7 +10,7 @@ from matplotlib.widgets import Slider, Button # type:ignore
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from functions.gauss_smear_individual import gauss_smear_individual
|
from functions.gauss_smear_individual import gauss_smear_individual
|
||||||
from functions.create_logger import create_logger
|
from functions.create_logger import create_logger
|
||||||
|
from functions.get_torch_device import get_torch_device
|
||||||
|
|
||||||
mylogger = create_logger(
|
mylogger = create_logger(
|
||||||
save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_2"
|
save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_2"
|
||||||
|
@ -22,22 +22,17 @@ with open("config.json", "r") as file:
|
||||||
|
|
||||||
threshold: float = 0.05
|
threshold: float = 0.05
|
||||||
path: str = config["ref_image_path"]
|
path: str = config["ref_image_path"]
|
||||||
|
use_channel: str = "donor"
|
||||||
|
spatial_width: float = 4.0
|
||||||
|
temporal_width: float = 0.1
|
||||||
|
|
||||||
image_ref_file: str = os.path.join(path, "donor.npy")
|
|
||||||
image_var_file: str = os.path.join(path, "donor_var.npy")
|
image_ref_file: str = os.path.join(path, use_channel + ".npy")
|
||||||
|
image_var_file: str = os.path.join(path, use_channel + "_var.npy")
|
||||||
heartbeat_mask_file: str = os.path.join(path, "heartbeat_mask.npy")
|
heartbeat_mask_file: str = os.path.join(path, "heartbeat_mask.npy")
|
||||||
heartbeat_mask_threshold_file: str = os.path.join(path, "heartbeat_mask_threshold.npy")
|
heartbeat_mask_threshold_file: str = os.path.join(path, "heartbeat_mask_threshold.npy")
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
device = get_torch_device(mylogger, config["force_to_cpu"])
|
||||||
device_name: str = "cuda:0"
|
|
||||||
else:
|
|
||||||
device_name = "cpu"
|
|
||||||
|
|
||||||
if config["force_to_cpu"]:
|
|
||||||
device_name = "cpu"
|
|
||||||
|
|
||||||
mylogger.info(f"Using device: {device_name}")
|
|
||||||
device: torch.device = torch.device(device_name)
|
|
||||||
|
|
||||||
|
|
||||||
def next_frame(
|
def next_frame(
|
||||||
|
@ -58,7 +53,7 @@ def next_frame(
|
||||||
|
|
||||||
def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None:
|
def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None:
|
||||||
global threshold
|
global threshold
|
||||||
global volume_3color
|
global image_3color
|
||||||
global path
|
global path
|
||||||
global mylogger
|
global mylogger
|
||||||
global heartbeat_mask_file
|
global heartbeat_mask_file
|
||||||
|
@ -66,7 +61,7 @@ def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None:
|
||||||
|
|
||||||
mylogger.info(f"Threshold: {threshold}")
|
mylogger.info(f"Threshold: {threshold}")
|
||||||
|
|
||||||
mask: np.ndarray = volume_3color[..., 2] >= threshold
|
mask: np.ndarray = image_3color[..., 2] >= threshold
|
||||||
mylogger.info(f"Save mask to: {heartbeat_mask_file}")
|
mylogger.info(f"Save mask to: {heartbeat_mask_file}")
|
||||||
np.save(heartbeat_mask_file, mask)
|
np.save(heartbeat_mask_file, mask)
|
||||||
mylogger.info(f"Save threshold to: {heartbeat_mask_threshold_file}")
|
mylogger.info(f"Save threshold to: {heartbeat_mask_threshold_file}")
|
||||||
|
@ -89,15 +84,15 @@ image_var /= image_var.max()
|
||||||
mylogger.info("Smear the image heartbeat power patially")
|
mylogger.info("Smear the image heartbeat power patially")
|
||||||
temp, _ = gauss_smear_individual(
|
temp, _ = gauss_smear_individual(
|
||||||
input=torch.tensor(image_var[..., np.newaxis], device=device),
|
input=torch.tensor(image_var[..., np.newaxis], device=device),
|
||||||
spatial_width=4.0,
|
spatial_width=spatial_width,
|
||||||
temporal_width=0.1,
|
temporal_width=temporal_width,
|
||||||
use_matlab_mask=False,
|
use_matlab_mask=False,
|
||||||
)
|
)
|
||||||
temp /= temp.max()
|
temp /= temp.max()
|
||||||
|
|
||||||
mylogger.info("-==- DONE -==-")
|
mylogger.info("-==- DONE -==-")
|
||||||
|
|
||||||
volume_3color = np.concatenate(
|
image_3color = np.concatenate(
|
||||||
(
|
(
|
||||||
np.zeros_like(image_ref[..., np.newaxis]),
|
np.zeros_like(image_ref[..., np.newaxis]),
|
||||||
image_ref[..., np.newaxis],
|
image_ref[..., np.newaxis],
|
||||||
|
@ -108,9 +103,9 @@ volume_3color = np.concatenate(
|
||||||
|
|
||||||
mylogger.info("Prepare image")
|
mylogger.info("Prepare image")
|
||||||
|
|
||||||
display_image = volume_3color.copy()
|
display_image = image_3color.copy()
|
||||||
display_image[..., 2] = display_image[..., 0]
|
display_image[..., 2] = display_image[..., 0]
|
||||||
mask = np.where(volume_3color[..., 2] >= threshold, 1.0, np.nan)[..., np.newaxis]
|
mask = np.where(image_3color[..., 2] >= threshold, 1.0, np.nan)[..., np.newaxis]
|
||||||
display_image *= mask
|
display_image *= mask
|
||||||
display_image = np.nan_to_num(display_image, nan=1.0)
|
display_image = np.nan_to_num(display_image, nan=1.0)
|
||||||
|
|
||||||
|
@ -148,7 +143,7 @@ button_cancel = Button(
|
||||||
button_cancel.on_clicked(on_clicked_cancel) # type: ignore
|
button_cancel.on_clicked(on_clicked_cancel) # type: ignore
|
||||||
|
|
||||||
slice_slider.on_changed(
|
slice_slider.on_changed(
|
||||||
partial(next_frame, images=volume_3color, image_handle=image_handle)
|
partial(next_frame, images=image_3color, image_handle=image_handle)
|
||||||
)
|
)
|
||||||
|
|
||||||
mylogger.info("Display")
|
mylogger.info("Display")
|
||||||
|
|
164
new_pipeline/stage_3_refine_mask.py
Normal file
164
new_pipeline/stage_3_refine_mask.py
Normal file
|
@ -0,0 +1,164 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt # type:ignore
|
||||||
|
import matplotlib
|
||||||
|
from matplotlib.widgets import Button # type:ignore
|
||||||
|
|
||||||
|
# pip install roipoly
|
||||||
|
from roipoly import RoiPoly
|
||||||
|
|
||||||
|
from jsmin import jsmin # type:ignore
|
||||||
|
from functions.create_logger import create_logger
|
||||||
|
from functions.get_torch_device import get_torch_device
|
||||||
|
|
||||||
|
|
||||||
|
def compose_image(image_3color: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
||||||
|
display_image = image_3color.copy()
|
||||||
|
display_image[..., 2] = display_image[..., 0]
|
||||||
|
display_image[mask == 0, :] = 1.0
|
||||||
|
return display_image
|
||||||
|
|
||||||
|
|
||||||
|
def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None:
|
||||||
|
global mylogger
|
||||||
|
global refined_mask_file
|
||||||
|
global mask
|
||||||
|
|
||||||
|
mylogger.info(f"Save mask to: {refined_mask_file}")
|
||||||
|
np.save(refined_mask_file, mask)
|
||||||
|
|
||||||
|
exit()
|
||||||
|
|
||||||
|
|
||||||
|
def on_clicked_cancel(event: matplotlib.backend_bases.MouseEvent) -> None:
|
||||||
|
global mylogger
|
||||||
|
mylogger.info("Ended without saving the mask")
|
||||||
|
exit()
|
||||||
|
|
||||||
|
|
||||||
|
def on_clicked_add(event: matplotlib.backend_bases.MouseEvent) -> None:
|
||||||
|
global new_roi
|
||||||
|
global mask
|
||||||
|
global image_3color
|
||||||
|
global display_image
|
||||||
|
global mylogger
|
||||||
|
if len(new_roi.x) > 0:
|
||||||
|
mylogger.info("A ROI with the following coordiantes has been added to the mask")
|
||||||
|
for i in range(0, len(new_roi.x)):
|
||||||
|
mylogger.info(f"{round(new_roi.x[i],1)} x {round(new_roi.y[i],1)}")
|
||||||
|
mylogger.info("")
|
||||||
|
new_mask = new_roi.get_mask(display_image[:, :, 0])
|
||||||
|
mask[new_mask] = 0.0
|
||||||
|
display_image = compose_image(image_3color=image_3color, mask=mask)
|
||||||
|
image_handle.set_data(display_image)
|
||||||
|
for line in ax_main.lines:
|
||||||
|
line.remove()
|
||||||
|
plt.draw()
|
||||||
|
|
||||||
|
new_roi = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False)
|
||||||
|
|
||||||
|
|
||||||
|
def on_clicked_remove(event: matplotlib.backend_bases.MouseEvent) -> None:
|
||||||
|
global new_roi
|
||||||
|
global mask
|
||||||
|
global image_3color
|
||||||
|
global display_image
|
||||||
|
if len(new_roi.x) > 0:
|
||||||
|
mylogger.info(
|
||||||
|
"A ROI with the following coordiantes has been removed from the mask"
|
||||||
|
)
|
||||||
|
for i in range(0, len(new_roi.x)):
|
||||||
|
mylogger.info(f"{round(new_roi.x[i],1)} x {round(new_roi.y[i],1)}")
|
||||||
|
mylogger.info("")
|
||||||
|
new_mask = new_roi.get_mask(display_image[:, :, 0])
|
||||||
|
mask[new_mask] = 1.0
|
||||||
|
display_image = compose_image(image_3color=image_3color, mask=mask)
|
||||||
|
image_handle.set_data(display_image)
|
||||||
|
for line in ax_main.lines:
|
||||||
|
line.remove()
|
||||||
|
plt.draw()
|
||||||
|
new_roi = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False)
|
||||||
|
|
||||||
|
|
||||||
|
mylogger = create_logger(
|
||||||
|
save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_3"
|
||||||
|
)
|
||||||
|
|
||||||
|
mylogger.info("loading config file")
|
||||||
|
with open("config.json", "r") as file:
|
||||||
|
config = json.loads(jsmin(file.read()))
|
||||||
|
|
||||||
|
device = get_torch_device(mylogger, config["force_to_cpu"])
|
||||||
|
|
||||||
|
path: str = config["ref_image_path"]
|
||||||
|
use_channel: str = "donor"
|
||||||
|
|
||||||
|
image_ref_file: str = os.path.join(path, use_channel + ".npy")
|
||||||
|
heartbeat_mask_file: str = os.path.join(path, "heartbeat_mask.npy")
|
||||||
|
refined_mask_file: str = os.path.join(path, "mask_not_rotated.npy")
|
||||||
|
|
||||||
|
mylogger.info(f"loading image reference file: {image_ref_file}")
|
||||||
|
image_ref: np.ndarray = np.load(image_ref_file)
|
||||||
|
image_ref /= image_ref.max()
|
||||||
|
|
||||||
|
mylogger.info(f"loading heartbeat mask: {heartbeat_mask_file}")
|
||||||
|
mask: np.ndarray = np.load(heartbeat_mask_file)
|
||||||
|
|
||||||
|
image_3color = np.concatenate(
|
||||||
|
(
|
||||||
|
np.zeros_like(image_ref[..., np.newaxis]),
|
||||||
|
image_ref[..., np.newaxis],
|
||||||
|
np.zeros_like(image_ref[..., np.newaxis]),
|
||||||
|
),
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
mylogger.info("-==- DONE -==-")
|
||||||
|
|
||||||
|
fig, ax_main = plt.subplots()
|
||||||
|
|
||||||
|
display_image = compose_image(image_3color=image_3color, mask=mask)
|
||||||
|
image_handle = ax_main.imshow(display_image, vmin=0, vmax=1, cmap="hot")
|
||||||
|
|
||||||
|
mylogger.info("Add controls")
|
||||||
|
|
||||||
|
axbutton_accept = fig.add_axes(rect=(0.3, 0.85, 0.2, 0.04))
|
||||||
|
button_accept = Button(
|
||||||
|
ax=axbutton_accept, label="Accept", image=None, color="0.85", hovercolor="0.95"
|
||||||
|
)
|
||||||
|
button_accept.on_clicked(on_clicked_accept) # type: ignore
|
||||||
|
|
||||||
|
axbutton_cancel = fig.add_axes(rect=(0.5, 0.85, 0.2, 0.04))
|
||||||
|
button_cancel = Button(
|
||||||
|
ax=axbutton_cancel, label="Cancel", image=None, color="0.85", hovercolor="0.95"
|
||||||
|
)
|
||||||
|
button_cancel.on_clicked(on_clicked_cancel) # type: ignore
|
||||||
|
|
||||||
|
axbutton_addmask = fig.add_axes(rect=(0.3, 0.9, 0.2, 0.04))
|
||||||
|
button_addmask = Button(
|
||||||
|
ax=axbutton_addmask, label="Add mask", image=None, color="0.85", hovercolor="0.95"
|
||||||
|
)
|
||||||
|
button_addmask.on_clicked(on_clicked_add) # type: ignore
|
||||||
|
|
||||||
|
axbutton_removemask = fig.add_axes(rect=(0.5, 0.9, 0.2, 0.04))
|
||||||
|
button_removemask = Button(
|
||||||
|
ax=axbutton_removemask,
|
||||||
|
label="Remove mask",
|
||||||
|
image=None,
|
||||||
|
color="0.85",
|
||||||
|
hovercolor="0.95",
|
||||||
|
)
|
||||||
|
button_removemask.on_clicked(on_clicked_remove) # type: ignore
|
||||||
|
|
||||||
|
# ax_main.cla()
|
||||||
|
|
||||||
|
mylogger.info("Display")
|
||||||
|
new_roi: RoiPoly = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False)
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
# image_handle.remove()
|
||||||
|
#
|
Loading…
Reference in a new issue