Add files via upload

This commit is contained in:
David Rotermund 2024-02-24 17:28:26 +01:00 committed by GitHub
parent 840bae8628
commit 369540f472
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 682 additions and 0 deletions

26
new_pipeline/config.json Normal file
View file

@ -0,0 +1,26 @@
{
"basic_path": "/data_1/robert",
"recoding_data": "2021-05-05",
"mouse_identifier": "M3852M",
"raw_path": "raw",
"export_path": "output",
"ref_image_path": "ref_images",
"required_order": [
"acceptor",
"donor",
"oxygenation",
"volume"
],
"dtype": "float32",
"binning_enable": false,
"binning_kernel_size": 4,
"binning_stride": 4,
"binning_divisor_override": 1,
// Heart beat detection
"lower_freqency_bandpass": 5.0, // Hz
"upper_freqency_bandpass": 14.0, // Hz
// LED Ramp on
"skip_frames_in_the_beginning": 100, // Frames
// PyTorch
"force_to_cpu": false
}

View file

@ -0,0 +1,85 @@
import torchaudio as ta # type: ignore
import torch
@torch.no_grad()
def filtfilt(
input: torch.Tensor,
butter_a: torch.Tensor,
butter_b: torch.Tensor,
) -> torch.Tensor:
assert butter_a.ndim == 1
assert butter_b.ndim == 1
assert butter_a.shape[0] == butter_b.shape[0]
process_data: torch.Tensor = input.detach().clone()
padding_length = 12 * int(butter_a.shape[0])
left_padding = 2 * process_data[..., 0].unsqueeze(-1) - process_data[
..., 1 : padding_length + 1
].flip(-1)
right_padding = 2 * process_data[..., -1].unsqueeze(-1) - process_data[
..., -(padding_length + 1) : -1
].flip(-1)
process_data_padded = torch.cat((left_padding, process_data, right_padding), dim=-1)
output = ta.functional.filtfilt(
process_data_padded.unsqueeze(0), butter_a, butter_b, clamp=False
).squeeze(0)
output = output[..., padding_length:-padding_length]
return output
@torch.no_grad()
def butter_bandpass(
device: torch.device,
low_frequency: float = 0.1,
high_frequency: float = 1.0,
fs: float = 30.0,
) -> tuple[torch.Tensor, torch.Tensor]:
import scipy # type: ignore
butter_b_np, butter_a_np = scipy.signal.butter(
4, [low_frequency, high_frequency], btype="bandpass", output="ba", fs=fs
)
butter_a = torch.tensor(butter_a_np, device=device, dtype=torch.float32)
butter_b = torch.tensor(butter_b_np, device=device, dtype=torch.float32)
return butter_a, butter_b
@torch.no_grad()
def chunk_iterator(array: torch.Tensor, chunk_size: int):
for i in range(0, array.shape[0], chunk_size):
yield array[i : i + chunk_size]
@torch.no_grad()
def bandpass(
data: torch.Tensor,
device: torch.device,
low_frequency: float = 0.1,
high_frequency: float = 1.0,
fs=30.0,
filtfilt_chuck_size: int = 10,
) -> torch.Tensor:
butter_a, butter_b = butter_bandpass(
device=device,
low_frequency=low_frequency,
high_frequency=high_frequency,
fs=fs,
)
index_full_dataset: torch.Tensor = torch.arange(
0, data.shape[1], device=device, dtype=torch.int64
)
for chunk in chunk_iterator(index_full_dataset, filtfilt_chuck_size):
temp_filtfilt = filtfilt(
data[:, chunk, :],
butter_a=butter_a,
butter_b=butter_b,
)
data[:, chunk, :] = temp_filtfilt
return data

View file

@ -0,0 +1,37 @@
import logging
import datetime
import os
def create_logger(
save_logging_messages: bool, display_logging_messages: bool, log_stage_name: str
):
now = datetime.datetime.now()
dt_string_filename = now.strftime("%Y_%m_%d_%H_%M_%S")
logger = logging.getLogger("MyLittleLogger")
logger.setLevel(logging.DEBUG)
if save_logging_messages:
time_format = "%b %-d %Y %H:%M:%S"
logformat = "%(asctime)s %(message)s"
file_formatter = logging.Formatter(fmt=logformat, datefmt=time_format)
os.makedirs("logs_" + log_stage_name, exist_ok=True)
file_handler = logging.FileHandler(
os.path.join("logs_" + log_stage_name, f"log_{dt_string_filename}.txt")
)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
if display_logging_messages:
time_format = "%H:%M:%S"
logformat = "%(asctime)s %(message)s"
stream_formatter = logging.Formatter(fmt=logformat, datefmt=time_format)
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(stream_formatter)
logger.addHandler(stream_handler)
return logger

View file

@ -0,0 +1,127 @@
import torch
import math
@torch.no_grad()
def gauss_smear_individual(
input: torch.Tensor,
spatial_width: float,
temporal_width: float,
overwrite_fft_gauss: None | torch.Tensor = None,
use_matlab_mask: bool = True,
epsilon: float = float(torch.finfo(torch.float64).eps),
) -> tuple[torch.Tensor, torch.Tensor]:
dim_x: int = int(2 * math.ceil(2 * spatial_width) + 1)
dim_y: int = int(2 * math.ceil(2 * spatial_width) + 1)
dim_t: int = int(2 * math.ceil(2 * temporal_width) + 1)
dims_xyt: torch.Tensor = torch.tensor(
[dim_x, dim_y, dim_t], dtype=torch.int64, device=input.device
)
if input.ndim == 2:
input = input.unsqueeze(-1)
input_padded = torch.nn.functional.pad(
input.unsqueeze(0),
pad=(
dim_t,
dim_t,
dim_y,
dim_y,
dim_x,
dim_x,
),
mode="replicate",
).squeeze(0)
if overwrite_fft_gauss is None:
center_x: int = int(math.ceil(input_padded.shape[0] / 2))
center_y: int = int(math.ceil(input_padded.shape[1] / 2))
center_z: int = int(math.ceil(input_padded.shape[2] / 2))
grid_x: torch.Tensor = (
torch.arange(0, input_padded.shape[0], device=input.device) - center_x + 1
)
grid_y: torch.Tensor = (
torch.arange(0, input_padded.shape[1], device=input.device) - center_y + 1
)
grid_z: torch.Tensor = (
torch.arange(0, input_padded.shape[2], device=input.device) - center_z + 1
)
grid_x = grid_x.unsqueeze(-1).unsqueeze(-1) ** 2
grid_y = grid_y.unsqueeze(0).unsqueeze(-1) ** 2
grid_z = grid_z.unsqueeze(0).unsqueeze(0) ** 2
gauss_kernel: torch.Tensor = (
(grid_x / (spatial_width**2))
+ (grid_y / (spatial_width**2))
+ (grid_z / (temporal_width**2))
)
if use_matlab_mask:
filter_radius: torch.Tensor = (dims_xyt - 1) // 2
border_lower: list[int] = [
center_x - int(filter_radius[0]) - 1,
center_y - int(filter_radius[1]) - 1,
center_z - int(filter_radius[2]) - 1,
]
border_upper: list[int] = [
center_x + int(filter_radius[0]),
center_y + int(filter_radius[1]),
center_z + int(filter_radius[2]),
]
matlab_mask: torch.Tensor = torch.zeros_like(gauss_kernel)
matlab_mask[
border_lower[0] : border_upper[0],
border_lower[1] : border_upper[1],
border_lower[2] : border_upper[2],
] = 1.0
gauss_kernel = torch.exp(-gauss_kernel / 2.0)
if use_matlab_mask:
gauss_kernel = gauss_kernel * matlab_mask
gauss_kernel[gauss_kernel < (epsilon * gauss_kernel.max())] = 0
sum_gauss_kernel: float = float(gauss_kernel.sum())
if sum_gauss_kernel != 0.0:
gauss_kernel = gauss_kernel / sum_gauss_kernel
# FFT Shift
gauss_kernel = torch.cat(
(gauss_kernel[center_x - 1 :, :, :], gauss_kernel[: center_x - 1, :, :]),
dim=0,
)
gauss_kernel = torch.cat(
(gauss_kernel[:, center_y - 1 :, :], gauss_kernel[:, : center_y - 1, :]),
dim=1,
)
gauss_kernel = torch.cat(
(gauss_kernel[:, :, center_z - 1 :], gauss_kernel[:, :, : center_z - 1]),
dim=2,
)
overwrite_fft_gauss = torch.fft.fftn(gauss_kernel)
input_padded_gauss_filtered: torch.Tensor = torch.real(
torch.fft.ifftn(torch.fft.fftn(input_padded) * overwrite_fft_gauss)
)
else:
input_padded_gauss_filtered = torch.real(
torch.fft.ifftn(torch.fft.fftn(input_padded) * overwrite_fft_gauss)
)
start = dims_xyt
stop = (
torch.tensor(input_padded.shape, device=dims_xyt.device, dtype=dims_xyt.dtype)
- dims_xyt
)
output = input_padded_gauss_filtered[
start[0] : stop[0], start[1] : stop[1], start[2] : stop[2]
]
return (output, overwrite_fft_gauss)

View file

@ -0,0 +1,19 @@
import torch
import os
import glob
@torch.no_grad()
def get_experiments(path: str) -> torch.Tensor:
filename_np: str = os.path.join(
path,
"Exp*_Part001.npy",
)
list_str = glob.glob(filename_np)
list_int: list[int] = []
for i in range(0, len(list_str)):
list_int.append(int(list_str[i].split("Exp")[-1].split("_Trial")[0]))
list_int = sorted(list_int)
return torch.tensor(list_int).unique()

View file

@ -0,0 +1,18 @@
import torch
import os
import glob
@torch.no_grad()
def get_parts(path: str, experiment_id: int, trial_id: int) -> torch.Tensor:
filename_np: str = os.path.join(
path,
f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part*.npy",
)
list_str = glob.glob(filename_np)
list_int: list[int] = []
for i in range(0, len(list_str)):
list_int.append(int(list_str[i].split("_Part")[-1].split(".npy")[0]))
list_int = sorted(list_int)
return torch.tensor(list_int).unique()

View file

@ -0,0 +1,18 @@
import torch
import os
import glob
@torch.no_grad()
def get_trials(path: str, experiment_id: int) -> torch.Tensor:
filename_np: str = os.path.join(
path,
f"Exp{experiment_id:03d}_Trial*_Part001.npy",
)
list_str = glob.glob(filename_np)
list_int: list[int] = []
for i in range(0, len(list_str)):
list_int.append(int(list_str[i].split("_Trial")[-1].split("_Part")[0]))
list_int = sorted(list_int)
return torch.tensor(list_int).unique()

View file

@ -0,0 +1,54 @@
import logging
import json
def load_meta_data(
mylogger: logging.Logger, filename_meta: str
) -> tuple[list[str], str, str, dict, dict, float, float, str]:
mylogger.info("Loading meta data")
with open(filename_meta, "r") as file_handle:
metadata: dict = json.load(file_handle)
channels: list[str] = metadata["channelKey"]
mylogger.info(f"meta data: channel order: {channels}")
mouse_markings: str = metadata["sessionMetaData"]["mouseMarkings"]
mylogger.info(f"meta data: mouse markings: {mouse_markings}")
recording_date: str = metadata["sessionMetaData"]["date"]
mylogger.info(f"meta data: recording data: {recording_date}")
stimulation_times: dict = metadata["sessionMetaData"]["stimulationTimes"]
mylogger.info(f"meta data: stimulation times: {stimulation_times}")
experiment_names: dict = metadata["sessionMetaData"]["experimentNames"]
mylogger.info(f"meta data: experiment names: {experiment_names}")
trial_recording_duration: float = float(
metadata["sessionMetaData"]["trialRecordingDuration"]
)
mylogger.info(
f"meta data: trial recording duration: {trial_recording_duration} sec"
)
frame_time: float = float(metadata["sessionMetaData"]["frameTime"])
mylogger.info(
f"meta data: frame time: {frame_time} sec ; frame rate: {1.0/frame_time}Hz"
)
mouse: str = metadata["sessionMetaData"]["mouse"]
mylogger.info(f"meta data: mouse: {mouse}")
mylogger.info("-==- Done -==-")
return (
channels,
mouse_markings,
recording_date,
stimulation_times,
experiment_names,
trial_recording_duration,
frame_time,
mouse,
)

View file

@ -0,0 +1,143 @@
import json
import os
from jsmin import jsmin # type: ignore
import torch
import numpy as np
from functions.get_experiments import get_experiments
from functions.get_trials import get_trials
from functions.get_parts import get_parts
from functions.bandpass import bandpass
from functions.create_logger import create_logger
from functions.load_meta_data import load_meta_data
mylogger = create_logger(
save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_1"
)
mylogger.info("loading config file")
with open("config.json", "r") as file:
config = json.loads(jsmin(file.read()))
if torch.cuda.is_available():
device_name: str = "cuda:0"
else:
device_name = "cpu"
if config["force_to_cpu"]:
device_name = "cpu"
mylogger.info(f"Using device: {device_name}")
device: torch.device = torch.device(device_name)
dtype_str: str = config["dtype"]
dtype: torch.dtype = getattr(torch, dtype_str)
raw_data_path: str = os.path.join(
config["basic_path"],
config["recoding_data"],
config["mouse_identifier"],
config["raw_path"],
)
mylogger.info(f"Using data path: {raw_data_path}")
first_experiment_id: int = int(get_experiments(raw_data_path).min())
first_trial_id: int = int(get_trials(raw_data_path, first_experiment_id).min())
first_part_id: int = int(
get_parts(raw_data_path, first_experiment_id, first_trial_id).min()
)
filename_data: str = os.path.join(
raw_data_path,
f"Exp{first_experiment_id:03d}_Trial{first_trial_id:03d}_Part{first_part_id:03d}.npy",
)
mylogger.info(f"Will use: {filename_data} for data")
filename_meta: str = os.path.join(
raw_data_path,
f"Exp{first_experiment_id:03d}_Trial{first_trial_id:03d}_Part{first_part_id:03d}_meta.txt",
)
mylogger.info(f"Will use: {filename_meta} for meta data")
meta_channels: list[str]
meta_mouse_markings: str
meta_recording_date: str
meta_stimulation_times: dict
meta_experiment_names: dict
meta_trial_recording_duration: float
meta_frame_time: float
meta_mouse: str
(
meta_channels,
meta_mouse_markings,
meta_recording_date,
meta_stimulation_times,
meta_experiment_names,
meta_trial_recording_duration,
meta_frame_time,
meta_mouse,
) = load_meta_data(mylogger=mylogger, filename_meta=filename_meta)
dtype_str = config["dtype"]
dtype_np: np.dtype = getattr(np, dtype_str)
mylogger.info("Loading data")
data = torch.tensor(
np.load(filename_data).astype(dtype_np), dtype=dtype, device=torch.device("cpu")
)
mylogger.info("-==- Done -==-")
output_path = config["ref_image_path"]
mylogger.info(f"Create directory {output_path} in the case it does not exist")
os.makedirs(output_path, exist_ok=True)
mylogger.info("Reference images")
for i in range(0, len(meta_channels)):
temp_path: str = os.path.join(output_path, meta_channels[i] + ".npy")
mylogger.info(f"Extract and save: {temp_path}")
frame_id: int = data.shape[-2] // 2
mylogger.info(f"Will use frame id: {frame_id}")
ref_image: np.ndarray = (
data[:, :, frame_id, meta_channels.index(meta_channels[i])]
.clone()
.cpu()
.numpy()
)
np.save(temp_path, ref_image)
mylogger.info("-==- Done -==-")
sample_frequency: float = 1.0 / meta_frame_time
mylogger.info(
(
f"Heartbeat power {config['lower_freqency_bandpass']}Hz"
f" - {config['upper_freqency_bandpass']}Hz,"
f" sample-rate: {sample_frequency},"
f" skipping the first {config['skip_frames_in_the_beginning']} frames"
)
)
for i in range(0, len(meta_channels)):
temp_path = os.path.join(output_path, meta_channels[i] + "_var.npy")
mylogger.info(f"Extract and save: {temp_path}")
heartbeat_ts: torch.Tensor = bandpass(
data=data[..., i],
device=data.device,
low_frequency=config["lower_freqency_bandpass"],
high_frequency=config["upper_freqency_bandpass"],
fs=sample_frequency,
filtfilt_chuck_size=10,
)
heartbeat_power = heartbeat_ts[..., config["skip_frames_in_the_beginning"] :].var(
dim=-1
)
np.save(temp_path, heartbeat_power)
mylogger.info("-==- Done -==-")

View file

@ -0,0 +1,155 @@
import matplotlib.pyplot as plt # type:ignore
import matplotlib
import numpy as np
import torch
import os
import json
from jsmin import jsmin # type:ignore
from matplotlib.widgets import Slider, Button # type:ignore
from functools import partial
from functions.gauss_smear_individual import gauss_smear_individual
from functions.create_logger import create_logger
mylogger = create_logger(
save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_2"
)
mylogger.info("loading config file")
with open("config.json", "r") as file:
config = json.loads(jsmin(file.read()))
threshold: float = 0.05
path: str = config["ref_image_path"]
image_ref_file: str = os.path.join(path, "donor.npy")
image_var_file: str = os.path.join(path, "donor_var.npy")
heartbeat_mask_file: str = os.path.join(path, "heartbeat_mask.npy")
heartbeat_mask_threshold_file: str = os.path.join(path, "heartbeat_mask_threshold.npy")
if torch.cuda.is_available():
device_name: str = "cuda:0"
else:
device_name = "cpu"
if config["force_to_cpu"]:
device_name = "cpu"
mylogger.info(f"Using device: {device_name}")
device: torch.device = torch.device(device_name)
def next_frame(
i: float, images: np.ndarray, image_handle: matplotlib.image.AxesImage
) -> None:
global threshold
threshold = i
display_image: np.ndarray = images.copy()
display_image[..., 2] = display_image[..., 0]
mask: np.ndarray = np.where(images[..., 2] >= i, 1.0, np.nan)[..., np.newaxis]
display_image *= mask
display_image = np.nan_to_num(display_image, nan=1.0)
image_handle.set_data(display_image)
return
def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None:
global threshold
global volume_3color
global path
global mylogger
global heartbeat_mask_file
global heartbeat_mask_threshold_file
mylogger.info(f"Threshold: {threshold}")
mask: np.ndarray = volume_3color[..., 2] >= threshold
mylogger.info(f"Save mask to: {heartbeat_mask_file}")
np.save(heartbeat_mask_file, mask)
mylogger.info(f"Save threshold to: {heartbeat_mask_threshold_file}")
np.save(heartbeat_mask_threshold_file, np.array([threshold]))
exit()
def on_clicked_cancel(event: matplotlib.backend_bases.MouseEvent) -> None:
exit()
mylogger.info(f"loading image reference file: {image_ref_file}")
image_ref: np.ndarray = np.load(image_ref_file)
image_ref /= image_ref.max()
mylogger.info(f"loading image heartbeat power: {image_var_file}")
image_var: np.ndarray = np.load(image_var_file)
image_var /= image_var.max()
mylogger.info("Smear the image heartbeat power patially")
temp, _ = gauss_smear_individual(
input=torch.tensor(image_var[..., np.newaxis], device=device),
spatial_width=4.0,
temporal_width=0.1,
use_matlab_mask=False,
)
temp /= temp.max()
mylogger.info("-==- DONE -==-")
volume_3color = np.concatenate(
(
np.zeros_like(image_ref[..., np.newaxis]),
image_ref[..., np.newaxis],
temp.cpu().numpy(),
),
axis=-1,
)
mylogger.info("Prepare image")
display_image = volume_3color.copy()
display_image[..., 2] = display_image[..., 0]
mask = np.where(volume_3color[..., 2] >= threshold, 1.0, np.nan)[..., np.newaxis]
display_image *= mask
display_image = np.nan_to_num(display_image, nan=1.0)
value_sort = np.sort(image_var.flatten())
value_sort_max = value_sort[int(value_sort.shape[0] * 0.95)]
mylogger.info("-==- DONE -==-")
mylogger.info("Create figure")
fig: matplotlib.figure.Figure = plt.figure()
image_handle = plt.imshow(display_image, vmin=0, vmax=1, cmap="hot")
mylogger.info("Add controls")
axfreq = fig.add_axes(rect=(0.4, 0.9, 0.3, 0.03))
slice_slider = Slider(
ax=axfreq,
label="Slice",
valmin=0,
valmax=value_sort_max,
valinit=threshold,
valstep=value_sort_max / 100.0,
)
axbutton_accept = fig.add_axes(rect=(0.3, 0.85, 0.2, 0.04))
button_accept = Button(
ax=axbutton_accept, label="Accept", image=None, color="0.85", hovercolor="0.95"
)
button_accept.on_clicked(on_clicked_accept) # type: ignore
axbutton_cancel = fig.add_axes(rect=(0.55, 0.85, 0.2, 0.04))
button_cancel = Button(
ax=axbutton_cancel, label="Cancel", image=None, color="0.85", hovercolor="0.95"
)
button_cancel.on_clicked(on_clicked_cancel) # type: ignore
slice_slider.on_changed(
partial(next_frame, images=volume_3color, image_handle=image_handle)
)
mylogger.info("Display")
plt.show()