mirror of
https://github.com/davrot/gevi.git
synced 2025-04-20 22:26:43 +02:00
Delete stage_2_make_heartbeat_mask.py
This commit is contained in:
parent
2d1772b948
commit
580a85b964
1 changed files with 0 additions and 163 deletions
|
@ -1,163 +0,0 @@
|
|||
import matplotlib.pyplot as plt # type:ignore
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
import torch
|
||||
import os
|
||||
import argh
|
||||
|
||||
from matplotlib.widgets import Slider, Button # type:ignore
|
||||
from functools import partial
|
||||
from functions.gauss_smear_individual import gauss_smear_individual
|
||||
from functions.create_logger import create_logger
|
||||
from functions.get_torch_device import get_torch_device
|
||||
from functions.load_config import load_config
|
||||
|
||||
|
||||
def main(*, config_filename: str = "config.json") -> None:
|
||||
mylogger = create_logger(
|
||||
save_logging_messages=True,
|
||||
display_logging_messages=True,
|
||||
log_stage_name="stage_2",
|
||||
)
|
||||
|
||||
config = load_config(mylogger=mylogger, filename=config_filename)
|
||||
|
||||
path: str = config["ref_image_path"]
|
||||
use_channel: str = "donor"
|
||||
spatial_width: float = 4.0
|
||||
temporal_width: float = 0.1
|
||||
|
||||
threshold: float = 0.05
|
||||
|
||||
heartbeat_mask_threshold_file: str = os.path.join(
|
||||
path, "heartbeat_mask_threshold.npy"
|
||||
)
|
||||
if os.path.isfile(heartbeat_mask_threshold_file):
|
||||
mylogger.info(
|
||||
f"loading previous threshold file: {heartbeat_mask_threshold_file}"
|
||||
)
|
||||
threshold = float(np.load(heartbeat_mask_threshold_file)[0])
|
||||
|
||||
mylogger.info(f"initial threshold is {threshold}")
|
||||
|
||||
image_ref_file: str = os.path.join(path, use_channel + ".npy")
|
||||
image_var_file: str = os.path.join(path, use_channel + "_var.npy")
|
||||
heartbeat_mask_file: str = os.path.join(path, "heartbeat_mask.npy")
|
||||
|
||||
device = get_torch_device(mylogger, config["force_to_cpu"])
|
||||
|
||||
mylogger.info(f"loading image reference file: {image_ref_file}")
|
||||
image_ref: np.ndarray = np.load(image_ref_file)
|
||||
image_ref /= image_ref.max()
|
||||
|
||||
mylogger.info(f"loading image heartbeat power: {image_var_file}")
|
||||
image_var: np.ndarray = np.load(image_var_file)
|
||||
image_var /= image_var.max()
|
||||
|
||||
mylogger.info("Smear the image heartbeat power patially")
|
||||
temp, _ = gauss_smear_individual(
|
||||
input=torch.tensor(image_var[..., np.newaxis], device=device),
|
||||
spatial_width=spatial_width,
|
||||
temporal_width=temporal_width,
|
||||
use_matlab_mask=False,
|
||||
)
|
||||
temp /= temp.max()
|
||||
|
||||
mylogger.info("-==- DONE -==-")
|
||||
|
||||
image_3color = np.concatenate(
|
||||
(
|
||||
np.zeros_like(image_ref[..., np.newaxis]),
|
||||
image_ref[..., np.newaxis],
|
||||
temp.cpu().numpy(),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
mylogger.info("Prepare image")
|
||||
|
||||
display_image = image_3color.copy()
|
||||
display_image[..., 2] = display_image[..., 0]
|
||||
mask = np.where(image_3color[..., 2] >= threshold, 1.0, np.nan)[..., np.newaxis]
|
||||
display_image *= mask
|
||||
display_image = np.nan_to_num(display_image, nan=1.0)
|
||||
|
||||
value_sort = np.sort(image_var.flatten())
|
||||
value_sort_max = value_sort[int(value_sort.shape[0] * 0.95)] * 3
|
||||
print(value_sort_max)
|
||||
mylogger.info("-==- DONE -==-")
|
||||
|
||||
mylogger.info("Create figure")
|
||||
|
||||
fig: matplotlib.figure.Figure = plt.figure()
|
||||
|
||||
image_handle = plt.imshow(display_image, vmin=0, vmax=1, cmap="hot")
|
||||
|
||||
mylogger.info("Add controls")
|
||||
|
||||
def next_frame(
|
||||
i: float, images: np.ndarray, image_handle: matplotlib.image.AxesImage
|
||||
) -> None:
|
||||
nonlocal threshold
|
||||
threshold = i
|
||||
|
||||
display_image: np.ndarray = images.copy()
|
||||
display_image[..., 2] = display_image[..., 0]
|
||||
mask: np.ndarray = np.where(images[..., 2] >= i, 1.0, np.nan)[..., np.newaxis]
|
||||
display_image *= mask
|
||||
display_image = np.nan_to_num(display_image, nan=1.0)
|
||||
|
||||
image_handle.set_data(display_image)
|
||||
return
|
||||
|
||||
def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None:
|
||||
nonlocal threshold
|
||||
nonlocal image_3color
|
||||
nonlocal path
|
||||
nonlocal mylogger
|
||||
nonlocal heartbeat_mask_file
|
||||
nonlocal heartbeat_mask_threshold_file
|
||||
|
||||
mylogger.info(f"Threshold: {threshold}")
|
||||
|
||||
mask: np.ndarray = image_3color[..., 2] >= threshold
|
||||
mylogger.info(f"Save mask to: {heartbeat_mask_file}")
|
||||
np.save(heartbeat_mask_file, mask)
|
||||
mylogger.info(f"Save threshold to: {heartbeat_mask_threshold_file}")
|
||||
np.save(heartbeat_mask_threshold_file, np.array([threshold]))
|
||||
exit()
|
||||
|
||||
def on_clicked_cancel(event: matplotlib.backend_bases.MouseEvent) -> None:
|
||||
exit()
|
||||
|
||||
axfreq = fig.add_axes(rect=(0.4, 0.9, 0.3, 0.03))
|
||||
slice_slider = Slider(
|
||||
ax=axfreq,
|
||||
label="Threshold",
|
||||
valmin=0,
|
||||
valmax=value_sort_max,
|
||||
valinit=threshold,
|
||||
valstep=value_sort_max / 1000.0,
|
||||
)
|
||||
axbutton_accept = fig.add_axes(rect=(0.3, 0.85, 0.2, 0.04))
|
||||
button_accept = Button(
|
||||
ax=axbutton_accept, label="Accept", image=None, color="0.85", hovercolor="0.95"
|
||||
)
|
||||
button_accept.on_clicked(on_clicked_accept) # type: ignore
|
||||
|
||||
axbutton_cancel = fig.add_axes(rect=(0.55, 0.85, 0.2, 0.04))
|
||||
button_cancel = Button(
|
||||
ax=axbutton_cancel, label="Cancel", image=None, color="0.85", hovercolor="0.95"
|
||||
)
|
||||
button_cancel.on_clicked(on_clicked_cancel) # type: ignore
|
||||
|
||||
slice_slider.on_changed(
|
||||
partial(next_frame, images=image_3color, image_handle=image_handle)
|
||||
)
|
||||
|
||||
mylogger.info("Display")
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
argh.dispatch_command(main)
|
Loading…
Add table
Reference in a new issue