Add files via upload
This commit is contained in:
parent
84c254ae76
commit
77dc69eb13
5 changed files with 1416 additions and 0 deletions
62
config.json
Normal file
62
config.json
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
{
|
||||||
|
"basic_path": "/data_1/hendrik",
|
||||||
|
"recoding_data": "2021-06-17",
|
||||||
|
"mouse_identifier": "M3859M",
|
||||||
|
//"basic_path": "/data_1/robert",
|
||||||
|
//"recoding_data": "2021-10-05",
|
||||||
|
//"mouse_identifier": "M3879M",
|
||||||
|
"raw_path": "raw",
|
||||||
|
"export_path": "output",
|
||||||
|
"ref_image_path": "ref_images",
|
||||||
|
// Ratio Sequence
|
||||||
|
"classical_ratio_mode": true, // true: a/d false: 1+a-d
|
||||||
|
// Regression
|
||||||
|
"target_camera_acceptor": "acceptor",
|
||||||
|
"regressor_cameras_acceptor": [
|
||||||
|
"oxygenation",
|
||||||
|
"volume"
|
||||||
|
],
|
||||||
|
"target_camera_donor": "donor",
|
||||||
|
"regressor_cameras_donor": [
|
||||||
|
"oxygenation",
|
||||||
|
"volume"
|
||||||
|
],
|
||||||
|
// binning
|
||||||
|
"binning_enable": true,
|
||||||
|
"binning_at_the_end": false,
|
||||||
|
"binning_kernel_size": 4,
|
||||||
|
"binning_stride": 4,
|
||||||
|
"binning_divisor_override": 1,
|
||||||
|
// alignment
|
||||||
|
"alignment_batch_size": 200,
|
||||||
|
"rotation_stabilization_threshold_factor": 3.0, // >= 1.0
|
||||||
|
"rotation_stabilization_threshold_border": 0.9, // <= 1.0
|
||||||
|
// Heart beat detection
|
||||||
|
"lower_freqency_bandpass": 5.0, // Hz
|
||||||
|
"upper_freqency_bandpass": 14.0, // Hz
|
||||||
|
"heartbeat_filtfilt_chuck_size": 10,
|
||||||
|
// Gauss smear
|
||||||
|
"gauss_smear_spatial_width": 8,
|
||||||
|
"gauss_smear_temporal_width": 0.1,
|
||||||
|
"gauss_smear_use_matlab_mask": false,
|
||||||
|
// LED Ramp on
|
||||||
|
"skip_frames_in_the_beginning": 100, // Frames
|
||||||
|
// PyTorch
|
||||||
|
"dtype": "float32",
|
||||||
|
"force_to_cpu": false,
|
||||||
|
// Save
|
||||||
|
"save_as_python": true, // produces .npz files (compressed)
|
||||||
|
"save_as_matlab": false, // produces .hd5 file (compressed)
|
||||||
|
// Save extra information
|
||||||
|
"save_alignment": false,
|
||||||
|
"save_heartbeat": false,
|
||||||
|
"save_factors": false,
|
||||||
|
"save_regression_coefficients": false,
|
||||||
|
// Not important parameter
|
||||||
|
"required_order": [
|
||||||
|
"acceptor",
|
||||||
|
"donor",
|
||||||
|
"oxygenation",
|
||||||
|
"volume"
|
||||||
|
]
|
||||||
|
}
|
126
stage_1_get_ref_image.py
Normal file
126
stage_1_get_ref_image.py
Normal file
|
@ -0,0 +1,126 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
from functions.get_experiments import get_experiments
|
||||||
|
from functions.get_trials import get_trials
|
||||||
|
from functions.bandpass import bandpass
|
||||||
|
from functions.create_logger import create_logger
|
||||||
|
from functions.get_torch_device import get_torch_device
|
||||||
|
from functions.load_config import load_config
|
||||||
|
from functions.data_raw_loader import data_raw_loader
|
||||||
|
|
||||||
|
mylogger = create_logger(
|
||||||
|
save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_1"
|
||||||
|
)
|
||||||
|
|
||||||
|
config = load_config(mylogger=mylogger)
|
||||||
|
|
||||||
|
if config["binning_enable"] and (config["binning_at_the_end"] is False):
|
||||||
|
device: torch.device = torch.device("cpu")
|
||||||
|
else:
|
||||||
|
device = get_torch_device(mylogger, config["force_to_cpu"])
|
||||||
|
|
||||||
|
|
||||||
|
dtype_str: str = config["dtype"]
|
||||||
|
dtype: torch.dtype = getattr(torch, dtype_str)
|
||||||
|
|
||||||
|
raw_data_path: str = os.path.join(
|
||||||
|
config["basic_path"],
|
||||||
|
config["recoding_data"],
|
||||||
|
config["mouse_identifier"],
|
||||||
|
config["raw_path"],
|
||||||
|
)
|
||||||
|
|
||||||
|
mylogger.info(f"Using data path: {raw_data_path}")
|
||||||
|
|
||||||
|
first_experiment_id: int = int(get_experiments(raw_data_path).min())
|
||||||
|
first_trial_id: int = int(get_trials(raw_data_path, first_experiment_id).min())
|
||||||
|
|
||||||
|
meta_channels: list[str]
|
||||||
|
meta_mouse_markings: str
|
||||||
|
meta_recording_date: str
|
||||||
|
meta_stimulation_times: dict
|
||||||
|
meta_experiment_names: dict
|
||||||
|
meta_trial_recording_duration: float
|
||||||
|
meta_frame_time: float
|
||||||
|
meta_mouse: str
|
||||||
|
data: torch.Tensor
|
||||||
|
|
||||||
|
if config["binning_enable"] and (config["binning_at_the_end"] is False):
|
||||||
|
force_to_cpu_memory: bool = True
|
||||||
|
else:
|
||||||
|
force_to_cpu_memory = False
|
||||||
|
|
||||||
|
mylogger.info("Loading data")
|
||||||
|
|
||||||
|
(
|
||||||
|
meta_channels,
|
||||||
|
meta_mouse_markings,
|
||||||
|
meta_recording_date,
|
||||||
|
meta_stimulation_times,
|
||||||
|
meta_experiment_names,
|
||||||
|
meta_trial_recording_duration,
|
||||||
|
meta_frame_time,
|
||||||
|
meta_mouse,
|
||||||
|
data,
|
||||||
|
) = data_raw_loader(
|
||||||
|
raw_data_path=raw_data_path,
|
||||||
|
mylogger=mylogger,
|
||||||
|
experiment_id=first_experiment_id,
|
||||||
|
trial_id=first_trial_id,
|
||||||
|
device=device,
|
||||||
|
force_to_cpu_memory=force_to_cpu_memory,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
output_path = config["ref_image_path"]
|
||||||
|
mylogger.info(f"Create directory {output_path} in the case it does not exist")
|
||||||
|
os.makedirs(output_path, exist_ok=True)
|
||||||
|
|
||||||
|
mylogger.info("Reference images")
|
||||||
|
for i in range(0, len(meta_channels)):
|
||||||
|
temp_path: str = os.path.join(output_path, meta_channels[i] + ".npy")
|
||||||
|
mylogger.info(f"Extract and save: {temp_path}")
|
||||||
|
frame_id: int = data.shape[-2] // 2
|
||||||
|
mylogger.info(f"Will use frame id: {frame_id}")
|
||||||
|
ref_image: np.ndarray = (
|
||||||
|
data[:, :, frame_id, meta_channels.index(meta_channels[i])]
|
||||||
|
.clone()
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
)
|
||||||
|
np.save(temp_path, ref_image)
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
sample_frequency: float = 1.0 / meta_frame_time
|
||||||
|
mylogger.info(
|
||||||
|
(
|
||||||
|
f"Heartbeat power {config['lower_freqency_bandpass']} Hz"
|
||||||
|
f" - {config['upper_freqency_bandpass']} Hz,"
|
||||||
|
f" sample-rate: {sample_frequency},"
|
||||||
|
f" skipping the first {config['skip_frames_in_the_beginning']} frames"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(0, len(meta_channels)):
|
||||||
|
temp_path = os.path.join(output_path, meta_channels[i] + "_var.npy")
|
||||||
|
mylogger.info(f"Extract and save: {temp_path}")
|
||||||
|
|
||||||
|
heartbeat_ts: torch.Tensor = bandpass(
|
||||||
|
data=data[..., i],
|
||||||
|
device=data.device,
|
||||||
|
low_frequency=config["lower_freqency_bandpass"],
|
||||||
|
high_frequency=config["upper_freqency_bandpass"],
|
||||||
|
fs=sample_frequency,
|
||||||
|
filtfilt_chuck_size=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
heartbeat_power = heartbeat_ts[..., config["skip_frames_in_the_beginning"] :].var(
|
||||||
|
dim=-1
|
||||||
|
)
|
||||||
|
np.save(temp_path, heartbeat_power)
|
||||||
|
|
||||||
|
mylogger.info("-==- Done -==-")
|
153
stage_2_make_heartbeat_mask.py
Normal file
153
stage_2_make_heartbeat_mask.py
Normal file
|
@ -0,0 +1,153 @@
|
||||||
|
import matplotlib.pyplot as plt # type:ignore
|
||||||
|
import matplotlib
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
|
from matplotlib.widgets import Slider, Button # type:ignore
|
||||||
|
from functools import partial
|
||||||
|
from functions.gauss_smear_individual import gauss_smear_individual
|
||||||
|
from functions.create_logger import create_logger
|
||||||
|
from functions.get_torch_device import get_torch_device
|
||||||
|
from functions.load_config import load_config
|
||||||
|
|
||||||
|
mylogger = create_logger(
|
||||||
|
save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_2"
|
||||||
|
)
|
||||||
|
|
||||||
|
config = load_config(mylogger=mylogger)
|
||||||
|
|
||||||
|
path: str = config["ref_image_path"]
|
||||||
|
use_channel: str = "donor"
|
||||||
|
spatial_width: float = 4.0
|
||||||
|
temporal_width: float = 0.1
|
||||||
|
|
||||||
|
threshold: float = 0.05
|
||||||
|
|
||||||
|
heartbeat_mask_threshold_file: str = os.path.join(path, "heartbeat_mask_threshold.npy")
|
||||||
|
if os.path.isfile(heartbeat_mask_threshold_file):
|
||||||
|
mylogger.info(f"loading previous threshold file: {heartbeat_mask_threshold_file}")
|
||||||
|
threshold = float(np.load(heartbeat_mask_threshold_file)[0])
|
||||||
|
|
||||||
|
mylogger.info(f"initial threshold is {threshold}")
|
||||||
|
|
||||||
|
image_ref_file: str = os.path.join(path, use_channel + ".npy")
|
||||||
|
image_var_file: str = os.path.join(path, use_channel + "_var.npy")
|
||||||
|
heartbeat_mask_file: str = os.path.join(path, "heartbeat_mask.npy")
|
||||||
|
|
||||||
|
device = get_torch_device(mylogger, config["force_to_cpu"])
|
||||||
|
|
||||||
|
|
||||||
|
def next_frame(
|
||||||
|
i: float, images: np.ndarray, image_handle: matplotlib.image.AxesImage
|
||||||
|
) -> None:
|
||||||
|
global threshold
|
||||||
|
threshold = i
|
||||||
|
|
||||||
|
display_image: np.ndarray = images.copy()
|
||||||
|
display_image[..., 2] = display_image[..., 0]
|
||||||
|
mask: np.ndarray = np.where(images[..., 2] >= i, 1.0, np.nan)[..., np.newaxis]
|
||||||
|
display_image *= mask
|
||||||
|
display_image = np.nan_to_num(display_image, nan=1.0)
|
||||||
|
|
||||||
|
image_handle.set_data(display_image)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None:
|
||||||
|
global threshold
|
||||||
|
global image_3color
|
||||||
|
global path
|
||||||
|
global mylogger
|
||||||
|
global heartbeat_mask_file
|
||||||
|
global heartbeat_mask_threshold_file
|
||||||
|
|
||||||
|
mylogger.info(f"Threshold: {threshold}")
|
||||||
|
|
||||||
|
mask: np.ndarray = image_3color[..., 2] >= threshold
|
||||||
|
mylogger.info(f"Save mask to: {heartbeat_mask_file}")
|
||||||
|
np.save(heartbeat_mask_file, mask)
|
||||||
|
mylogger.info(f"Save threshold to: {heartbeat_mask_threshold_file}")
|
||||||
|
np.save(heartbeat_mask_threshold_file, np.array([threshold]))
|
||||||
|
exit()
|
||||||
|
|
||||||
|
|
||||||
|
def on_clicked_cancel(event: matplotlib.backend_bases.MouseEvent) -> None:
|
||||||
|
exit()
|
||||||
|
|
||||||
|
|
||||||
|
mylogger.info(f"loading image reference file: {image_ref_file}")
|
||||||
|
image_ref: np.ndarray = np.load(image_ref_file)
|
||||||
|
image_ref /= image_ref.max()
|
||||||
|
|
||||||
|
mylogger.info(f"loading image heartbeat power: {image_var_file}")
|
||||||
|
image_var: np.ndarray = np.load(image_var_file)
|
||||||
|
image_var /= image_var.max()
|
||||||
|
|
||||||
|
mylogger.info("Smear the image heartbeat power patially")
|
||||||
|
temp, _ = gauss_smear_individual(
|
||||||
|
input=torch.tensor(image_var[..., np.newaxis], device=device),
|
||||||
|
spatial_width=spatial_width,
|
||||||
|
temporal_width=temporal_width,
|
||||||
|
use_matlab_mask=False,
|
||||||
|
)
|
||||||
|
temp /= temp.max()
|
||||||
|
|
||||||
|
mylogger.info("-==- DONE -==-")
|
||||||
|
|
||||||
|
image_3color = np.concatenate(
|
||||||
|
(
|
||||||
|
np.zeros_like(image_ref[..., np.newaxis]),
|
||||||
|
image_ref[..., np.newaxis],
|
||||||
|
temp.cpu().numpy(),
|
||||||
|
),
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
mylogger.info("Prepare image")
|
||||||
|
|
||||||
|
display_image = image_3color.copy()
|
||||||
|
display_image[..., 2] = display_image[..., 0]
|
||||||
|
mask = np.where(image_3color[..., 2] >= threshold, 1.0, np.nan)[..., np.newaxis]
|
||||||
|
display_image *= mask
|
||||||
|
display_image = np.nan_to_num(display_image, nan=1.0)
|
||||||
|
|
||||||
|
value_sort = np.sort(image_var.flatten())
|
||||||
|
value_sort_max = value_sort[int(value_sort.shape[0] * 0.95)]
|
||||||
|
mylogger.info("-==- DONE -==-")
|
||||||
|
|
||||||
|
mylogger.info("Create figure")
|
||||||
|
|
||||||
|
fig: matplotlib.figure.Figure = plt.figure()
|
||||||
|
|
||||||
|
image_handle = plt.imshow(display_image, vmin=0, vmax=1, cmap="hot")
|
||||||
|
|
||||||
|
mylogger.info("Add controls")
|
||||||
|
|
||||||
|
axfreq = fig.add_axes(rect=(0.4, 0.9, 0.3, 0.03))
|
||||||
|
slice_slider = Slider(
|
||||||
|
ax=axfreq,
|
||||||
|
label="Threshold",
|
||||||
|
valmin=0,
|
||||||
|
valmax=value_sort_max,
|
||||||
|
valinit=threshold,
|
||||||
|
valstep=value_sort_max / 100.0,
|
||||||
|
)
|
||||||
|
axbutton_accept = fig.add_axes(rect=(0.3, 0.85, 0.2, 0.04))
|
||||||
|
button_accept = Button(
|
||||||
|
ax=axbutton_accept, label="Accept", image=None, color="0.85", hovercolor="0.95"
|
||||||
|
)
|
||||||
|
button_accept.on_clicked(on_clicked_accept) # type: ignore
|
||||||
|
|
||||||
|
axbutton_cancel = fig.add_axes(rect=(0.55, 0.85, 0.2, 0.04))
|
||||||
|
button_cancel = Button(
|
||||||
|
ax=axbutton_cancel, label="Cancel", image=None, color="0.85", hovercolor="0.95"
|
||||||
|
)
|
||||||
|
button_cancel.on_clicked(on_clicked_cancel) # type: ignore
|
||||||
|
|
||||||
|
slice_slider.on_changed(
|
||||||
|
partial(next_frame, images=image_3color, image_handle=image_handle)
|
||||||
|
)
|
||||||
|
|
||||||
|
mylogger.info("Display")
|
||||||
|
plt.show()
|
157
stage_3_refine_mask.py
Normal file
157
stage_3_refine_mask.py
Normal file
|
@ -0,0 +1,157 @@
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt # type:ignore
|
||||||
|
import matplotlib
|
||||||
|
from matplotlib.widgets import Button # type:ignore
|
||||||
|
|
||||||
|
# pip install roipoly
|
||||||
|
from roipoly import RoiPoly # type:ignore
|
||||||
|
|
||||||
|
from functions.create_logger import create_logger
|
||||||
|
from functions.get_torch_device import get_torch_device
|
||||||
|
from functions.load_config import load_config
|
||||||
|
|
||||||
|
|
||||||
|
def compose_image(image_3color: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
||||||
|
display_image = image_3color.copy()
|
||||||
|
display_image[..., 2] = display_image[..., 0]
|
||||||
|
display_image[mask == 0, :] = 1.0
|
||||||
|
return display_image
|
||||||
|
|
||||||
|
|
||||||
|
def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None:
|
||||||
|
global mylogger
|
||||||
|
global refined_mask_file
|
||||||
|
global mask
|
||||||
|
|
||||||
|
mylogger.info(f"Save mask to: {refined_mask_file}")
|
||||||
|
np.save(refined_mask_file, mask)
|
||||||
|
|
||||||
|
exit()
|
||||||
|
|
||||||
|
|
||||||
|
def on_clicked_cancel(event: matplotlib.backend_bases.MouseEvent) -> None:
|
||||||
|
global mylogger
|
||||||
|
mylogger.info("Ended without saving the mask")
|
||||||
|
exit()
|
||||||
|
|
||||||
|
|
||||||
|
def on_clicked_add(event: matplotlib.backend_bases.MouseEvent) -> None:
|
||||||
|
global new_roi
|
||||||
|
global mask
|
||||||
|
global image_3color
|
||||||
|
global display_image
|
||||||
|
global mylogger
|
||||||
|
if len(new_roi.x) > 0:
|
||||||
|
mylogger.info("A ROI with the following coordiantes has been added to the mask")
|
||||||
|
for i in range(0, len(new_roi.x)):
|
||||||
|
mylogger.info(f"{round(new_roi.x[i],1)} x {round(new_roi.y[i],1)}")
|
||||||
|
mylogger.info("")
|
||||||
|
new_mask = new_roi.get_mask(display_image[:, :, 0])
|
||||||
|
mask[new_mask] = 0.0
|
||||||
|
display_image = compose_image(image_3color=image_3color, mask=mask)
|
||||||
|
image_handle.set_data(display_image)
|
||||||
|
for line in ax_main.lines:
|
||||||
|
line.remove()
|
||||||
|
plt.draw()
|
||||||
|
|
||||||
|
new_roi = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False)
|
||||||
|
|
||||||
|
|
||||||
|
def on_clicked_remove(event: matplotlib.backend_bases.MouseEvent) -> None:
|
||||||
|
global new_roi
|
||||||
|
global mask
|
||||||
|
global image_3color
|
||||||
|
global display_image
|
||||||
|
if len(new_roi.x) > 0:
|
||||||
|
mylogger.info(
|
||||||
|
"A ROI with the following coordiantes has been removed from the mask"
|
||||||
|
)
|
||||||
|
for i in range(0, len(new_roi.x)):
|
||||||
|
mylogger.info(f"{round(new_roi.x[i],1)} x {round(new_roi.y[i],1)}")
|
||||||
|
mylogger.info("")
|
||||||
|
new_mask = new_roi.get_mask(display_image[:, :, 0])
|
||||||
|
mask[new_mask] = 1.0
|
||||||
|
display_image = compose_image(image_3color=image_3color, mask=mask)
|
||||||
|
image_handle.set_data(display_image)
|
||||||
|
for line in ax_main.lines:
|
||||||
|
line.remove()
|
||||||
|
plt.draw()
|
||||||
|
new_roi = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False)
|
||||||
|
|
||||||
|
|
||||||
|
mylogger = create_logger(
|
||||||
|
save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_3"
|
||||||
|
)
|
||||||
|
|
||||||
|
config = load_config(mylogger=mylogger)
|
||||||
|
|
||||||
|
device = get_torch_device(mylogger, config["force_to_cpu"])
|
||||||
|
|
||||||
|
path: str = config["ref_image_path"]
|
||||||
|
use_channel: str = "donor"
|
||||||
|
|
||||||
|
image_ref_file: str = os.path.join(path, use_channel + ".npy")
|
||||||
|
heartbeat_mask_file: str = os.path.join(path, "heartbeat_mask.npy")
|
||||||
|
refined_mask_file: str = os.path.join(path, "mask_not_rotated.npy")
|
||||||
|
|
||||||
|
mylogger.info(f"loading image reference file: {image_ref_file}")
|
||||||
|
image_ref: np.ndarray = np.load(image_ref_file)
|
||||||
|
image_ref /= image_ref.max()
|
||||||
|
|
||||||
|
mylogger.info(f"loading heartbeat mask: {heartbeat_mask_file}")
|
||||||
|
mask: np.ndarray = np.load(heartbeat_mask_file)
|
||||||
|
|
||||||
|
image_3color = np.concatenate(
|
||||||
|
(
|
||||||
|
np.zeros_like(image_ref[..., np.newaxis]),
|
||||||
|
image_ref[..., np.newaxis],
|
||||||
|
np.zeros_like(image_ref[..., np.newaxis]),
|
||||||
|
),
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
mylogger.info("-==- DONE -==-")
|
||||||
|
|
||||||
|
fig, ax_main = plt.subplots()
|
||||||
|
|
||||||
|
display_image = compose_image(image_3color=image_3color, mask=mask)
|
||||||
|
image_handle = ax_main.imshow(display_image, vmin=0, vmax=1, cmap="hot")
|
||||||
|
|
||||||
|
mylogger.info("Add controls")
|
||||||
|
|
||||||
|
axbutton_accept = fig.add_axes(rect=(0.3, 0.85, 0.2, 0.04))
|
||||||
|
button_accept = Button(
|
||||||
|
ax=axbutton_accept, label="Accept", image=None, color="0.85", hovercolor="0.95"
|
||||||
|
)
|
||||||
|
button_accept.on_clicked(on_clicked_accept) # type: ignore
|
||||||
|
|
||||||
|
axbutton_cancel = fig.add_axes(rect=(0.5, 0.85, 0.2, 0.04))
|
||||||
|
button_cancel = Button(
|
||||||
|
ax=axbutton_cancel, label="Cancel", image=None, color="0.85", hovercolor="0.95"
|
||||||
|
)
|
||||||
|
button_cancel.on_clicked(on_clicked_cancel) # type: ignore
|
||||||
|
|
||||||
|
axbutton_addmask = fig.add_axes(rect=(0.3, 0.9, 0.2, 0.04))
|
||||||
|
button_addmask = Button(
|
||||||
|
ax=axbutton_addmask, label="Add mask", image=None, color="0.85", hovercolor="0.95"
|
||||||
|
)
|
||||||
|
button_addmask.on_clicked(on_clicked_add) # type: ignore
|
||||||
|
|
||||||
|
axbutton_removemask = fig.add_axes(rect=(0.5, 0.9, 0.2, 0.04))
|
||||||
|
button_removemask = Button(
|
||||||
|
ax=axbutton_removemask,
|
||||||
|
label="Remove mask",
|
||||||
|
image=None,
|
||||||
|
color="0.85",
|
||||||
|
hovercolor="0.95",
|
||||||
|
)
|
||||||
|
button_removemask.on_clicked(on_clicked_remove) # type: ignore
|
||||||
|
|
||||||
|
# ax_main.cla()
|
||||||
|
|
||||||
|
mylogger.info("Display")
|
||||||
|
new_roi: RoiPoly = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False)
|
||||||
|
|
||||||
|
plt.show()
|
918
stage_4_process.py
Normal file
918
stage_4_process.py
Normal file
|
@ -0,0 +1,918 @@
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torchvision as tv # type: ignore
|
||||||
|
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import h5py # type: ignore
|
||||||
|
|
||||||
|
from functions.create_logger import create_logger
|
||||||
|
from functions.get_torch_device import get_torch_device
|
||||||
|
from functions.load_config import load_config
|
||||||
|
from functions.get_experiments import get_experiments
|
||||||
|
from functions.get_trials import get_trials
|
||||||
|
from functions.binning import binning
|
||||||
|
from functions.ImageAlignment import ImageAlignment
|
||||||
|
from functions.align_refref import align_refref
|
||||||
|
from functions.perform_donor_volume_rotation import perform_donor_volume_rotation
|
||||||
|
from functions.perform_donor_volume_translation import perform_donor_volume_translation
|
||||||
|
from functions.bandpass import bandpass
|
||||||
|
from functions.gauss_smear_individual import gauss_smear_individual
|
||||||
|
from functions.regression import regression
|
||||||
|
from functions.data_raw_loader import data_raw_loader
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def process_trial(
|
||||||
|
config: dict,
|
||||||
|
mylogger: logging.Logger,
|
||||||
|
experiment_id: int,
|
||||||
|
trial_id: int,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
|
||||||
|
mylogger.info("")
|
||||||
|
mylogger.info("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
|
||||||
|
mylogger.info("~ TRIAL START ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
|
||||||
|
mylogger.info("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
|
||||||
|
mylogger.info("")
|
||||||
|
|
||||||
|
if device != torch.device("cpu"):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
mylogger.info("Empty CUDA cache")
|
||||||
|
cuda_total_memory: int = torch.cuda.get_device_properties(
|
||||||
|
device.index
|
||||||
|
).total_memory
|
||||||
|
else:
|
||||||
|
cuda_total_memory = 0
|
||||||
|
|
||||||
|
raw_data_path: str = os.path.join(
|
||||||
|
config["basic_path"],
|
||||||
|
config["recoding_data"],
|
||||||
|
config["mouse_identifier"],
|
||||||
|
config["raw_path"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if config["binning_enable"] and (config["binning_at_the_end"] is False):
|
||||||
|
force_to_cpu_memory: bool = True
|
||||||
|
else:
|
||||||
|
force_to_cpu_memory = False
|
||||||
|
|
||||||
|
meta_channels: list[str]
|
||||||
|
meta_mouse_markings: str
|
||||||
|
meta_recording_date: str
|
||||||
|
meta_stimulation_times: dict
|
||||||
|
meta_experiment_names: dict
|
||||||
|
meta_trial_recording_duration: float
|
||||||
|
meta_frame_time: float
|
||||||
|
meta_mouse: str
|
||||||
|
data: torch.Tensor
|
||||||
|
|
||||||
|
(
|
||||||
|
meta_channels,
|
||||||
|
meta_mouse_markings,
|
||||||
|
meta_recording_date,
|
||||||
|
meta_stimulation_times,
|
||||||
|
meta_experiment_names,
|
||||||
|
meta_trial_recording_duration,
|
||||||
|
meta_frame_time,
|
||||||
|
meta_mouse,
|
||||||
|
data,
|
||||||
|
) = data_raw_loader(
|
||||||
|
raw_data_path=raw_data_path,
|
||||||
|
mylogger=mylogger,
|
||||||
|
experiment_id=experiment_id,
|
||||||
|
trial_id=trial_id,
|
||||||
|
device=device,
|
||||||
|
force_to_cpu_memory=force_to_cpu_memory,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
experiment_name: str = f"Exp{experiment_id:03d}_Trial{trial_id:03d}"
|
||||||
|
|
||||||
|
dtype_str = config["dtype"]
|
||||||
|
dtype_np: np.dtype = getattr(np, dtype_str)
|
||||||
|
|
||||||
|
dtype: torch.dtype = data.dtype
|
||||||
|
|
||||||
|
if device != torch.device("cpu"):
|
||||||
|
free_mem = cuda_total_memory - max(
|
||||||
|
[torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)]
|
||||||
|
)
|
||||||
|
mylogger.info(f"CUDA memory: {free_mem//1024} MByte")
|
||||||
|
|
||||||
|
mylogger.info(f"Data shape: {data.shape}")
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
mylogger.info("Finding limit values in the RAW data and mark them for masking")
|
||||||
|
limit: float = (2**16) - 1
|
||||||
|
for i in range(0, data.shape[3]):
|
||||||
|
zero_pixel_mask: torch.Tensor = torch.any(data[..., i] >= limit, dim=-1)
|
||||||
|
data[zero_pixel_mask, :, i] = -100.0
|
||||||
|
mylogger.info(
|
||||||
|
f"{meta_channels[i]}: "
|
||||||
|
f"found {int(zero_pixel_mask.type(dtype=dtype).sum())} pixel "
|
||||||
|
f"with limit values "
|
||||||
|
)
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
mylogger.info("Reference images and mask")
|
||||||
|
|
||||||
|
ref_image_path: str = config["ref_image_path"]
|
||||||
|
|
||||||
|
ref_image_path_acceptor: str = os.path.join(ref_image_path, "acceptor.npy")
|
||||||
|
if os.path.isfile(ref_image_path_acceptor) is False:
|
||||||
|
mylogger.info(f"Could not load ref file: {ref_image_path_acceptor}")
|
||||||
|
assert os.path.isfile(ref_image_path_acceptor)
|
||||||
|
return
|
||||||
|
|
||||||
|
mylogger.info(f"Loading ref file data: {ref_image_path_acceptor}")
|
||||||
|
ref_image_acceptor: torch.Tensor = torch.tensor(
|
||||||
|
np.load(ref_image_path_acceptor).astype(dtype_np), dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
ref_image_path_donor: str = os.path.join(ref_image_path, "donor.npy")
|
||||||
|
if os.path.isfile(ref_image_path_donor) is False:
|
||||||
|
mylogger.info(f"Could not load ref file: {ref_image_path_donor}")
|
||||||
|
assert os.path.isfile(ref_image_path_donor)
|
||||||
|
return
|
||||||
|
|
||||||
|
mylogger.info(f"Loading ref file data: {ref_image_path_donor}")
|
||||||
|
ref_image_donor: torch.Tensor = torch.tensor(
|
||||||
|
np.load(ref_image_path_donor).astype(dtype_np), dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
ref_image_path_oxygenation: str = os.path.join(ref_image_path, "oxygenation.npy")
|
||||||
|
if os.path.isfile(ref_image_path_oxygenation) is False:
|
||||||
|
mylogger.info(f"Could not load ref file: {ref_image_path_oxygenation}")
|
||||||
|
assert os.path.isfile(ref_image_path_oxygenation)
|
||||||
|
return
|
||||||
|
|
||||||
|
mylogger.info(f"Loading ref file data: {ref_image_path_oxygenation}")
|
||||||
|
ref_image_oxygenation: torch.Tensor = torch.tensor(
|
||||||
|
np.load(ref_image_path_oxygenation).astype(dtype_np), dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
ref_image_path_volume: str = os.path.join(ref_image_path, "volume.npy")
|
||||||
|
if os.path.isfile(ref_image_path_volume) is False:
|
||||||
|
mylogger.info(f"Could not load ref file: {ref_image_path_volume}")
|
||||||
|
assert os.path.isfile(ref_image_path_volume)
|
||||||
|
return
|
||||||
|
|
||||||
|
mylogger.info(f"Loading ref file data: {ref_image_path_volume}")
|
||||||
|
ref_image_volume: torch.Tensor = torch.tensor(
|
||||||
|
np.load(ref_image_path_volume).astype(dtype_np), dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
refined_mask_file: str = os.path.join(ref_image_path, "mask_not_rotated.npy")
|
||||||
|
if os.path.isfile(refined_mask_file) is False:
|
||||||
|
mylogger.info(f"Could not load mask file: {refined_mask_file}")
|
||||||
|
assert os.path.isfile(refined_mask_file)
|
||||||
|
return
|
||||||
|
|
||||||
|
mylogger.info(f"Loading mask file data: {refined_mask_file}")
|
||||||
|
mask: torch.Tensor = torch.tensor(
|
||||||
|
np.load(refined_mask_file).astype(dtype_np), dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
if config["binning_enable"] and (config["binning_at_the_end"] is False):
|
||||||
|
mylogger.info("Binning of data")
|
||||||
|
mylogger.info(
|
||||||
|
(
|
||||||
|
f"kernel_size={int(config['binning_kernel_size'])}, "
|
||||||
|
f"stride={int(config['binning_stride'])}, "
|
||||||
|
f"divisor_override={int(config['binning_divisor_override'])}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
data = binning(
|
||||||
|
data,
|
||||||
|
kernel_size=int(config["binning_kernel_size"]),
|
||||||
|
stride=int(config["binning_stride"]),
|
||||||
|
divisor_override=int(config["binning_divisor_override"]),
|
||||||
|
).to(device=device)
|
||||||
|
ref_image_acceptor = (
|
||||||
|
binning(
|
||||||
|
ref_image_acceptor.unsqueeze(-1).unsqueeze(-1),
|
||||||
|
kernel_size=int(config["binning_kernel_size"]),
|
||||||
|
stride=int(config["binning_stride"]),
|
||||||
|
divisor_override=int(config["binning_divisor_override"]),
|
||||||
|
)
|
||||||
|
.squeeze(-1)
|
||||||
|
.squeeze(-1)
|
||||||
|
)
|
||||||
|
ref_image_donor = (
|
||||||
|
binning(
|
||||||
|
ref_image_donor.unsqueeze(-1).unsqueeze(-1),
|
||||||
|
kernel_size=int(config["binning_kernel_size"]),
|
||||||
|
stride=int(config["binning_stride"]),
|
||||||
|
divisor_override=int(config["binning_divisor_override"]),
|
||||||
|
)
|
||||||
|
.squeeze(-1)
|
||||||
|
.squeeze(-1)
|
||||||
|
)
|
||||||
|
ref_image_oxygenation = (
|
||||||
|
binning(
|
||||||
|
ref_image_oxygenation.unsqueeze(-1).unsqueeze(-1),
|
||||||
|
kernel_size=int(config["binning_kernel_size"]),
|
||||||
|
stride=int(config["binning_stride"]),
|
||||||
|
divisor_override=int(config["binning_divisor_override"]),
|
||||||
|
)
|
||||||
|
.squeeze(-1)
|
||||||
|
.squeeze(-1)
|
||||||
|
)
|
||||||
|
ref_image_volume = (
|
||||||
|
binning(
|
||||||
|
ref_image_volume.unsqueeze(-1).unsqueeze(-1),
|
||||||
|
kernel_size=int(config["binning_kernel_size"]),
|
||||||
|
stride=int(config["binning_stride"]),
|
||||||
|
divisor_override=int(config["binning_divisor_override"]),
|
||||||
|
)
|
||||||
|
.squeeze(-1)
|
||||||
|
.squeeze(-1)
|
||||||
|
)
|
||||||
|
mask = (
|
||||||
|
binning(
|
||||||
|
mask.unsqueeze(-1).unsqueeze(-1),
|
||||||
|
kernel_size=int(config["binning_kernel_size"]),
|
||||||
|
stride=int(config["binning_stride"]),
|
||||||
|
divisor_override=int(config["binning_divisor_override"]),
|
||||||
|
)
|
||||||
|
.squeeze(-1)
|
||||||
|
.squeeze(-1)
|
||||||
|
)
|
||||||
|
mylogger.info(f"Data shape: {data.shape}")
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
mylogger.info("Preparing alignment")
|
||||||
|
image_alignment = ImageAlignment(default_dtype=dtype, device=device)
|
||||||
|
|
||||||
|
mylogger.info("Re-order Raw data")
|
||||||
|
data = data.moveaxis(-2, 0).moveaxis(-1, 0)
|
||||||
|
mylogger.info(f"Data shape: {data.shape}")
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
mylogger.info("Alignment of the ref images and the mask")
|
||||||
|
mylogger.info("Ref image of donor stays fixed.")
|
||||||
|
mylogger.info("Ref image of volume and the mask doesn't need to be touched")
|
||||||
|
mylogger.info("Calculate translation and rotation between the reference images")
|
||||||
|
angle_refref, tvec_refref, ref_image_acceptor, ref_image_donor = align_refref(
|
||||||
|
mylogger=mylogger,
|
||||||
|
ref_image_acceptor=ref_image_acceptor,
|
||||||
|
ref_image_donor=ref_image_donor,
|
||||||
|
image_alignment=image_alignment,
|
||||||
|
batch_size=config["alignment_batch_size"],
|
||||||
|
fill_value=-100.0,
|
||||||
|
)
|
||||||
|
mylogger.info(f"Rotation: {round(float(angle_refref[0]),2)} degree")
|
||||||
|
mylogger.info(
|
||||||
|
f"Translation: {round(float(tvec_refref[0]),1)} x {round(float(tvec_refref[1]),1)} pixel"
|
||||||
|
)
|
||||||
|
|
||||||
|
if config["save_alignment"]:
|
||||||
|
temp_path: str = os.path.join(
|
||||||
|
config["export_path"], experiment_name + "_angle_refref.npy"
|
||||||
|
)
|
||||||
|
mylogger.info(f"Save angle to {temp_path}")
|
||||||
|
np.save(temp_path, angle_refref.cpu())
|
||||||
|
|
||||||
|
temp_path = os.path.join(
|
||||||
|
config["export_path"], experiment_name + "_tvec_refref.npy"
|
||||||
|
)
|
||||||
|
mylogger.info(f"Save translation vector to {temp_path}")
|
||||||
|
np.save(temp_path, tvec_refref.cpu())
|
||||||
|
|
||||||
|
mylogger.info("Moving & rotating the oxygenation ref image")
|
||||||
|
ref_image_oxygenation = tv.transforms.functional.affine(
|
||||||
|
img=ref_image_oxygenation.unsqueeze(0),
|
||||||
|
angle=-float(angle_refref),
|
||||||
|
translate=[0, 0],
|
||||||
|
scale=1.0,
|
||||||
|
shear=0,
|
||||||
|
interpolation=tv.transforms.InterpolationMode.BILINEAR,
|
||||||
|
fill=-100.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
ref_image_oxygenation = tv.transforms.functional.affine(
|
||||||
|
img=ref_image_oxygenation,
|
||||||
|
angle=0,
|
||||||
|
translate=[tvec_refref[1], tvec_refref[0]],
|
||||||
|
scale=1.0,
|
||||||
|
shear=0,
|
||||||
|
interpolation=tv.transforms.InterpolationMode.BILINEAR,
|
||||||
|
fill=-100.0,
|
||||||
|
).squeeze(0)
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
mylogger.info("Rotate and translate the acceptor and oxygenation data accordingly")
|
||||||
|
acceptor_index: int = config["required_order"].index("acceptor")
|
||||||
|
donor_index: int = config["required_order"].index("donor")
|
||||||
|
oxygenation_index: int = config["required_order"].index("oxygenation")
|
||||||
|
volume_index: int = config["required_order"].index("volume")
|
||||||
|
|
||||||
|
mylogger.info("Rotate acceptor")
|
||||||
|
data[acceptor_index, ...] = tv.transforms.functional.affine(
|
||||||
|
img=data[acceptor_index, ...],
|
||||||
|
angle=-float(angle_refref),
|
||||||
|
translate=[0, 0],
|
||||||
|
scale=1.0,
|
||||||
|
shear=0,
|
||||||
|
interpolation=tv.transforms.InterpolationMode.BILINEAR,
|
||||||
|
fill=-100.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
mylogger.info("Translate acceptor")
|
||||||
|
data[acceptor_index, ...] = tv.transforms.functional.affine(
|
||||||
|
img=data[acceptor_index, ...],
|
||||||
|
angle=0,
|
||||||
|
translate=[tvec_refref[1], tvec_refref[0]],
|
||||||
|
scale=1.0,
|
||||||
|
shear=0,
|
||||||
|
interpolation=tv.transforms.InterpolationMode.BILINEAR,
|
||||||
|
fill=-100.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
mylogger.info("Rotate oxygenation")
|
||||||
|
data[oxygenation_index, ...] = tv.transforms.functional.affine(
|
||||||
|
img=data[oxygenation_index, ...],
|
||||||
|
angle=-float(angle_refref),
|
||||||
|
translate=[0, 0],
|
||||||
|
scale=1.0,
|
||||||
|
shear=0,
|
||||||
|
interpolation=tv.transforms.InterpolationMode.BILINEAR,
|
||||||
|
fill=-100.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
mylogger.info("Translate oxygenation")
|
||||||
|
data[oxygenation_index, ...] = tv.transforms.functional.affine(
|
||||||
|
img=data[oxygenation_index, ...],
|
||||||
|
angle=0,
|
||||||
|
translate=[tvec_refref[1], tvec_refref[0]],
|
||||||
|
scale=1.0,
|
||||||
|
shear=0,
|
||||||
|
interpolation=tv.transforms.InterpolationMode.BILINEAR,
|
||||||
|
fill=-100.0,
|
||||||
|
)
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
mylogger.info("Perform rotation between donor and volume and its ref images")
|
||||||
|
mylogger.info("for all frames and then rotate all the data accordingly")
|
||||||
|
perform_donor_volume_rotation
|
||||||
|
(
|
||||||
|
data[acceptor_index, ...],
|
||||||
|
data[donor_index, ...],
|
||||||
|
data[oxygenation_index, ...],
|
||||||
|
data[volume_index, ...],
|
||||||
|
angle_donor_volume,
|
||||||
|
) = perform_donor_volume_rotation(
|
||||||
|
mylogger=mylogger,
|
||||||
|
acceptor=data[acceptor_index, ...],
|
||||||
|
donor=data[donor_index, ...],
|
||||||
|
oxygenation=data[oxygenation_index, ...],
|
||||||
|
volume=data[volume_index, ...],
|
||||||
|
ref_image_donor=ref_image_donor,
|
||||||
|
ref_image_volume=ref_image_volume,
|
||||||
|
image_alignment=image_alignment,
|
||||||
|
batch_size=config["alignment_batch_size"],
|
||||||
|
fill_value=-100.0,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
mylogger.info(
|
||||||
|
f"angles: "
|
||||||
|
f"min {round(float(angle_donor_volume.min()),2)} "
|
||||||
|
f"max {round(float(angle_donor_volume.max()),2)} "
|
||||||
|
f"mean {round(float(angle_donor_volume.mean()),2)} "
|
||||||
|
)
|
||||||
|
|
||||||
|
if config["save_alignment"]:
|
||||||
|
temp_path = os.path.join(
|
||||||
|
config["export_path"], experiment_name + "_angle_donor_volume.npy"
|
||||||
|
)
|
||||||
|
mylogger.info(f"Save angles to {temp_path}")
|
||||||
|
np.save(temp_path, angle_donor_volume.cpu())
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
mylogger.info("Perform translation between donor and volume and its ref images")
|
||||||
|
mylogger.info("for all frames and then translate all the data accordingly")
|
||||||
|
(
|
||||||
|
data[acceptor_index, ...],
|
||||||
|
data[donor_index, ...],
|
||||||
|
data[oxygenation_index, ...],
|
||||||
|
data[volume_index, ...],
|
||||||
|
tvec_donor_volume,
|
||||||
|
) = perform_donor_volume_translation(
|
||||||
|
mylogger=mylogger,
|
||||||
|
acceptor=data[acceptor_index, ...],
|
||||||
|
donor=data[donor_index, ...],
|
||||||
|
oxygenation=data[oxygenation_index, ...],
|
||||||
|
volume=data[volume_index, ...],
|
||||||
|
ref_image_donor=ref_image_donor,
|
||||||
|
ref_image_volume=ref_image_volume,
|
||||||
|
image_alignment=image_alignment,
|
||||||
|
batch_size=config["alignment_batch_size"],
|
||||||
|
fill_value=-100.0,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
mylogger.info(
|
||||||
|
f"translation dim 0: "
|
||||||
|
f"min {round(float(tvec_donor_volume[:,0].min()),1)} "
|
||||||
|
f"max {round(float(tvec_donor_volume[:,0].max()),1)} "
|
||||||
|
f"mean {round(float(tvec_donor_volume[:,0].mean()),1)} "
|
||||||
|
)
|
||||||
|
mylogger.info(
|
||||||
|
f"translation dim 1: "
|
||||||
|
f"min {round(float(tvec_donor_volume[:,1].min()),1)} "
|
||||||
|
f"max {round(float(tvec_donor_volume[:,1].max()),1)} "
|
||||||
|
f"mean {round(float(tvec_donor_volume[:,1].mean()),1)} "
|
||||||
|
)
|
||||||
|
|
||||||
|
if config["save_alignment"]:
|
||||||
|
temp_path = os.path.join(
|
||||||
|
config["export_path"], experiment_name + "_tvec_donor_volume.npy"
|
||||||
|
)
|
||||||
|
mylogger.info(f"Save translation vector to {temp_path}")
|
||||||
|
np.save(temp_path, tvec_donor_volume.cpu())
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
mylogger.info("Finding zeros values in the RAW data and mark them for masking")
|
||||||
|
for i in range(0, data.shape[0]):
|
||||||
|
zero_pixel_mask = torch.any(data[i, ...] == 0, dim=0)
|
||||||
|
data[i, :, zero_pixel_mask] = -100.0
|
||||||
|
mylogger.info(
|
||||||
|
f"{config['required_order'][i]}: "
|
||||||
|
f"found {int(zero_pixel_mask.type(dtype=dtype).sum())} pixel "
|
||||||
|
f"with zeros "
|
||||||
|
)
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
mylogger.info("Update mask with the new regions due to alignment")
|
||||||
|
|
||||||
|
new_mask_area: torch.Tensor = torch.any(torch.any(data < -0.1, dim=0), dim=0).bool()
|
||||||
|
mask = (mask == 0).bool()
|
||||||
|
mask = torch.logical_or(mask, new_mask_area)
|
||||||
|
mask_negative: torch.Tensor = mask.clone()
|
||||||
|
mask_positve: torch.Tensor = torch.logical_not(mask)
|
||||||
|
del mask
|
||||||
|
|
||||||
|
mylogger.info("Update the data with the new mask")
|
||||||
|
data *= mask_positve.unsqueeze(0).unsqueeze(0).type(dtype=dtype)
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
mylogger.info("Interpolate the 'in-between' frames for oxygenation and volume")
|
||||||
|
data[oxygenation_index, 1:, ...] = (
|
||||||
|
data[oxygenation_index, 1:, ...] + data[oxygenation_index, :-1, ...]
|
||||||
|
) / 2.0
|
||||||
|
data[volume_index, 1:, ...] = (
|
||||||
|
data[volume_index, 1:, ...] + data[volume_index, :-1, ...]
|
||||||
|
) / 2.0
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
sample_frequency: float = 1.0 / meta_frame_time
|
||||||
|
|
||||||
|
mylogger.info("Extract heartbeat from volume signal")
|
||||||
|
heartbeat_ts: torch.Tensor = bandpass(
|
||||||
|
data=data[volume_index, ...].movedim(0, -1).clone(),
|
||||||
|
device=data.device,
|
||||||
|
low_frequency=config["lower_freqency_bandpass"],
|
||||||
|
high_frequency=config["upper_freqency_bandpass"],
|
||||||
|
fs=sample_frequency,
|
||||||
|
filtfilt_chuck_size=config["heartbeat_filtfilt_chuck_size"],
|
||||||
|
)
|
||||||
|
heartbeat_ts = heartbeat_ts.flatten(start_dim=0, end_dim=-2)
|
||||||
|
mask_flatten: torch.Tensor = mask_positve.flatten(start_dim=0, end_dim=-1)
|
||||||
|
|
||||||
|
heartbeat_ts = heartbeat_ts[mask_flatten, :]
|
||||||
|
|
||||||
|
heartbeat_ts = heartbeat_ts.movedim(0, -1)
|
||||||
|
heartbeat_ts -= heartbeat_ts.mean(dim=0, keepdim=True)
|
||||||
|
|
||||||
|
volume_heartbeat, _, _ = torch.linalg.svd(heartbeat_ts, full_matrices=False)
|
||||||
|
volume_heartbeat = volume_heartbeat[:, 0]
|
||||||
|
volume_heartbeat -= volume_heartbeat[
|
||||||
|
config["skip_frames_in_the_beginning"] :
|
||||||
|
].mean()
|
||||||
|
|
||||||
|
del heartbeat_ts
|
||||||
|
|
||||||
|
if device != torch.device("cpu"):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
mylogger.info("Empty CUDA cache")
|
||||||
|
free_mem = cuda_total_memory - max(
|
||||||
|
[torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)]
|
||||||
|
)
|
||||||
|
mylogger.info(f"CUDA memory: {free_mem//1024} MByte")
|
||||||
|
|
||||||
|
if config["save_heartbeat"]:
|
||||||
|
temp_path = os.path.join(
|
||||||
|
config["export_path"], experiment_name + "_volume_heartbeat.npy"
|
||||||
|
)
|
||||||
|
mylogger.info(f"Save volume heartbeat to {temp_path}")
|
||||||
|
np.save(temp_path, volume_heartbeat.cpu())
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
volume_heartbeat = volume_heartbeat.unsqueeze(0).unsqueeze(0)
|
||||||
|
norm_volume_heartbeat = (
|
||||||
|
volume_heartbeat[..., config["skip_frames_in_the_beginning"] :] ** 2
|
||||||
|
).sum(dim=-1)
|
||||||
|
|
||||||
|
heartbeat_coefficients: torch.Tensor = torch.zeros(
|
||||||
|
(data.shape[0], data.shape[-2], data.shape[-1]),
|
||||||
|
dtype=data.dtype,
|
||||||
|
device=data.device,
|
||||||
|
)
|
||||||
|
for i in range(0, data.shape[0]):
|
||||||
|
y = bandpass(
|
||||||
|
data=data[i, ...].movedim(0, -1).clone(),
|
||||||
|
device=data.device,
|
||||||
|
low_frequency=config["lower_freqency_bandpass"],
|
||||||
|
high_frequency=config["upper_freqency_bandpass"],
|
||||||
|
fs=sample_frequency,
|
||||||
|
filtfilt_chuck_size=config["heartbeat_filtfilt_chuck_size"],
|
||||||
|
)[..., config["skip_frames_in_the_beginning"] :]
|
||||||
|
y -= y.mean(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
heartbeat_coefficients[i, ...] = (
|
||||||
|
volume_heartbeat[..., config["skip_frames_in_the_beginning"] :] * y
|
||||||
|
).sum(dim=-1) / norm_volume_heartbeat
|
||||||
|
|
||||||
|
heartbeat_coefficients[i, ...] *= mask_positve.type(
|
||||||
|
dtype=heartbeat_coefficients.dtype
|
||||||
|
)
|
||||||
|
del y
|
||||||
|
|
||||||
|
if config["save_heartbeat"]:
|
||||||
|
temp_path = os.path.join(
|
||||||
|
config["export_path"], experiment_name + "_heartbeat_coefficients.npy"
|
||||||
|
)
|
||||||
|
mylogger.info(f"Save heartbeat coefficients to {temp_path}")
|
||||||
|
np.save(temp_path, heartbeat_coefficients.cpu())
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
mylogger.info("Remove heart beat from data")
|
||||||
|
data -= heartbeat_coefficients.unsqueeze(1) * volume_heartbeat.unsqueeze(0).movedim(
|
||||||
|
-1, 1
|
||||||
|
)
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
donor_heartbeat_factor = heartbeat_coefficients[donor_index, ...].clone()
|
||||||
|
acceptor_heartbeat_factor = heartbeat_coefficients[acceptor_index, ...].clone()
|
||||||
|
del heartbeat_coefficients
|
||||||
|
|
||||||
|
if device != torch.device("cpu"):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
mylogger.info("Empty CUDA cache")
|
||||||
|
free_mem = cuda_total_memory - max(
|
||||||
|
[torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)]
|
||||||
|
)
|
||||||
|
mylogger.info(f"CUDA memory: {free_mem//1024} MByte")
|
||||||
|
|
||||||
|
mylogger.info("Calculate scaling factor for donor and acceptor")
|
||||||
|
donor_factor: torch.Tensor = (
|
||||||
|
donor_heartbeat_factor + acceptor_heartbeat_factor
|
||||||
|
) / (2 * donor_heartbeat_factor)
|
||||||
|
acceptor_factor: torch.Tensor = (
|
||||||
|
donor_heartbeat_factor + acceptor_heartbeat_factor
|
||||||
|
) / (2 * acceptor_heartbeat_factor)
|
||||||
|
|
||||||
|
del donor_heartbeat_factor
|
||||||
|
del acceptor_heartbeat_factor
|
||||||
|
|
||||||
|
if config["save_factors"]:
|
||||||
|
temp_path = os.path.join(
|
||||||
|
config["export_path"], experiment_name + "_donor_factor.npy"
|
||||||
|
)
|
||||||
|
mylogger.info(f"Save donor factor to {temp_path}")
|
||||||
|
np.save(temp_path, donor_factor.cpu())
|
||||||
|
|
||||||
|
temp_path = os.path.join(
|
||||||
|
config["export_path"], experiment_name + "_acceptor_factor.npy"
|
||||||
|
)
|
||||||
|
mylogger.info(f"Save acceptor factor to {temp_path}")
|
||||||
|
np.save(temp_path, acceptor_factor.cpu())
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
mylogger.info("Scale acceptor to heart beat amplitude")
|
||||||
|
mylogger.info("Calculate mean")
|
||||||
|
mean_values_acceptor = data[
|
||||||
|
acceptor_index, config["skip_frames_in_the_beginning"] :, ...
|
||||||
|
].nanmean(dim=0, keepdim=True)
|
||||||
|
|
||||||
|
mylogger.info("Remove mean")
|
||||||
|
data[acceptor_index, ...] -= mean_values_acceptor
|
||||||
|
mylogger.info("Apply acceptor_factor and mask")
|
||||||
|
data[acceptor_index, ...] *= acceptor_factor.unsqueeze(0) * mask_positve.unsqueeze(
|
||||||
|
0
|
||||||
|
)
|
||||||
|
mylogger.info("Add mean")
|
||||||
|
data[acceptor_index, ...] += mean_values_acceptor
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
mylogger.info("Scale donor to heart beat amplitude")
|
||||||
|
mylogger.info("Calculate mean")
|
||||||
|
mean_values_donor = data[
|
||||||
|
donor_index, config["skip_frames_in_the_beginning"] :, ...
|
||||||
|
].nanmean(dim=0, keepdim=True)
|
||||||
|
mylogger.info("Remove mean")
|
||||||
|
data[donor_index, ...] -= mean_values_donor
|
||||||
|
mylogger.info("Apply donor_factor and mask")
|
||||||
|
data[donor_index, ...] *= donor_factor.unsqueeze(0) * mask_positve.unsqueeze(0)
|
||||||
|
mylogger.info("Add mean")
|
||||||
|
data[donor_index, ...] += mean_values_donor
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
mylogger.info("Divide by mean over time")
|
||||||
|
data /= data[:, config["skip_frames_in_the_beginning"] :, ...].nanmean(
|
||||||
|
dim=1,
|
||||||
|
keepdim=True,
|
||||||
|
)
|
||||||
|
data = data.nan_to_num(nan=0.0)
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
mylogger.info("Preparation for regression -- Gauss smear")
|
||||||
|
spatial_width = float(config["gauss_smear_spatial_width"])
|
||||||
|
|
||||||
|
if config["binning_enable"] and (config["binning_at_the_end"] is False):
|
||||||
|
spatial_width /= float(config["binning_kernel_size"])
|
||||||
|
|
||||||
|
mylogger.info(
|
||||||
|
f"Mask -- "
|
||||||
|
f"spatial width: {spatial_width}, "
|
||||||
|
f"temporal width: {float(config['gauss_smear_temporal_width'])}, "
|
||||||
|
f"use matlab mode: {bool(config['gauss_smear_use_matlab_mask'])} "
|
||||||
|
)
|
||||||
|
|
||||||
|
input_mask = mask_positve.type(dtype=dtype).clone()
|
||||||
|
|
||||||
|
filtered_mask: torch.Tensor
|
||||||
|
filtered_mask, _ = gauss_smear_individual(
|
||||||
|
input=input_mask,
|
||||||
|
spatial_width=spatial_width,
|
||||||
|
temporal_width=float(config["gauss_smear_temporal_width"]),
|
||||||
|
use_matlab_mask=bool(config["gauss_smear_use_matlab_mask"]),
|
||||||
|
epsilon=float(torch.finfo(input_mask.dtype).eps),
|
||||||
|
)
|
||||||
|
|
||||||
|
mylogger.info("creating a copy of the data")
|
||||||
|
data_filtered = data.clone().movedim(1, -1)
|
||||||
|
if device != torch.device("cpu"):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
mylogger.info("Empty CUDA cache")
|
||||||
|
free_mem = cuda_total_memory - max(
|
||||||
|
[torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)]
|
||||||
|
)
|
||||||
|
mylogger.info(f"CUDA memory: {free_mem//1024} MByte")
|
||||||
|
|
||||||
|
overwrite_fft_gauss: None | torch.Tensor = None
|
||||||
|
for i in range(0, data_filtered.shape[0]):
|
||||||
|
mylogger.info(
|
||||||
|
f"{config['required_order'][i]} -- "
|
||||||
|
f"spatial width: {spatial_width}, "
|
||||||
|
f"temporal width: {float(config['gauss_smear_temporal_width'])}, "
|
||||||
|
f"use matlab mode: {bool(config['gauss_smear_use_matlab_mask'])} "
|
||||||
|
)
|
||||||
|
data_filtered[i, ...] *= input_mask.unsqueeze(-1)
|
||||||
|
data_filtered[i, ...], overwrite_fft_gauss = gauss_smear_individual(
|
||||||
|
input=data_filtered[i, ...],
|
||||||
|
spatial_width=spatial_width,
|
||||||
|
temporal_width=float(config["gauss_smear_temporal_width"]),
|
||||||
|
overwrite_fft_gauss=overwrite_fft_gauss,
|
||||||
|
use_matlab_mask=bool(config["gauss_smear_use_matlab_mask"]),
|
||||||
|
epsilon=float(torch.finfo(input_mask.dtype).eps),
|
||||||
|
)
|
||||||
|
|
||||||
|
data_filtered[i, ...] /= filtered_mask + 1e-20
|
||||||
|
data_filtered[i, ...] += 1.0 - input_mask.unsqueeze(-1)
|
||||||
|
|
||||||
|
del filtered_mask
|
||||||
|
del overwrite_fft_gauss
|
||||||
|
del input_mask
|
||||||
|
mylogger.info("data_filtered is populated")
|
||||||
|
|
||||||
|
if device != torch.device("cpu"):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
mylogger.info("Empty CUDA cache")
|
||||||
|
free_mem = cuda_total_memory - max(
|
||||||
|
[torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)]
|
||||||
|
)
|
||||||
|
mylogger.info(f"CUDA memory: {free_mem//1024} MByte")
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
mylogger.info("Preperation for Regression")
|
||||||
|
mylogger.info("Move time dimensions of data to the last dimension")
|
||||||
|
data = data.movedim(1, -1)
|
||||||
|
|
||||||
|
mylogger.info("Regression Acceptor")
|
||||||
|
mylogger.info(f"Target: {config['target_camera_acceptor']}")
|
||||||
|
mylogger.info(
|
||||||
|
f"Regressors: constant, linear and {config['regressor_cameras_acceptor']}"
|
||||||
|
)
|
||||||
|
target_id: int = config["required_order"].index(config["target_camera_acceptor"])
|
||||||
|
regressor_id: list[int] = []
|
||||||
|
for i in range(0, len(config["regressor_cameras_acceptor"])):
|
||||||
|
regressor_id.append(
|
||||||
|
config["required_order"].index(config["regressor_cameras_acceptor"][i])
|
||||||
|
)
|
||||||
|
|
||||||
|
data_acceptor, coefficients_acceptor = regression(
|
||||||
|
mylogger=mylogger,
|
||||||
|
target_camera_id=target_id,
|
||||||
|
regressor_camera_ids=regressor_id,
|
||||||
|
mask=mask_negative,
|
||||||
|
data=data,
|
||||||
|
data_filtered=data_filtered,
|
||||||
|
first_none_ramp_frame=int(config["skip_frames_in_the_beginning"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
if config["save_regression_coefficients"]:
|
||||||
|
temp_path = os.path.join(
|
||||||
|
config["export_path"], experiment_name + "_coefficients_acceptor.npy"
|
||||||
|
)
|
||||||
|
mylogger.info(f"Save acceptor coefficients to {temp_path}")
|
||||||
|
np.save(temp_path, coefficients_acceptor.cpu())
|
||||||
|
del coefficients_acceptor
|
||||||
|
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
mylogger.info("Regression Donor")
|
||||||
|
mylogger.info(f"Target: {config['target_camera_donor']}")
|
||||||
|
mylogger.info(
|
||||||
|
f"Regressors: constant, linear and {config['regressor_cameras_donor']}"
|
||||||
|
)
|
||||||
|
target_id = config["required_order"].index(config["target_camera_donor"])
|
||||||
|
regressor_id = []
|
||||||
|
for i in range(0, len(config["regressor_cameras_donor"])):
|
||||||
|
regressor_id.append(
|
||||||
|
config["required_order"].index(config["regressor_cameras_donor"][i])
|
||||||
|
)
|
||||||
|
|
||||||
|
data_donor, coefficients_donor = regression(
|
||||||
|
mylogger=mylogger,
|
||||||
|
target_camera_id=target_id,
|
||||||
|
regressor_camera_ids=regressor_id,
|
||||||
|
mask=mask_negative,
|
||||||
|
data=data,
|
||||||
|
data_filtered=data_filtered,
|
||||||
|
first_none_ramp_frame=int(config["skip_frames_in_the_beginning"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
if config["save_regression_coefficients"]:
|
||||||
|
temp_path = os.path.join(
|
||||||
|
config["export_path"], experiment_name + "_coefficients_donor.npy"
|
||||||
|
)
|
||||||
|
mylogger.info(f"Save acceptor donor to {temp_path}")
|
||||||
|
np.save(temp_path, coefficients_donor.cpu())
|
||||||
|
del coefficients_donor
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
del data
|
||||||
|
del data_filtered
|
||||||
|
|
||||||
|
if device != torch.device("cpu"):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
mylogger.info("Empty CUDA cache")
|
||||||
|
free_mem = cuda_total_memory - max(
|
||||||
|
[torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)]
|
||||||
|
)
|
||||||
|
mylogger.info(f"CUDA memory: {free_mem//1024} MByte")
|
||||||
|
|
||||||
|
mylogger.info("Calculate ratio sequence")
|
||||||
|
if config["classical_ratio_mode"]:
|
||||||
|
mylogger.info("via acceptor / donor")
|
||||||
|
ratio_sequence: torch.Tensor = data_acceptor / data_donor
|
||||||
|
mylogger.info("via / mean over time")
|
||||||
|
ratio_sequence /= ratio_sequence.mean(dim=-1, keepdim=True)
|
||||||
|
else:
|
||||||
|
mylogger.info("via 1.0 + acceptor - donor")
|
||||||
|
ratio_sequence = 1.0 + data_acceptor - data_donor
|
||||||
|
|
||||||
|
mylogger.info("Remove nan")
|
||||||
|
ratio_sequence = torch.nan_to_num(ratio_sequence, nan=0.0)
|
||||||
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
if config["binning_enable"] and config["binning_at_the_end"]:
|
||||||
|
mylogger.info("Binning of data")
|
||||||
|
mylogger.info(
|
||||||
|
(
|
||||||
|
f"kernel_size={int(config['binning_kernel_size'])}, "
|
||||||
|
f"stride={int(config['binning_stride'])}, "
|
||||||
|
"divisor_override=None"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
ratio_sequence = binning(
|
||||||
|
ratio_sequence.unsqueeze(-1),
|
||||||
|
kernel_size=int(config["binning_kernel_size"]),
|
||||||
|
stride=int(config["binning_stride"]),
|
||||||
|
divisor_override=None,
|
||||||
|
).squeeze(-1)
|
||||||
|
|
||||||
|
mask_positve = (
|
||||||
|
binning(
|
||||||
|
mask_positve.unsqueeze(-1).unsqueeze(-1).type(dtype=dtype),
|
||||||
|
kernel_size=int(config["binning_kernel_size"]),
|
||||||
|
stride=int(config["binning_stride"]),
|
||||||
|
divisor_override=None,
|
||||||
|
)
|
||||||
|
.squeeze(-1)
|
||||||
|
.squeeze(-1)
|
||||||
|
)
|
||||||
|
mask_positve = (mask_positve > 0).type(torch.bool)
|
||||||
|
|
||||||
|
if config["save_as_python"]:
|
||||||
|
temp_path = os.path.join(
|
||||||
|
config["export_path"], experiment_name + "_ratio_sequence.npz"
|
||||||
|
)
|
||||||
|
mylogger.info(f"Save ratio_sequence and mask to {temp_path}")
|
||||||
|
np.savez_compressed(
|
||||||
|
temp_path, ratio_sequence=ratio_sequence.cpu(), mask=mask_positve.cpu()
|
||||||
|
)
|
||||||
|
|
||||||
|
if config["save_as_matlab"]:
|
||||||
|
temp_path = os.path.join(
|
||||||
|
config["export_path"], experiment_name + "_ratio_sequence.hd5"
|
||||||
|
)
|
||||||
|
mylogger.info(f"Save ratio_sequence and mask to {temp_path}")
|
||||||
|
file_handle = h5py.File(temp_path, "w")
|
||||||
|
|
||||||
|
mask_positve = mask_positve.movedim(0, -1)
|
||||||
|
ratio_sequence = ratio_sequence.movedim(1, -1).movedim(0, -1)
|
||||||
|
_ = file_handle.create_dataset(
|
||||||
|
"mask",
|
||||||
|
data=mask_positve.type(torch.uint8).cpu(),
|
||||||
|
compression="gzip",
|
||||||
|
compression_opts=9,
|
||||||
|
)
|
||||||
|
_ = file_handle.create_dataset(
|
||||||
|
"ratio_sequence",
|
||||||
|
data=ratio_sequence.cpu(),
|
||||||
|
compression="gzip",
|
||||||
|
compression_opts=9,
|
||||||
|
)
|
||||||
|
mylogger.info("Reminder: How to read with matlab:")
|
||||||
|
mylogger.info(f"mask = h5read('{temp_path}','/mask');")
|
||||||
|
mylogger.info(f"ratio_sequence = h5read('{temp_path}','/ratio_sequence');")
|
||||||
|
file_handle.close()
|
||||||
|
|
||||||
|
del ratio_sequence
|
||||||
|
del mask_positve
|
||||||
|
del mask_negative
|
||||||
|
|
||||||
|
mylogger.info("")
|
||||||
|
mylogger.info("***********************************************")
|
||||||
|
mylogger.info("* TRIAL END ***********************************")
|
||||||
|
mylogger.info("***********************************************")
|
||||||
|
mylogger.info("")
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
mylogger = create_logger(
|
||||||
|
save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_4"
|
||||||
|
)
|
||||||
|
config = load_config(mylogger=mylogger)
|
||||||
|
|
||||||
|
if (config["save_as_python"] is False) and (config["save_as_matlab"] is False):
|
||||||
|
mylogger.info("No output will be created. ")
|
||||||
|
mylogger.info("Change save_as_python and/or save_as_matlab in the config file")
|
||||||
|
mylogger.info("ERROR: STOP!!!")
|
||||||
|
exit()
|
||||||
|
|
||||||
|
device = get_torch_device(mylogger, config["force_to_cpu"])
|
||||||
|
|
||||||
|
mylogger.info(f"Create directory {config['export_path']} in the case it does not exist")
|
||||||
|
os.makedirs(config["export_path"], exist_ok=True)
|
||||||
|
|
||||||
|
raw_data_path: str = os.path.join(
|
||||||
|
config["basic_path"],
|
||||||
|
config["recoding_data"],
|
||||||
|
config["mouse_identifier"],
|
||||||
|
config["raw_path"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if os.path.isdir(raw_data_path) is False:
|
||||||
|
mylogger.info(f"ERROR: could not find raw directory {raw_data_path}!!!!")
|
||||||
|
exit()
|
||||||
|
|
||||||
|
experiments = get_experiments(raw_data_path)
|
||||||
|
|
||||||
|
for experiment_counter in range(0, experiments.shape[0]):
|
||||||
|
experiment_id = int(experiments[experiment_counter])
|
||||||
|
trials = get_trials(raw_data_path, experiment_id)
|
||||||
|
for trial_counter in range(0, trials.shape[0]):
|
||||||
|
trial_id = int(trials[trial_counter])
|
||||||
|
|
||||||
|
mylogger.info("")
|
||||||
|
mylogger.info(
|
||||||
|
f"======= EXPERIMENT ID: {experiment_id} ==== TRIAL ID: {trial_id} ======="
|
||||||
|
)
|
||||||
|
mylogger.info("")
|
||||||
|
|
||||||
|
process_trial(
|
||||||
|
config=config,
|
||||||
|
mylogger=mylogger,
|
||||||
|
experiment_id=experiment_id,
|
||||||
|
trial_id=trial_id,
|
||||||
|
device=device,
|
||||||
|
)
|
Loading…
Reference in a new issue