Add files via upload

This commit is contained in:
David Rotermund 2024-02-26 13:00:47 +01:00 committed by GitHub
parent 369540f472
commit 9fa12c258e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 184 additions and 33 deletions

View file

@ -11,6 +11,7 @@ from functions.get_parts import get_parts
from functions.bandpass import bandpass
from functions.create_logger import create_logger
from functions.load_meta_data import load_meta_data
from functions.get_torch_device import get_torch_device
mylogger = create_logger(
save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_1"
@ -20,16 +21,7 @@ mylogger.info("loading config file")
with open("config.json", "r") as file:
config = json.loads(jsmin(file.read()))
if torch.cuda.is_available():
device_name: str = "cuda:0"
else:
device_name = "cpu"
if config["force_to_cpu"]:
device_name = "cpu"
mylogger.info(f"Using device: {device_name}")
device: torch.device = torch.device(device_name)
device = get_torch_device(mylogger, config["force_to_cpu"])
dtype_str: str = config["dtype"]
dtype: torch.dtype = getattr(torch, dtype_str)

View file

@ -10,7 +10,7 @@ from matplotlib.widgets import Slider, Button # type:ignore
from functools import partial
from functions.gauss_smear_individual import gauss_smear_individual
from functions.create_logger import create_logger
from functions.get_torch_device import get_torch_device
mylogger = create_logger(
save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_2"
@ -22,22 +22,17 @@ with open("config.json", "r") as file:
threshold: float = 0.05
path: str = config["ref_image_path"]
use_channel: str = "donor"
spatial_width: float = 4.0
temporal_width: float = 0.1
image_ref_file: str = os.path.join(path, "donor.npy")
image_var_file: str = os.path.join(path, "donor_var.npy")
image_ref_file: str = os.path.join(path, use_channel + ".npy")
image_var_file: str = os.path.join(path, use_channel + "_var.npy")
heartbeat_mask_file: str = os.path.join(path, "heartbeat_mask.npy")
heartbeat_mask_threshold_file: str = os.path.join(path, "heartbeat_mask_threshold.npy")
if torch.cuda.is_available():
device_name: str = "cuda:0"
else:
device_name = "cpu"
if config["force_to_cpu"]:
device_name = "cpu"
mylogger.info(f"Using device: {device_name}")
device: torch.device = torch.device(device_name)
device = get_torch_device(mylogger, config["force_to_cpu"])
def next_frame(
@ -58,7 +53,7 @@ def next_frame(
def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None:
global threshold
global volume_3color
global image_3color
global path
global mylogger
global heartbeat_mask_file
@ -66,7 +61,7 @@ def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None:
mylogger.info(f"Threshold: {threshold}")
mask: np.ndarray = volume_3color[..., 2] >= threshold
mask: np.ndarray = image_3color[..., 2] >= threshold
mylogger.info(f"Save mask to: {heartbeat_mask_file}")
np.save(heartbeat_mask_file, mask)
mylogger.info(f"Save threshold to: {heartbeat_mask_threshold_file}")
@ -89,15 +84,15 @@ image_var /= image_var.max()
mylogger.info("Smear the image heartbeat power patially")
temp, _ = gauss_smear_individual(
input=torch.tensor(image_var[..., np.newaxis], device=device),
spatial_width=4.0,
temporal_width=0.1,
spatial_width=spatial_width,
temporal_width=temporal_width,
use_matlab_mask=False,
)
temp /= temp.max()
mylogger.info("-==- DONE -==-")
volume_3color = np.concatenate(
image_3color = np.concatenate(
(
np.zeros_like(image_ref[..., np.newaxis]),
image_ref[..., np.newaxis],
@ -108,9 +103,9 @@ volume_3color = np.concatenate(
mylogger.info("Prepare image")
display_image = volume_3color.copy()
display_image = image_3color.copy()
display_image[..., 2] = display_image[..., 0]
mask = np.where(volume_3color[..., 2] >= threshold, 1.0, np.nan)[..., np.newaxis]
mask = np.where(image_3color[..., 2] >= threshold, 1.0, np.nan)[..., np.newaxis]
display_image *= mask
display_image = np.nan_to_num(display_image, nan=1.0)
@ -148,7 +143,7 @@ button_cancel = Button(
button_cancel.on_clicked(on_clicked_cancel) # type: ignore
slice_slider.on_changed(
partial(next_frame, images=volume_3color, image_handle=image_handle)
partial(next_frame, images=image_3color, image_handle=image_handle)
)
mylogger.info("Display")

View file

@ -0,0 +1,164 @@
import os
import json
import numpy as np
import matplotlib.pyplot as plt # type:ignore
import matplotlib
from matplotlib.widgets import Button # type:ignore
# pip install roipoly
from roipoly import RoiPoly
from jsmin import jsmin # type:ignore
from functions.create_logger import create_logger
from functions.get_torch_device import get_torch_device
def compose_image(image_3color: np.ndarray, mask: np.ndarray) -> np.ndarray:
display_image = image_3color.copy()
display_image[..., 2] = display_image[..., 0]
display_image[mask == 0, :] = 1.0
return display_image
def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None:
global mylogger
global refined_mask_file
global mask
mylogger.info(f"Save mask to: {refined_mask_file}")
np.save(refined_mask_file, mask)
exit()
def on_clicked_cancel(event: matplotlib.backend_bases.MouseEvent) -> None:
global mylogger
mylogger.info("Ended without saving the mask")
exit()
def on_clicked_add(event: matplotlib.backend_bases.MouseEvent) -> None:
global new_roi
global mask
global image_3color
global display_image
global mylogger
if len(new_roi.x) > 0:
mylogger.info("A ROI with the following coordiantes has been added to the mask")
for i in range(0, len(new_roi.x)):
mylogger.info(f"{round(new_roi.x[i],1)} x {round(new_roi.y[i],1)}")
mylogger.info("")
new_mask = new_roi.get_mask(display_image[:, :, 0])
mask[new_mask] = 0.0
display_image = compose_image(image_3color=image_3color, mask=mask)
image_handle.set_data(display_image)
for line in ax_main.lines:
line.remove()
plt.draw()
new_roi = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False)
def on_clicked_remove(event: matplotlib.backend_bases.MouseEvent) -> None:
global new_roi
global mask
global image_3color
global display_image
if len(new_roi.x) > 0:
mylogger.info(
"A ROI with the following coordiantes has been removed from the mask"
)
for i in range(0, len(new_roi.x)):
mylogger.info(f"{round(new_roi.x[i],1)} x {round(new_roi.y[i],1)}")
mylogger.info("")
new_mask = new_roi.get_mask(display_image[:, :, 0])
mask[new_mask] = 1.0
display_image = compose_image(image_3color=image_3color, mask=mask)
image_handle.set_data(display_image)
for line in ax_main.lines:
line.remove()
plt.draw()
new_roi = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False)
mylogger = create_logger(
save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_3"
)
mylogger.info("loading config file")
with open("config.json", "r") as file:
config = json.loads(jsmin(file.read()))
device = get_torch_device(mylogger, config["force_to_cpu"])
path: str = config["ref_image_path"]
use_channel: str = "donor"
image_ref_file: str = os.path.join(path, use_channel + ".npy")
heartbeat_mask_file: str = os.path.join(path, "heartbeat_mask.npy")
refined_mask_file: str = os.path.join(path, "mask_not_rotated.npy")
mylogger.info(f"loading image reference file: {image_ref_file}")
image_ref: np.ndarray = np.load(image_ref_file)
image_ref /= image_ref.max()
mylogger.info(f"loading heartbeat mask: {heartbeat_mask_file}")
mask: np.ndarray = np.load(heartbeat_mask_file)
image_3color = np.concatenate(
(
np.zeros_like(image_ref[..., np.newaxis]),
image_ref[..., np.newaxis],
np.zeros_like(image_ref[..., np.newaxis]),
),
axis=-1,
)
mylogger.info("-==- DONE -==-")
fig, ax_main = plt.subplots()
display_image = compose_image(image_3color=image_3color, mask=mask)
image_handle = ax_main.imshow(display_image, vmin=0, vmax=1, cmap="hot")
mylogger.info("Add controls")
axbutton_accept = fig.add_axes(rect=(0.3, 0.85, 0.2, 0.04))
button_accept = Button(
ax=axbutton_accept, label="Accept", image=None, color="0.85", hovercolor="0.95"
)
button_accept.on_clicked(on_clicked_accept) # type: ignore
axbutton_cancel = fig.add_axes(rect=(0.5, 0.85, 0.2, 0.04))
button_cancel = Button(
ax=axbutton_cancel, label="Cancel", image=None, color="0.85", hovercolor="0.95"
)
button_cancel.on_clicked(on_clicked_cancel) # type: ignore
axbutton_addmask = fig.add_axes(rect=(0.3, 0.9, 0.2, 0.04))
button_addmask = Button(
ax=axbutton_addmask, label="Add mask", image=None, color="0.85", hovercolor="0.95"
)
button_addmask.on_clicked(on_clicked_add) # type: ignore
axbutton_removemask = fig.add_axes(rect=(0.5, 0.9, 0.2, 0.04))
button_removemask = Button(
ax=axbutton_removemask,
label="Remove mask",
image=None,
color="0.85",
hovercolor="0.95",
)
button_removemask.on_clicked(on_clicked_remove) # type: ignore
# ax_main.cla()
mylogger.info("Display")
new_roi: RoiPoly = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False)
plt.show()
# image_handle.remove()
#