Add files via upload
This commit is contained in:
parent
840bae8628
commit
369540f472
10 changed files with 682 additions and 0 deletions
26
new_pipeline/config.json
Normal file
26
new_pipeline/config.json
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
{
|
||||||
|
"basic_path": "/data_1/robert",
|
||||||
|
"recoding_data": "2021-05-05",
|
||||||
|
"mouse_identifier": "M3852M",
|
||||||
|
"raw_path": "raw",
|
||||||
|
"export_path": "output",
|
||||||
|
"ref_image_path": "ref_images",
|
||||||
|
"required_order": [
|
||||||
|
"acceptor",
|
||||||
|
"donor",
|
||||||
|
"oxygenation",
|
||||||
|
"volume"
|
||||||
|
],
|
||||||
|
"dtype": "float32",
|
||||||
|
"binning_enable": false,
|
||||||
|
"binning_kernel_size": 4,
|
||||||
|
"binning_stride": 4,
|
||||||
|
"binning_divisor_override": 1,
|
||||||
|
// Heart beat detection
|
||||||
|
"lower_freqency_bandpass": 5.0, // Hz
|
||||||
|
"upper_freqency_bandpass": 14.0, // Hz
|
||||||
|
// LED Ramp on
|
||||||
|
"skip_frames_in_the_beginning": 100, // Frames
|
||||||
|
// PyTorch
|
||||||
|
"force_to_cpu": false
|
||||||
|
}
|
85
new_pipeline/functions/bandpass.py
Normal file
85
new_pipeline/functions/bandpass.py
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
import torchaudio as ta # type: ignore
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def filtfilt(
|
||||||
|
input: torch.Tensor,
|
||||||
|
butter_a: torch.Tensor,
|
||||||
|
butter_b: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert butter_a.ndim == 1
|
||||||
|
assert butter_b.ndim == 1
|
||||||
|
assert butter_a.shape[0] == butter_b.shape[0]
|
||||||
|
|
||||||
|
process_data: torch.Tensor = input.detach().clone()
|
||||||
|
|
||||||
|
padding_length = 12 * int(butter_a.shape[0])
|
||||||
|
left_padding = 2 * process_data[..., 0].unsqueeze(-1) - process_data[
|
||||||
|
..., 1 : padding_length + 1
|
||||||
|
].flip(-1)
|
||||||
|
right_padding = 2 * process_data[..., -1].unsqueeze(-1) - process_data[
|
||||||
|
..., -(padding_length + 1) : -1
|
||||||
|
].flip(-1)
|
||||||
|
process_data_padded = torch.cat((left_padding, process_data, right_padding), dim=-1)
|
||||||
|
|
||||||
|
output = ta.functional.filtfilt(
|
||||||
|
process_data_padded.unsqueeze(0), butter_a, butter_b, clamp=False
|
||||||
|
).squeeze(0)
|
||||||
|
|
||||||
|
output = output[..., padding_length:-padding_length]
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def butter_bandpass(
|
||||||
|
device: torch.device,
|
||||||
|
low_frequency: float = 0.1,
|
||||||
|
high_frequency: float = 1.0,
|
||||||
|
fs: float = 30.0,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
import scipy # type: ignore
|
||||||
|
|
||||||
|
butter_b_np, butter_a_np = scipy.signal.butter(
|
||||||
|
4, [low_frequency, high_frequency], btype="bandpass", output="ba", fs=fs
|
||||||
|
)
|
||||||
|
butter_a = torch.tensor(butter_a_np, device=device, dtype=torch.float32)
|
||||||
|
butter_b = torch.tensor(butter_b_np, device=device, dtype=torch.float32)
|
||||||
|
return butter_a, butter_b
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def chunk_iterator(array: torch.Tensor, chunk_size: int):
|
||||||
|
for i in range(0, array.shape[0], chunk_size):
|
||||||
|
yield array[i : i + chunk_size]
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def bandpass(
|
||||||
|
data: torch.Tensor,
|
||||||
|
device: torch.device,
|
||||||
|
low_frequency: float = 0.1,
|
||||||
|
high_frequency: float = 1.0,
|
||||||
|
fs=30.0,
|
||||||
|
filtfilt_chuck_size: int = 10,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
butter_a, butter_b = butter_bandpass(
|
||||||
|
device=device,
|
||||||
|
low_frequency=low_frequency,
|
||||||
|
high_frequency=high_frequency,
|
||||||
|
fs=fs,
|
||||||
|
)
|
||||||
|
|
||||||
|
index_full_dataset: torch.Tensor = torch.arange(
|
||||||
|
0, data.shape[1], device=device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in chunk_iterator(index_full_dataset, filtfilt_chuck_size):
|
||||||
|
temp_filtfilt = filtfilt(
|
||||||
|
data[:, chunk, :],
|
||||||
|
butter_a=butter_a,
|
||||||
|
butter_b=butter_b,
|
||||||
|
)
|
||||||
|
data[:, chunk, :] = temp_filtfilt
|
||||||
|
|
||||||
|
return data
|
37
new_pipeline/functions/create_logger.py
Normal file
37
new_pipeline/functions/create_logger.py
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
import logging
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def create_logger(
|
||||||
|
save_logging_messages: bool, display_logging_messages: bool, log_stage_name: str
|
||||||
|
):
|
||||||
|
now = datetime.datetime.now()
|
||||||
|
dt_string_filename = now.strftime("%Y_%m_%d_%H_%M_%S")
|
||||||
|
|
||||||
|
logger = logging.getLogger("MyLittleLogger")
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
if save_logging_messages:
|
||||||
|
time_format = "%b %-d %Y %H:%M:%S"
|
||||||
|
logformat = "%(asctime)s %(message)s"
|
||||||
|
file_formatter = logging.Formatter(fmt=logformat, datefmt=time_format)
|
||||||
|
os.makedirs("logs_" + log_stage_name, exist_ok=True)
|
||||||
|
file_handler = logging.FileHandler(
|
||||||
|
os.path.join("logs_" + log_stage_name, f"log_{dt_string_filename}.txt")
|
||||||
|
)
|
||||||
|
file_handler.setLevel(logging.INFO)
|
||||||
|
file_handler.setFormatter(file_formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
if display_logging_messages:
|
||||||
|
time_format = "%H:%M:%S"
|
||||||
|
logformat = "%(asctime)s %(message)s"
|
||||||
|
stream_formatter = logging.Formatter(fmt=logformat, datefmt=time_format)
|
||||||
|
|
||||||
|
stream_handler = logging.StreamHandler()
|
||||||
|
stream_handler.setLevel(logging.INFO)
|
||||||
|
stream_handler.setFormatter(stream_formatter)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
return logger
|
127
new_pipeline/functions/gauss_smear_individual.py
Normal file
127
new_pipeline/functions/gauss_smear_individual.py
Normal file
|
@ -0,0 +1,127 @@
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def gauss_smear_individual(
|
||||||
|
input: torch.Tensor,
|
||||||
|
spatial_width: float,
|
||||||
|
temporal_width: float,
|
||||||
|
overwrite_fft_gauss: None | torch.Tensor = None,
|
||||||
|
use_matlab_mask: bool = True,
|
||||||
|
epsilon: float = float(torch.finfo(torch.float64).eps),
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
dim_x: int = int(2 * math.ceil(2 * spatial_width) + 1)
|
||||||
|
dim_y: int = int(2 * math.ceil(2 * spatial_width) + 1)
|
||||||
|
dim_t: int = int(2 * math.ceil(2 * temporal_width) + 1)
|
||||||
|
dims_xyt: torch.Tensor = torch.tensor(
|
||||||
|
[dim_x, dim_y, dim_t], dtype=torch.int64, device=input.device
|
||||||
|
)
|
||||||
|
|
||||||
|
if input.ndim == 2:
|
||||||
|
input = input.unsqueeze(-1)
|
||||||
|
|
||||||
|
input_padded = torch.nn.functional.pad(
|
||||||
|
input.unsqueeze(0),
|
||||||
|
pad=(
|
||||||
|
dim_t,
|
||||||
|
dim_t,
|
||||||
|
dim_y,
|
||||||
|
dim_y,
|
||||||
|
dim_x,
|
||||||
|
dim_x,
|
||||||
|
),
|
||||||
|
mode="replicate",
|
||||||
|
).squeeze(0)
|
||||||
|
|
||||||
|
if overwrite_fft_gauss is None:
|
||||||
|
center_x: int = int(math.ceil(input_padded.shape[0] / 2))
|
||||||
|
center_y: int = int(math.ceil(input_padded.shape[1] / 2))
|
||||||
|
center_z: int = int(math.ceil(input_padded.shape[2] / 2))
|
||||||
|
grid_x: torch.Tensor = (
|
||||||
|
torch.arange(0, input_padded.shape[0], device=input.device) - center_x + 1
|
||||||
|
)
|
||||||
|
grid_y: torch.Tensor = (
|
||||||
|
torch.arange(0, input_padded.shape[1], device=input.device) - center_y + 1
|
||||||
|
)
|
||||||
|
grid_z: torch.Tensor = (
|
||||||
|
torch.arange(0, input_padded.shape[2], device=input.device) - center_z + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
grid_x = grid_x.unsqueeze(-1).unsqueeze(-1) ** 2
|
||||||
|
grid_y = grid_y.unsqueeze(0).unsqueeze(-1) ** 2
|
||||||
|
grid_z = grid_z.unsqueeze(0).unsqueeze(0) ** 2
|
||||||
|
|
||||||
|
gauss_kernel: torch.Tensor = (
|
||||||
|
(grid_x / (spatial_width**2))
|
||||||
|
+ (grid_y / (spatial_width**2))
|
||||||
|
+ (grid_z / (temporal_width**2))
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_matlab_mask:
|
||||||
|
filter_radius: torch.Tensor = (dims_xyt - 1) // 2
|
||||||
|
|
||||||
|
border_lower: list[int] = [
|
||||||
|
center_x - int(filter_radius[0]) - 1,
|
||||||
|
center_y - int(filter_radius[1]) - 1,
|
||||||
|
center_z - int(filter_radius[2]) - 1,
|
||||||
|
]
|
||||||
|
|
||||||
|
border_upper: list[int] = [
|
||||||
|
center_x + int(filter_radius[0]),
|
||||||
|
center_y + int(filter_radius[1]),
|
||||||
|
center_z + int(filter_radius[2]),
|
||||||
|
]
|
||||||
|
|
||||||
|
matlab_mask: torch.Tensor = torch.zeros_like(gauss_kernel)
|
||||||
|
matlab_mask[
|
||||||
|
border_lower[0] : border_upper[0],
|
||||||
|
border_lower[1] : border_upper[1],
|
||||||
|
border_lower[2] : border_upper[2],
|
||||||
|
] = 1.0
|
||||||
|
|
||||||
|
gauss_kernel = torch.exp(-gauss_kernel / 2.0)
|
||||||
|
if use_matlab_mask:
|
||||||
|
gauss_kernel = gauss_kernel * matlab_mask
|
||||||
|
|
||||||
|
gauss_kernel[gauss_kernel < (epsilon * gauss_kernel.max())] = 0
|
||||||
|
|
||||||
|
sum_gauss_kernel: float = float(gauss_kernel.sum())
|
||||||
|
|
||||||
|
if sum_gauss_kernel != 0.0:
|
||||||
|
gauss_kernel = gauss_kernel / sum_gauss_kernel
|
||||||
|
|
||||||
|
# FFT Shift
|
||||||
|
gauss_kernel = torch.cat(
|
||||||
|
(gauss_kernel[center_x - 1 :, :, :], gauss_kernel[: center_x - 1, :, :]),
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
gauss_kernel = torch.cat(
|
||||||
|
(gauss_kernel[:, center_y - 1 :, :], gauss_kernel[:, : center_y - 1, :]),
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
gauss_kernel = torch.cat(
|
||||||
|
(gauss_kernel[:, :, center_z - 1 :], gauss_kernel[:, :, : center_z - 1]),
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
overwrite_fft_gauss = torch.fft.fftn(gauss_kernel)
|
||||||
|
input_padded_gauss_filtered: torch.Tensor = torch.real(
|
||||||
|
torch.fft.ifftn(torch.fft.fftn(input_padded) * overwrite_fft_gauss)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
input_padded_gauss_filtered = torch.real(
|
||||||
|
torch.fft.ifftn(torch.fft.fftn(input_padded) * overwrite_fft_gauss)
|
||||||
|
)
|
||||||
|
|
||||||
|
start = dims_xyt
|
||||||
|
stop = (
|
||||||
|
torch.tensor(input_padded.shape, device=dims_xyt.device, dtype=dims_xyt.dtype)
|
||||||
|
- dims_xyt
|
||||||
|
)
|
||||||
|
|
||||||
|
output = input_padded_gauss_filtered[
|
||||||
|
start[0] : stop[0], start[1] : stop[1], start[2] : stop[2]
|
||||||
|
]
|
||||||
|
|
||||||
|
return (output, overwrite_fft_gauss)
|
19
new_pipeline/functions/get_experiments.py
Normal file
19
new_pipeline/functions/get_experiments.py
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_experiments(path: str) -> torch.Tensor:
|
||||||
|
filename_np: str = os.path.join(
|
||||||
|
path,
|
||||||
|
"Exp*_Part001.npy",
|
||||||
|
)
|
||||||
|
|
||||||
|
list_str = glob.glob(filename_np)
|
||||||
|
list_int: list[int] = []
|
||||||
|
for i in range(0, len(list_str)):
|
||||||
|
list_int.append(int(list_str[i].split("Exp")[-1].split("_Trial")[0]))
|
||||||
|
list_int = sorted(list_int)
|
||||||
|
|
||||||
|
return torch.tensor(list_int).unique()
|
18
new_pipeline/functions/get_parts.py
Normal file
18
new_pipeline/functions/get_parts.py
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_parts(path: str, experiment_id: int, trial_id: int) -> torch.Tensor:
|
||||||
|
filename_np: str = os.path.join(
|
||||||
|
path,
|
||||||
|
f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part*.npy",
|
||||||
|
)
|
||||||
|
|
||||||
|
list_str = glob.glob(filename_np)
|
||||||
|
list_int: list[int] = []
|
||||||
|
for i in range(0, len(list_str)):
|
||||||
|
list_int.append(int(list_str[i].split("_Part")[-1].split(".npy")[0]))
|
||||||
|
list_int = sorted(list_int)
|
||||||
|
return torch.tensor(list_int).unique()
|
18
new_pipeline/functions/get_trials.py
Normal file
18
new_pipeline/functions/get_trials.py
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_trials(path: str, experiment_id: int) -> torch.Tensor:
|
||||||
|
filename_np: str = os.path.join(
|
||||||
|
path,
|
||||||
|
f"Exp{experiment_id:03d}_Trial*_Part001.npy",
|
||||||
|
)
|
||||||
|
|
||||||
|
list_str = glob.glob(filename_np)
|
||||||
|
list_int: list[int] = []
|
||||||
|
for i in range(0, len(list_str)):
|
||||||
|
list_int.append(int(list_str[i].split("_Trial")[-1].split("_Part")[0]))
|
||||||
|
list_int = sorted(list_int)
|
||||||
|
return torch.tensor(list_int).unique()
|
54
new_pipeline/functions/load_meta_data.py
Normal file
54
new_pipeline/functions/load_meta_data.py
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
def load_meta_data(
|
||||||
|
mylogger: logging.Logger, filename_meta: str
|
||||||
|
) -> tuple[list[str], str, str, dict, dict, float, float, str]:
|
||||||
|
|
||||||
|
mylogger.info("Loading meta data")
|
||||||
|
with open(filename_meta, "r") as file_handle:
|
||||||
|
metadata: dict = json.load(file_handle)
|
||||||
|
|
||||||
|
channels: list[str] = metadata["channelKey"]
|
||||||
|
|
||||||
|
mylogger.info(f"meta data: channel order: {channels}")
|
||||||
|
|
||||||
|
mouse_markings: str = metadata["sessionMetaData"]["mouseMarkings"]
|
||||||
|
mylogger.info(f"meta data: mouse markings: {mouse_markings}")
|
||||||
|
|
||||||
|
recording_date: str = metadata["sessionMetaData"]["date"]
|
||||||
|
mylogger.info(f"meta data: recording data: {recording_date}")
|
||||||
|
|
||||||
|
stimulation_times: dict = metadata["sessionMetaData"]["stimulationTimes"]
|
||||||
|
mylogger.info(f"meta data: stimulation times: {stimulation_times}")
|
||||||
|
|
||||||
|
experiment_names: dict = metadata["sessionMetaData"]["experimentNames"]
|
||||||
|
mylogger.info(f"meta data: experiment names: {experiment_names}")
|
||||||
|
|
||||||
|
trial_recording_duration: float = float(
|
||||||
|
metadata["sessionMetaData"]["trialRecordingDuration"]
|
||||||
|
)
|
||||||
|
mylogger.info(
|
||||||
|
f"meta data: trial recording duration: {trial_recording_duration} sec"
|
||||||
|
)
|
||||||
|
|
||||||
|
frame_time: float = float(metadata["sessionMetaData"]["frameTime"])
|
||||||
|
mylogger.info(
|
||||||
|
f"meta data: frame time: {frame_time} sec ; frame rate: {1.0/frame_time}Hz"
|
||||||
|
)
|
||||||
|
|
||||||
|
mouse: str = metadata["sessionMetaData"]["mouse"]
|
||||||
|
mylogger.info(f"meta data: mouse: {mouse}")
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
return (
|
||||||
|
channels,
|
||||||
|
mouse_markings,
|
||||||
|
recording_date,
|
||||||
|
stimulation_times,
|
||||||
|
experiment_names,
|
||||||
|
trial_recording_duration,
|
||||||
|
frame_time,
|
||||||
|
mouse,
|
||||||
|
)
|
143
new_pipeline/stage_1_get_ref_image.py
Normal file
143
new_pipeline/stage_1_get_ref_image.py
Normal file
|
@ -0,0 +1,143 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from jsmin import jsmin # type: ignore
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
from functions.get_experiments import get_experiments
|
||||||
|
from functions.get_trials import get_trials
|
||||||
|
from functions.get_parts import get_parts
|
||||||
|
from functions.bandpass import bandpass
|
||||||
|
from functions.create_logger import create_logger
|
||||||
|
from functions.load_meta_data import load_meta_data
|
||||||
|
|
||||||
|
mylogger = create_logger(
|
||||||
|
save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_1"
|
||||||
|
)
|
||||||
|
|
||||||
|
mylogger.info("loading config file")
|
||||||
|
with open("config.json", "r") as file:
|
||||||
|
config = json.loads(jsmin(file.read()))
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device_name: str = "cuda:0"
|
||||||
|
else:
|
||||||
|
device_name = "cpu"
|
||||||
|
|
||||||
|
if config["force_to_cpu"]:
|
||||||
|
device_name = "cpu"
|
||||||
|
|
||||||
|
mylogger.info(f"Using device: {device_name}")
|
||||||
|
device: torch.device = torch.device(device_name)
|
||||||
|
|
||||||
|
dtype_str: str = config["dtype"]
|
||||||
|
dtype: torch.dtype = getattr(torch, dtype_str)
|
||||||
|
|
||||||
|
raw_data_path: str = os.path.join(
|
||||||
|
config["basic_path"],
|
||||||
|
config["recoding_data"],
|
||||||
|
config["mouse_identifier"],
|
||||||
|
config["raw_path"],
|
||||||
|
)
|
||||||
|
|
||||||
|
mylogger.info(f"Using data path: {raw_data_path}")
|
||||||
|
|
||||||
|
first_experiment_id: int = int(get_experiments(raw_data_path).min())
|
||||||
|
first_trial_id: int = int(get_trials(raw_data_path, first_experiment_id).min())
|
||||||
|
first_part_id: int = int(
|
||||||
|
get_parts(raw_data_path, first_experiment_id, first_trial_id).min()
|
||||||
|
)
|
||||||
|
|
||||||
|
filename_data: str = os.path.join(
|
||||||
|
raw_data_path,
|
||||||
|
f"Exp{first_experiment_id:03d}_Trial{first_trial_id:03d}_Part{first_part_id:03d}.npy",
|
||||||
|
)
|
||||||
|
|
||||||
|
mylogger.info(f"Will use: {filename_data} for data")
|
||||||
|
|
||||||
|
filename_meta: str = os.path.join(
|
||||||
|
raw_data_path,
|
||||||
|
f"Exp{first_experiment_id:03d}_Trial{first_trial_id:03d}_Part{first_part_id:03d}_meta.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
mylogger.info(f"Will use: {filename_meta} for meta data")
|
||||||
|
|
||||||
|
meta_channels: list[str]
|
||||||
|
meta_mouse_markings: str
|
||||||
|
meta_recording_date: str
|
||||||
|
meta_stimulation_times: dict
|
||||||
|
meta_experiment_names: dict
|
||||||
|
meta_trial_recording_duration: float
|
||||||
|
meta_frame_time: float
|
||||||
|
meta_mouse: str
|
||||||
|
|
||||||
|
(
|
||||||
|
meta_channels,
|
||||||
|
meta_mouse_markings,
|
||||||
|
meta_recording_date,
|
||||||
|
meta_stimulation_times,
|
||||||
|
meta_experiment_names,
|
||||||
|
meta_trial_recording_duration,
|
||||||
|
meta_frame_time,
|
||||||
|
meta_mouse,
|
||||||
|
) = load_meta_data(mylogger=mylogger, filename_meta=filename_meta)
|
||||||
|
|
||||||
|
|
||||||
|
dtype_str = config["dtype"]
|
||||||
|
dtype_np: np.dtype = getattr(np, dtype_str)
|
||||||
|
|
||||||
|
mylogger.info("Loading data")
|
||||||
|
data = torch.tensor(
|
||||||
|
np.load(filename_data).astype(dtype_np), dtype=dtype, device=torch.device("cpu")
|
||||||
|
)
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
output_path = config["ref_image_path"]
|
||||||
|
mylogger.info(f"Create directory {output_path} in the case it does not exist")
|
||||||
|
os.makedirs(output_path, exist_ok=True)
|
||||||
|
|
||||||
|
mylogger.info("Reference images")
|
||||||
|
for i in range(0, len(meta_channels)):
|
||||||
|
temp_path: str = os.path.join(output_path, meta_channels[i] + ".npy")
|
||||||
|
mylogger.info(f"Extract and save: {temp_path}")
|
||||||
|
frame_id: int = data.shape[-2] // 2
|
||||||
|
mylogger.info(f"Will use frame id: {frame_id}")
|
||||||
|
ref_image: np.ndarray = (
|
||||||
|
data[:, :, frame_id, meta_channels.index(meta_channels[i])]
|
||||||
|
.clone()
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
)
|
||||||
|
np.save(temp_path, ref_image)
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
sample_frequency: float = 1.0 / meta_frame_time
|
||||||
|
mylogger.info(
|
||||||
|
(
|
||||||
|
f"Heartbeat power {config['lower_freqency_bandpass']}Hz"
|
||||||
|
f" - {config['upper_freqency_bandpass']}Hz,"
|
||||||
|
f" sample-rate: {sample_frequency},"
|
||||||
|
f" skipping the first {config['skip_frames_in_the_beginning']} frames"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(0, len(meta_channels)):
|
||||||
|
temp_path = os.path.join(output_path, meta_channels[i] + "_var.npy")
|
||||||
|
mylogger.info(f"Extract and save: {temp_path}")
|
||||||
|
|
||||||
|
heartbeat_ts: torch.Tensor = bandpass(
|
||||||
|
data=data[..., i],
|
||||||
|
device=data.device,
|
||||||
|
low_frequency=config["lower_freqency_bandpass"],
|
||||||
|
high_frequency=config["upper_freqency_bandpass"],
|
||||||
|
fs=sample_frequency,
|
||||||
|
filtfilt_chuck_size=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
heartbeat_power = heartbeat_ts[..., config["skip_frames_in_the_beginning"] :].var(
|
||||||
|
dim=-1
|
||||||
|
)
|
||||||
|
np.save(temp_path, heartbeat_power)
|
||||||
|
|
||||||
|
mylogger.info("-==- Done -==-")
|
155
new_pipeline/stage_2_make_heartbeat_mask.py
Normal file
155
new_pipeline/stage_2_make_heartbeat_mask.py
Normal file
|
@ -0,0 +1,155 @@
|
||||||
|
import matplotlib.pyplot as plt # type:ignore
|
||||||
|
import matplotlib
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
from jsmin import jsmin # type:ignore
|
||||||
|
from matplotlib.widgets import Slider, Button # type:ignore
|
||||||
|
from functools import partial
|
||||||
|
from functions.gauss_smear_individual import gauss_smear_individual
|
||||||
|
from functions.create_logger import create_logger
|
||||||
|
|
||||||
|
|
||||||
|
mylogger = create_logger(
|
||||||
|
save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_2"
|
||||||
|
)
|
||||||
|
|
||||||
|
mylogger.info("loading config file")
|
||||||
|
with open("config.json", "r") as file:
|
||||||
|
config = json.loads(jsmin(file.read()))
|
||||||
|
|
||||||
|
threshold: float = 0.05
|
||||||
|
path: str = config["ref_image_path"]
|
||||||
|
|
||||||
|
image_ref_file: str = os.path.join(path, "donor.npy")
|
||||||
|
image_var_file: str = os.path.join(path, "donor_var.npy")
|
||||||
|
heartbeat_mask_file: str = os.path.join(path, "heartbeat_mask.npy")
|
||||||
|
heartbeat_mask_threshold_file: str = os.path.join(path, "heartbeat_mask_threshold.npy")
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device_name: str = "cuda:0"
|
||||||
|
else:
|
||||||
|
device_name = "cpu"
|
||||||
|
|
||||||
|
if config["force_to_cpu"]:
|
||||||
|
device_name = "cpu"
|
||||||
|
|
||||||
|
mylogger.info(f"Using device: {device_name}")
|
||||||
|
device: torch.device = torch.device(device_name)
|
||||||
|
|
||||||
|
|
||||||
|
def next_frame(
|
||||||
|
i: float, images: np.ndarray, image_handle: matplotlib.image.AxesImage
|
||||||
|
) -> None:
|
||||||
|
global threshold
|
||||||
|
threshold = i
|
||||||
|
|
||||||
|
display_image: np.ndarray = images.copy()
|
||||||
|
display_image[..., 2] = display_image[..., 0]
|
||||||
|
mask: np.ndarray = np.where(images[..., 2] >= i, 1.0, np.nan)[..., np.newaxis]
|
||||||
|
display_image *= mask
|
||||||
|
display_image = np.nan_to_num(display_image, nan=1.0)
|
||||||
|
|
||||||
|
image_handle.set_data(display_image)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None:
|
||||||
|
global threshold
|
||||||
|
global volume_3color
|
||||||
|
global path
|
||||||
|
global mylogger
|
||||||
|
global heartbeat_mask_file
|
||||||
|
global heartbeat_mask_threshold_file
|
||||||
|
|
||||||
|
mylogger.info(f"Threshold: {threshold}")
|
||||||
|
|
||||||
|
mask: np.ndarray = volume_3color[..., 2] >= threshold
|
||||||
|
mylogger.info(f"Save mask to: {heartbeat_mask_file}")
|
||||||
|
np.save(heartbeat_mask_file, mask)
|
||||||
|
mylogger.info(f"Save threshold to: {heartbeat_mask_threshold_file}")
|
||||||
|
np.save(heartbeat_mask_threshold_file, np.array([threshold]))
|
||||||
|
exit()
|
||||||
|
|
||||||
|
|
||||||
|
def on_clicked_cancel(event: matplotlib.backend_bases.MouseEvent) -> None:
|
||||||
|
exit()
|
||||||
|
|
||||||
|
|
||||||
|
mylogger.info(f"loading image reference file: {image_ref_file}")
|
||||||
|
image_ref: np.ndarray = np.load(image_ref_file)
|
||||||
|
image_ref /= image_ref.max()
|
||||||
|
|
||||||
|
mylogger.info(f"loading image heartbeat power: {image_var_file}")
|
||||||
|
image_var: np.ndarray = np.load(image_var_file)
|
||||||
|
image_var /= image_var.max()
|
||||||
|
|
||||||
|
mylogger.info("Smear the image heartbeat power patially")
|
||||||
|
temp, _ = gauss_smear_individual(
|
||||||
|
input=torch.tensor(image_var[..., np.newaxis], device=device),
|
||||||
|
spatial_width=4.0,
|
||||||
|
temporal_width=0.1,
|
||||||
|
use_matlab_mask=False,
|
||||||
|
)
|
||||||
|
temp /= temp.max()
|
||||||
|
|
||||||
|
mylogger.info("-==- DONE -==-")
|
||||||
|
|
||||||
|
volume_3color = np.concatenate(
|
||||||
|
(
|
||||||
|
np.zeros_like(image_ref[..., np.newaxis]),
|
||||||
|
image_ref[..., np.newaxis],
|
||||||
|
temp.cpu().numpy(),
|
||||||
|
),
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
mylogger.info("Prepare image")
|
||||||
|
|
||||||
|
display_image = volume_3color.copy()
|
||||||
|
display_image[..., 2] = display_image[..., 0]
|
||||||
|
mask = np.where(volume_3color[..., 2] >= threshold, 1.0, np.nan)[..., np.newaxis]
|
||||||
|
display_image *= mask
|
||||||
|
display_image = np.nan_to_num(display_image, nan=1.0)
|
||||||
|
|
||||||
|
value_sort = np.sort(image_var.flatten())
|
||||||
|
value_sort_max = value_sort[int(value_sort.shape[0] * 0.95)]
|
||||||
|
mylogger.info("-==- DONE -==-")
|
||||||
|
|
||||||
|
mylogger.info("Create figure")
|
||||||
|
|
||||||
|
fig: matplotlib.figure.Figure = plt.figure()
|
||||||
|
|
||||||
|
image_handle = plt.imshow(display_image, vmin=0, vmax=1, cmap="hot")
|
||||||
|
|
||||||
|
mylogger.info("Add controls")
|
||||||
|
|
||||||
|
axfreq = fig.add_axes(rect=(0.4, 0.9, 0.3, 0.03))
|
||||||
|
slice_slider = Slider(
|
||||||
|
ax=axfreq,
|
||||||
|
label="Slice",
|
||||||
|
valmin=0,
|
||||||
|
valmax=value_sort_max,
|
||||||
|
valinit=threshold,
|
||||||
|
valstep=value_sort_max / 100.0,
|
||||||
|
)
|
||||||
|
axbutton_accept = fig.add_axes(rect=(0.3, 0.85, 0.2, 0.04))
|
||||||
|
button_accept = Button(
|
||||||
|
ax=axbutton_accept, label="Accept", image=None, color="0.85", hovercolor="0.95"
|
||||||
|
)
|
||||||
|
button_accept.on_clicked(on_clicked_accept) # type: ignore
|
||||||
|
|
||||||
|
axbutton_cancel = fig.add_axes(rect=(0.55, 0.85, 0.2, 0.04))
|
||||||
|
button_cancel = Button(
|
||||||
|
ax=axbutton_cancel, label="Cancel", image=None, color="0.85", hovercolor="0.95"
|
||||||
|
)
|
||||||
|
button_cancel.on_clicked(on_clicked_cancel) # type: ignore
|
||||||
|
|
||||||
|
slice_slider.on_changed(
|
||||||
|
partial(next_frame, images=volume_3color, image_handle=image_handle)
|
||||||
|
)
|
||||||
|
|
||||||
|
mylogger.info("Display")
|
||||||
|
plt.show()
|
Loading…
Reference in a new issue