diff --git a/new_pipeline/stage_2_make_heartbeat_mask.py b/new_pipeline/stage_2_make_heartbeat_mask.py index be19083..ee7a8f9 100644 --- a/new_pipeline/stage_2_make_heartbeat_mask.py +++ b/new_pipeline/stage_2_make_heartbeat_mask.py @@ -20,17 +20,24 @@ mylogger.info("loading config file") with open("config.json", "r") as file: config = json.loads(jsmin(file.read())) -threshold: float = 0.05 + path: str = config["ref_image_path"] use_channel: str = "donor" spatial_width: float = 4.0 temporal_width: float = 0.1 +threshold: float = 0.05 + +heartbeat_mask_threshold_file: str = os.path.join(path, "heartbeat_mask_threshold.npy") +if os.path.isfile(heartbeat_mask_threshold_file): + mylogger.info(f"loading previous threshold file: {heartbeat_mask_threshold_file}") + threshold = float(np.load(heartbeat_mask_threshold_file)[0]) + +mylogger.info(f"initial threshold is {threshold}") image_ref_file: str = os.path.join(path, use_channel + ".npy") image_var_file: str = os.path.join(path, use_channel + "_var.npy") heartbeat_mask_file: str = os.path.join(path, "heartbeat_mask.npy") -heartbeat_mask_threshold_file: str = os.path.join(path, "heartbeat_mask_threshold.npy") device = get_torch_device(mylogger, config["force_to_cpu"]) @@ -124,7 +131,7 @@ mylogger.info("Add controls") axfreq = fig.add_axes(rect=(0.4, 0.9, 0.3, 0.03)) slice_slider = Slider( ax=axfreq, - label="Slice", + label="Threshold", valmin=0, valmax=value_sort_max, valinit=threshold,