gevi/stage_4_process.py

1413 lines
51 KiB
Python

# %%
import numpy as np
import torch
import torchvision as tv # type: ignore
import os
import logging
import h5py # type: ignore
from functions.create_logger import create_logger
from functions.get_torch_device import get_torch_device
from functions.load_config import load_config
from functions.get_experiments import get_experiments
from functions.get_trials import get_trials
from functions.binning import binning
from functions.align_refref import align_refref
from functions.perform_donor_volume_rotation import perform_donor_volume_rotation
from functions.perform_donor_volume_translation import perform_donor_volume_translation
from functions.bandpass import bandpass
from functions.gauss_smear_individual import gauss_smear_individual
from functions.regression import regression
from functions.data_raw_loader import data_raw_loader
import argh
@torch.no_grad()
def process_trial(
config: dict,
mylogger: logging.Logger,
experiment_id: int,
trial_id: int,
device: torch.device,
):
mylogger.info("")
mylogger.info("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
mylogger.info("~ TRIAL START ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
mylogger.info("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
mylogger.info("")
if device != torch.device("cpu"):
torch.cuda.empty_cache()
mylogger.info("Empty CUDA cache")
cuda_total_memory: int = torch.cuda.get_device_properties(
device.index
).total_memory
else:
cuda_total_memory = 0
mylogger.info("")
mylogger.info("(A) LOADING DATA, REFERENCE, AND MASK")
mylogger.info("-----------------------------------------------")
mylogger.info("")
raw_data_path: str = os.path.join(
config["basic_path"],
config["recoding_data"],
config["mouse_identifier"],
config["raw_path"],
)
if config["binning_enable"] and (config["binning_at_the_end"] is False):
force_to_cpu_memory: bool = True
else:
force_to_cpu_memory = False
meta_channels: list[str]
meta_mouse_markings: str
meta_recording_date: str
meta_stimulation_times: dict
meta_experiment_names: dict
meta_trial_recording_duration: float
meta_frame_time: float
meta_mouse: str
data: torch.Tensor
(
meta_channels,
meta_mouse_markings,
meta_recording_date,
meta_stimulation_times,
meta_experiment_names,
meta_trial_recording_duration,
meta_frame_time,
meta_mouse,
data,
) = data_raw_loader(
raw_data_path=raw_data_path,
mylogger=mylogger,
experiment_id=experiment_id,
trial_id=trial_id,
device=device,
force_to_cpu_memory=force_to_cpu_memory,
config=config,
)
experiment_name: str = f"Exp{experiment_id:03d}_Trial{trial_id:03d}"
dtype_str = config["dtype"]
dtype_np: np.dtype = getattr(np, dtype_str)
dtype: torch.dtype = data.dtype
if device != torch.device("cpu"):
free_mem = cuda_total_memory - max(
[torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)]
)
mylogger.info(f"CUDA memory: {free_mem // 1024} MByte")
mylogger.info(f"Data shape: {data.shape}")
mylogger.info("-==- Done -==-")
mylogger.info("Finding limit values in the RAW data and mark them for masking")
limit: float = (2**16) - 1
for i in range(0, data.shape[3]):
zero_pixel_mask: torch.Tensor = torch.any(data[..., i] >= limit, dim=-1)
data[zero_pixel_mask, :, i] = -100.0
mylogger.info(
f"{meta_channels[i]}: "
f"found {int(zero_pixel_mask.type(dtype=dtype).sum())} pixel "
f"with limit values "
)
mylogger.info("-==- Done -==-")
mylogger.info("Reference images and mask")
ref_image_path: str = config["ref_image_path"]
ref_image_path_acceptor: str = os.path.join(ref_image_path, "acceptor.npy")
if os.path.isfile(ref_image_path_acceptor) is False:
mylogger.info(f"Could not load ref file: {ref_image_path_acceptor}")
assert os.path.isfile(ref_image_path_acceptor)
return
mylogger.info(f"Loading ref file data: {ref_image_path_acceptor}")
ref_image_acceptor: torch.Tensor = torch.tensor(
np.load(ref_image_path_acceptor).astype(dtype_np),
dtype=dtype,
device=data.device,
)
ref_image_path_donor: str = os.path.join(ref_image_path, "donor.npy")
if os.path.isfile(ref_image_path_donor) is False:
mylogger.info(f"Could not load ref file: {ref_image_path_donor}")
assert os.path.isfile(ref_image_path_donor)
return
mylogger.info(f"Loading ref file data: {ref_image_path_donor}")
ref_image_donor: torch.Tensor = torch.tensor(
np.load(ref_image_path_donor).astype(dtype_np), dtype=dtype, device=data.device
)
ref_image_path_oxygenation: str = os.path.join(ref_image_path, "oxygenation.npy")
if os.path.isfile(ref_image_path_oxygenation) is False:
mylogger.info(f"Could not load ref file: {ref_image_path_oxygenation}")
assert os.path.isfile(ref_image_path_oxygenation)
return
mylogger.info(f"Loading ref file data: {ref_image_path_oxygenation}")
ref_image_oxygenation: torch.Tensor = torch.tensor(
np.load(ref_image_path_oxygenation).astype(dtype_np),
dtype=dtype,
device=data.device,
)
ref_image_path_volume: str = os.path.join(ref_image_path, "volume.npy")
if os.path.isfile(ref_image_path_volume) is False:
mylogger.info(f"Could not load ref file: {ref_image_path_volume}")
assert os.path.isfile(ref_image_path_volume)
return
mylogger.info(f"Loading ref file data: {ref_image_path_volume}")
ref_image_volume: torch.Tensor = torch.tensor(
np.load(ref_image_path_volume).astype(dtype_np), dtype=dtype, device=data.device
)
refined_mask_file: str = os.path.join(ref_image_path, "mask_not_rotated.npy")
if os.path.isfile(refined_mask_file) is False:
mylogger.info(f"Could not load mask file: {refined_mask_file}")
assert os.path.isfile(refined_mask_file)
return
mylogger.info(f"Loading mask file data: {refined_mask_file}")
mask: torch.Tensor = torch.tensor(
np.load(refined_mask_file).astype(dtype_np), dtype=dtype, device=data.device
)
mylogger.info("-==- Done -==-")
if config["binning_enable"] and (config["binning_at_the_end"] is False):
mylogger.info("")
mylogger.info("(B-OPTIONAL) BINNING")
mylogger.info("-----------------------------------------------")
mylogger.info("")
mylogger.info("Binning of data")
mylogger.info(
(
f"kernel_size={int(config['binning_kernel_size'])}, "
f"stride={int(config['binning_stride'])}, "
f"divisor_override={int(config['binning_divisor_override'])}"
)
)
data = binning(
data,
kernel_size=int(config["binning_kernel_size"]),
stride=int(config["binning_stride"]),
divisor_override=int(config["binning_divisor_override"]),
).to(device=data.device)
ref_image_acceptor = (
binning(
ref_image_acceptor.unsqueeze(-1).unsqueeze(-1),
kernel_size=int(config["binning_kernel_size"]),
stride=int(config["binning_stride"]),
divisor_override=int(config["binning_divisor_override"]),
)
.squeeze(-1)
.squeeze(-1)
)
ref_image_donor = (
binning(
ref_image_donor.unsqueeze(-1).unsqueeze(-1),
kernel_size=int(config["binning_kernel_size"]),
stride=int(config["binning_stride"]),
divisor_override=int(config["binning_divisor_override"]),
)
.squeeze(-1)
.squeeze(-1)
)
ref_image_oxygenation = (
binning(
ref_image_oxygenation.unsqueeze(-1).unsqueeze(-1),
kernel_size=int(config["binning_kernel_size"]),
stride=int(config["binning_stride"]),
divisor_override=int(config["binning_divisor_override"]),
)
.squeeze(-1)
.squeeze(-1)
)
ref_image_volume = (
binning(
ref_image_volume.unsqueeze(-1).unsqueeze(-1),
kernel_size=int(config["binning_kernel_size"]),
stride=int(config["binning_stride"]),
divisor_override=int(config["binning_divisor_override"]),
)
.squeeze(-1)
.squeeze(-1)
)
mask = (
binning(
mask.unsqueeze(-1).unsqueeze(-1),
kernel_size=int(config["binning_kernel_size"]),
stride=int(config["binning_stride"]),
divisor_override=int(config["binning_divisor_override"]),
)
.squeeze(-1)
.squeeze(-1)
)
mylogger.info(f"Data shape: {data.shape}")
mylogger.info("-==- Done -==-")
mylogger.info("")
mylogger.info("(C) ALIGNMENT OF SECOND TO FIRST CAMERA")
mylogger.info("-----------------------------------------------")
mylogger.info("")
mylogger.info("Preparing alignment")
mylogger.info("Re-order Raw data")
data = data.moveaxis(-2, 0).moveaxis(-1, 0)
mylogger.info(f"Data shape: {data.shape}")
mylogger.info("-==- Done -==-")
mylogger.info("Alignment of the ref images and the mask")
mylogger.info("Ref image of donor stays fixed.")
mylogger.info("Ref image of volume and the mask doesn't need to be touched")
mylogger.info("Calculate translation and rotation between the reference images")
angle_refref, tvec_refref, ref_image_acceptor, ref_image_donor = align_refref(
mylogger=mylogger,
ref_image_acceptor=ref_image_acceptor,
ref_image_donor=ref_image_donor,
batch_size=config["alignment_batch_size"],
fill_value=-100.0,
)
mylogger.info(f"Rotation: {round(float(angle_refref[0]), 2)} degree")
mylogger.info(
f"Translation: {round(float(tvec_refref[0]), 1)} x {round(float(tvec_refref[1]), 1)} pixel"
)
if config["save_alignment"]:
temp_path: str = os.path.join(
config["export_path"], experiment_name + "_angle_refref.npy"
)
mylogger.info(f"Save angle to {temp_path}")
np.save(temp_path, angle_refref.cpu())
temp_path = os.path.join(
config["export_path"], experiment_name + "_tvec_refref.npy"
)
mylogger.info(f"Save translation vector to {temp_path}")
np.save(temp_path, tvec_refref.cpu())
mylogger.info("Moving & rotating the oxygenation ref image")
ref_image_oxygenation = tv.transforms.functional.affine( # type: ignore
img=ref_image_oxygenation.unsqueeze(0),
angle=-float(angle_refref),
translate=[0, 0],
scale=1.0,
shear=0,
interpolation=tv.transforms.InterpolationMode.BILINEAR,
fill=-100.0,
)
ref_image_oxygenation = tv.transforms.functional.affine( # type: ignore
img=ref_image_oxygenation,
angle=0,
translate=[tvec_refref[1], tvec_refref[0]],
scale=1.0,
shear=0,
interpolation=tv.transforms.InterpolationMode.BILINEAR,
fill=-100.0,
).squeeze(0)
mylogger.info("-==- Done -==-")
mylogger.info("Rotate and translate the acceptor and oxygenation data accordingly")
acceptor_index: int = config["required_order"].index("acceptor")
donor_index: int = config["required_order"].index("donor")
oxygenation_index: int = config["required_order"].index("oxygenation")
volume_index: int = config["required_order"].index("volume")
mylogger.info("Rotate acceptor")
data[acceptor_index, ...] = tv.transforms.functional.affine( # type: ignore
img=data[acceptor_index, ...], # type: ignore
angle=-float(angle_refref),
translate=[0, 0],
scale=1.0,
shear=0,
interpolation=tv.transforms.InterpolationMode.BILINEAR,
fill=-100.0,
)
mylogger.info("Translate acceptor")
data[acceptor_index, ...] = tv.transforms.functional.affine( # type: ignore
img=data[acceptor_index, ...],
angle=0,
translate=[tvec_refref[1], tvec_refref[0]],
scale=1.0,
shear=0,
interpolation=tv.transforms.InterpolationMode.BILINEAR,
fill=-100.0,
)
mylogger.info("Rotate oxygenation")
data[oxygenation_index, ...] = tv.transforms.functional.affine( # type: ignore
img=data[oxygenation_index, ...],
angle=-float(angle_refref),
translate=[0, 0],
scale=1.0,
shear=0,
interpolation=tv.transforms.InterpolationMode.BILINEAR,
fill=-100.0,
)
mylogger.info("Translate oxygenation")
data[oxygenation_index, ...] = tv.transforms.functional.affine( # type: ignore
img=data[oxygenation_index, ...],
angle=0,
translate=[tvec_refref[1], tvec_refref[0]],
scale=1.0,
shear=0,
interpolation=tv.transforms.InterpolationMode.BILINEAR,
fill=-100.0,
)
mylogger.info("-==- Done -==-")
mylogger.info("Perform rotation between donor and volume and its ref images")
mylogger.info("for all frames and then rotate all the data accordingly")
(
data[acceptor_index, ...],
data[donor_index, ...],
data[oxygenation_index, ...],
data[volume_index, ...],
angle_donor_volume,
) = perform_donor_volume_rotation(
mylogger=mylogger,
acceptor=data[acceptor_index, ...],
donor=data[donor_index, ...],
oxygenation=data[oxygenation_index, ...],
volume=data[volume_index, ...],
ref_image_donor=ref_image_donor,
ref_image_volume=ref_image_volume,
batch_size=config["alignment_batch_size"],
fill_value=-100.0,
config=config,
)
mylogger.info(
f"angles: "
f"min {round(float(angle_donor_volume.min()), 2)} "
f"max {round(float(angle_donor_volume.max()), 2)} "
f"mean {round(float(angle_donor_volume.mean()), 2)} "
)
if config["save_alignment"]:
temp_path = os.path.join(
config["export_path"], experiment_name + "_angle_donor_volume.npy"
)
mylogger.info(f"Save angles to {temp_path}")
np.save(temp_path, angle_donor_volume.cpu())
mylogger.info("-==- Done -==-")
mylogger.info("Perform translation between donor and volume and its ref images")
mylogger.info("for all frames and then translate all the data accordingly")
(
data[acceptor_index, ...],
data[donor_index, ...],
data[oxygenation_index, ...],
data[volume_index, ...],
tvec_donor_volume,
) = perform_donor_volume_translation(
mylogger=mylogger,
acceptor=data[acceptor_index, ...],
donor=data[donor_index, ...],
oxygenation=data[oxygenation_index, ...],
volume=data[volume_index, ...],
ref_image_donor=ref_image_donor,
ref_image_volume=ref_image_volume,
batch_size=config["alignment_batch_size"],
fill_value=-100.0,
config=config,
)
mylogger.info(
f"translation dim 0: "
f"min {round(float(tvec_donor_volume[:, 0].min()), 1)} "
f"max {round(float(tvec_donor_volume[:, 0].max()), 1)} "
f"mean {round(float(tvec_donor_volume[:, 0].mean()), 1)} "
)
mylogger.info(
f"translation dim 1: "
f"min {round(float(tvec_donor_volume[:, 1].min()), 1)} "
f"max {round(float(tvec_donor_volume[:, 1].max()), 1)} "
f"mean {round(float(tvec_donor_volume[:, 1].mean()), 1)} "
)
if config["save_alignment"]:
temp_path = os.path.join(
config["export_path"], experiment_name + "_tvec_donor_volume.npy"
)
mylogger.info(f"Save translation vector to {temp_path}")
np.save(temp_path, tvec_donor_volume.cpu())
mylogger.info("-==- Done -==-")
mylogger.info("Finding zeros values in the RAW data and mark them for masking")
for i in range(0, data.shape[0]):
zero_pixel_mask = torch.any(data[i, ...] == 0, dim=0)
data[i, :, zero_pixel_mask] = -100.0
mylogger.info(
f"{config['required_order'][i]}: "
f"found {int(zero_pixel_mask.type(dtype=dtype).sum())} pixel "
f"with zeros "
)
mylogger.info("-==- Done -==-")
mylogger.info("Update mask with the new regions due to alignment")
new_mask_area: torch.Tensor = torch.any(torch.any(data < -0.1, dim=0), dim=0).bool()
mask = (mask == 0).bool()
mask = torch.logical_or(mask, new_mask_area)
mask_negative: torch.Tensor = mask.clone()
mask_positve: torch.Tensor = torch.logical_not(mask)
del mask
mylogger.info("Update the data with the new mask")
data *= mask_positve.unsqueeze(0).unsqueeze(0).type(dtype=dtype)
mylogger.info("-==- Done -==-")
if config["save_aligned_as_python"]:
temp_path = os.path.join(
config["export_path"], experiment_name + "_aligned.npz"
)
mylogger.info(f"Save aligned data and mask to {temp_path}")
np.savez_compressed(
temp_path,
data=data.cpu(),
mask=mask_positve.cpu(),
acceptor_index=acceptor_index,
donor_index=donor_index,
oxygenation_index=oxygenation_index,
volume_index=volume_index,
)
if config["save_aligned_as_matlab"]:
temp_path = os.path.join(
config["export_path"], experiment_name + "_aligned.hd5"
)
mylogger.info(f"Save aligned data and mask to {temp_path}")
file_handle = h5py.File(temp_path, "w")
_ = file_handle.create_dataset(
"mask",
data=mask_positve.movedim(0, -1).type(torch.uint8).cpu(),
compression="gzip",
compression_opts=9,
)
_ = file_handle.create_dataset(
"data",
data=data.movedim(1, -1).movedim(0, -1).cpu(),
compression="gzip",
compression_opts=9,
)
_ = file_handle.create_dataset(
"acceptor_index",
data=torch.tensor((acceptor_index,)),
compression="gzip",
compression_opts=9,
)
_ = file_handle.create_dataset(
"donor_index",
data=torch.tensor((donor_index,)),
compression="gzip",
compression_opts=9,
)
_ = file_handle.create_dataset(
"oxygenation_index",
data=torch.tensor((oxygenation_index,)),
compression="gzip",
compression_opts=9,
)
_ = file_handle.create_dataset(
"volume_index",
data=torch.tensor((volume_index,)),
compression="gzip",
compression_opts=9,
)
mylogger.info("Reminder: How to read with matlab:")
mylogger.info(f"mask = h5read('{temp_path}','/mask');")
mylogger.info(f"data_acceptor = h5read('{temp_path}','/data');")
file_handle.close()
mylogger.info("")
mylogger.info("(D) INTER-FRAME INTERPOLATION")
mylogger.info("-----------------------------------------------")
mylogger.info("")
mylogger.info("Interpolate the 'in-between' frames for oxygenation and volume")
data[oxygenation_index, 1:, ...] = (
data[oxygenation_index, 1:, ...] + data[oxygenation_index, :-1, ...]
) / 2.0
data[volume_index, 1:, ...] = (
data[volume_index, 1:, ...] + data[volume_index, :-1, ...]
) / 2.0
mylogger.info("-==- Done -==-")
sample_frequency: float = 1.0 / meta_frame_time
if config["gevi"]:
assert config["heartbeat_remove"]
if config["heartbeat_remove"]:
mylogger.info("")
mylogger.info("(E-OPTIONAL) HEARTBEAT REMOVAL")
mylogger.info("-----------------------------------------------")
mylogger.info("")
mylogger.info("Extract heartbeat from volume signal")
heartbeat_ts: torch.Tensor = bandpass(
data=data[volume_index, ...].movedim(0, -1).clone(),
low_frequency=config["lower_freqency_bandpass"],
high_frequency=config["upper_freqency_bandpass"],
fs=sample_frequency,
filtfilt_chuck_size=config["heartbeat_filtfilt_chuck_size"],
)
heartbeat_ts = heartbeat_ts.flatten(start_dim=0, end_dim=-2)
mask_flatten: torch.Tensor = mask_positve.flatten(start_dim=0, end_dim=-1)
heartbeat_ts = heartbeat_ts[mask_flatten, :]
heartbeat_ts = heartbeat_ts.movedim(0, -1)
heartbeat_ts -= heartbeat_ts.mean(dim=0, keepdim=True)
try:
volume_heartbeat, _, _ = torch.linalg.svd(heartbeat_ts, full_matrices=False)
except torch.cuda.OutOfMemoryError:
mylogger.info("torch.cuda.OutOfMemoryError: Fallback to cpu")
volume_heartbeat_cpu, _, _ = torch.linalg.svd(
heartbeat_ts.cpu(), full_matrices=False
)
volume_heartbeat = volume_heartbeat_cpu.to(heartbeat_ts.device, copy=True)
del volume_heartbeat_cpu
volume_heartbeat = volume_heartbeat[:, 0]
volume_heartbeat -= volume_heartbeat[
config["skip_frames_in_the_beginning"] :
].mean()
del heartbeat_ts
if device != torch.device("cpu"):
torch.cuda.empty_cache()
mylogger.info("Empty CUDA cache")
free_mem = cuda_total_memory - max(
[
torch.cuda.memory_reserved(device),
torch.cuda.memory_allocated(device),
]
)
mylogger.info(f"CUDA memory: {free_mem // 1024} MByte")
if config["save_heartbeat"]:
temp_path = os.path.join(
config["export_path"], experiment_name + "_volume_heartbeat.npy"
)
mylogger.info(f"Save volume heartbeat to {temp_path}")
np.save(temp_path, volume_heartbeat.cpu())
mylogger.info("-==- Done -==-")
volume_heartbeat = volume_heartbeat.unsqueeze(0).unsqueeze(0)
norm_volume_heartbeat = (
volume_heartbeat[..., config["skip_frames_in_the_beginning"] :] ** 2
).sum(dim=-1)
heartbeat_coefficients: torch.Tensor = torch.zeros(
(data.shape[0], data.shape[-2], data.shape[-1]),
dtype=data.dtype,
device=data.device,
)
for i in range(0, data.shape[0]):
y = bandpass(
data=data[i, ...].movedim(0, -1).clone(),
low_frequency=config["lower_freqency_bandpass"],
high_frequency=config["upper_freqency_bandpass"],
fs=sample_frequency,
filtfilt_chuck_size=config["heartbeat_filtfilt_chuck_size"],
)[..., config["skip_frames_in_the_beginning"] :]
y -= y.mean(dim=-1, keepdim=True)
heartbeat_coefficients[i, ...] = (
volume_heartbeat[..., config["skip_frames_in_the_beginning"] :] * y
).sum(dim=-1) / norm_volume_heartbeat
heartbeat_coefficients[i, ...] *= mask_positve.type(
dtype=heartbeat_coefficients.dtype
)
del y
if config["save_heartbeat"]:
temp_path = os.path.join(
config["export_path"], experiment_name + "_heartbeat_coefficients.npy"
)
mylogger.info(f"Save heartbeat coefficients to {temp_path}")
np.save(temp_path, heartbeat_coefficients.cpu())
mylogger.info("-==- Done -==-")
mylogger.info("Remove heart beat from data")
data -= heartbeat_coefficients.unsqueeze(1) * volume_heartbeat.unsqueeze(
0
).movedim(-1, 1)
# data_herzlos = data.clone()
mylogger.info("-==- Done -==-")
if config["gevi"]: # UDO scaling performed!
mylogger.info("")
mylogger.info("(F-OPTIONAL) DONOR/ACCEPTOR SCALING")
mylogger.info("-----------------------------------------------")
mylogger.info("")
donor_heartbeat_factor = heartbeat_coefficients[donor_index, ...].clone()
acceptor_heartbeat_factor = heartbeat_coefficients[
acceptor_index, ...
].clone()
del heartbeat_coefficients
if device != torch.device("cpu"):
torch.cuda.empty_cache()
mylogger.info("Empty CUDA cache")
free_mem = cuda_total_memory - max(
[
torch.cuda.memory_reserved(device),
torch.cuda.memory_allocated(device),
]
)
mylogger.info(f"CUDA memory: {free_mem // 1024} MByte")
mylogger.info("Calculate scaling factor for donor and acceptor")
# donor_factor: torch.Tensor = (
# donor_heartbeat_factor + acceptor_heartbeat_factor
# ) / (2 * donor_heartbeat_factor)
# acceptor_factor: torch.Tensor = (
# donor_heartbeat_factor + acceptor_heartbeat_factor
# ) / (2 * acceptor_heartbeat_factor)
donor_factor = torch.sqrt(
acceptor_heartbeat_factor / donor_heartbeat_factor
)
acceptor_factor = 1 / donor_factor
# import matplotlib.pyplot as plt
# plt.pcolor(donor_factor, vmin=0.5, vmax=2.0)
# plt.colorbar()
# plt.show()
# plt.pcolor(acceptor_factor, vmin=0.5, vmax=2.0)
# plt.colorbar()
# plt.show()
# TODO remove
del donor_heartbeat_factor
del acceptor_heartbeat_factor
# import matplotlib.pyplot as plt
# plt.pcolor(torch.std(data[acceptor_index, config["skip_frames_in_the_beginning"] :, ...], axis=0), vmin=0, vmax=500)
# plt.colorbar()
# plt.show()
# plt.pcolor(torch.std(data[donor_index, config["skip_frames_in_the_beginning"] :, ...], axis=0), vmin=0, vmax=500)
# plt.colorbar()
# plt.show()
# TODO remove
if config["save_factors"]:
temp_path = os.path.join(
config["export_path"], experiment_name + "_donor_factor.npy"
)
mylogger.info(f"Save donor factor to {temp_path}")
np.save(temp_path, donor_factor.cpu())
temp_path = os.path.join(
config["export_path"], experiment_name + "_acceptor_factor.npy"
)
mylogger.info(f"Save acceptor factor to {temp_path}")
np.save(temp_path, acceptor_factor.cpu())
mylogger.info("-==- Done -==-")
# TODO we have to calculate means first!
mylogger.info("Extract means for acceptor and donor first")
mean_values_acceptor = data[
acceptor_index, config["skip_frames_in_the_beginning"] :, ...
].nanmean(dim=0, keepdim=True)
mean_values_donor = data[
donor_index, config["skip_frames_in_the_beginning"] :, ...
].nanmean(dim=0, keepdim=True)
mylogger.info("Scale acceptor to heart beat amplitude")
mylogger.info("Remove mean")
data[acceptor_index, ...] -= mean_values_acceptor
mylogger.info("Apply acceptor_factor and mask")
# data[acceptor_index, ...] *= acceptor_factor.unsqueeze(
# 0
# ) * mask_positve.unsqueeze(0)
acceptor_factor_correction = np.sqrt(
mean_values_acceptor / mean_values_donor
)
data[acceptor_index, ...] *= acceptor_factor.unsqueeze(
0
) * acceptor_factor_correction * mask_positve.unsqueeze(0)
mylogger.info("Add mean")
data[acceptor_index, ...] += mean_values_acceptor
mylogger.info("-==- Done -==-")
mylogger.info("Scale donor to heart beat amplitude")
mylogger.info("Remove mean")
data[donor_index, ...] -= mean_values_donor
mylogger.info("Apply donor_factor and mask")
# data[donor_index, ...] *= donor_factor.unsqueeze(
# 0
# ) * mask_positve.unsqueeze(0)
donor_factor_correction = 1 / acceptor_factor_correction
data[donor_index, ...] *= donor_factor.unsqueeze(
0
) * donor_factor_correction * mask_positve.unsqueeze(0)
mylogger.info("Add mean")
data[donor_index, ...] += mean_values_donor
mylogger.info("-==- Done -==-")
# import matplotlib.pyplot as plt
# plt.pcolor(mean_values_acceptor[0])
# plt.colorbar()
# plt.show()
# plt.pcolor(mean_values_donor[0])
# plt.colorbar()
# plt.show()
# TODO remove
# TODO SCHNUGGEL
else:
mylogger.info("GECI does not require acceptor/donor scaling, skipping!")
mylogger.info("-==- Done -==-")
mylogger.info("")
mylogger.info("(G) CONVERSION TO RELATIVE SIGNAL CHANGES (DIV/MEAN)")
mylogger.info("-----------------------------------------------")
mylogger.info("")
mylogger.info("Divide by mean over time")
data /= data[:, config["skip_frames_in_the_beginning"] :, ...].nanmean(
dim=1,
keepdim=True,
)
mylogger.info("-==- Done -==-")
mylogger.info("")
mylogger.info("(H) CLEANING BY REGRESSION")
mylogger.info("-----------------------------------------------")
mylogger.info("")
data = data.nan_to_num(nan=0.0)
mylogger.info("Preparation for regression -- Gauss smear")
spatial_width = float(config["gauss_smear_spatial_width"])
if config["binning_enable"] and (config["binning_at_the_end"] is False):
spatial_width /= float(config["binning_kernel_size"])
mylogger.info(
f"Mask -- "
f"spatial width: {spatial_width}, "
f"temporal width: {float(config['gauss_smear_temporal_width'])}, "
f"use matlab mode: {bool(config['gauss_smear_use_matlab_mask'])} "
)
input_mask = mask_positve.type(dtype=dtype).clone()
filtered_mask: torch.Tensor
filtered_mask, _ = gauss_smear_individual(
input=input_mask,
spatial_width=spatial_width,
temporal_width=float(config["gauss_smear_temporal_width"]),
use_matlab_mask=bool(config["gauss_smear_use_matlab_mask"]),
epsilon=float(torch.finfo(input_mask.dtype).eps),
)
mylogger.info("creating a copy of the data")
data_filtered = data.clone().movedim(1, -1)
if device != torch.device("cpu"):
torch.cuda.empty_cache()
mylogger.info("Empty CUDA cache")
free_mem = cuda_total_memory - max(
[torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)]
)
mylogger.info(f"CUDA memory: {free_mem // 1024} MByte")
overwrite_fft_gauss: None | torch.Tensor = None
for i in range(0, data_filtered.shape[0]):
mylogger.info(
f"{config['required_order'][i]} -- "
f"spatial width: {spatial_width}, "
f"temporal width: {float(config['gauss_smear_temporal_width'])}, "
f"use matlab mode: {bool(config['gauss_smear_use_matlab_mask'])} "
)
data_filtered[i, ...] *= input_mask.unsqueeze(-1)
data_filtered[i, ...], overwrite_fft_gauss = gauss_smear_individual(
input=data_filtered[i, ...],
spatial_width=spatial_width,
temporal_width=float(config["gauss_smear_temporal_width"]),
overwrite_fft_gauss=overwrite_fft_gauss,
use_matlab_mask=bool(config["gauss_smear_use_matlab_mask"]),
epsilon=float(torch.finfo(input_mask.dtype).eps),
)
data_filtered[i, ...] /= filtered_mask + 1e-20
data_filtered[i, ...] += 1.0 - input_mask.unsqueeze(-1)
del filtered_mask
del overwrite_fft_gauss
del input_mask
mylogger.info("data_filtered is populated")
if device != torch.device("cpu"):
torch.cuda.empty_cache()
mylogger.info("Empty CUDA cache")
free_mem = cuda_total_memory - max(
[torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)]
)
mylogger.info(f"CUDA memory: {free_mem // 1024} MByte")
mylogger.info("-==- Done -==-")
mylogger.info("Preperation for Regression")
mylogger.info("Move time dimensions of data to the last dimension")
data = data.movedim(1, -1)
dual_signal_mode: bool = True
if len(config["target_camera_acceptor"]) > 0:
mylogger.info("Regression Acceptor")
mylogger.info(f"Target: {config['target_camera_acceptor']}")
mylogger.info(
f"Regressors: constant, linear and {config['regressor_cameras_acceptor']}"
)
target_id: int = config["required_order"].index(
config["target_camera_acceptor"]
)
regressor_id: list[int] = []
for i in range(0, len(config["regressor_cameras_acceptor"])):
regressor_id.append(
config["required_order"].index(config["regressor_cameras_acceptor"][i])
)
data_acceptor, coefficients_acceptor = regression(
mylogger=mylogger,
target_camera_id=target_id,
regressor_camera_ids=regressor_id,
mask=mask_negative,
data=data,
data_filtered=data_filtered,
first_none_ramp_frame=int(config["skip_frames_in_the_beginning"]),
)
if config["save_regression_coefficients"]:
temp_path = os.path.join(
config["export_path"], experiment_name + "_coefficients_acceptor.npy"
)
mylogger.info(f"Save acceptor coefficients to {temp_path}")
np.save(temp_path, coefficients_acceptor.cpu())
del coefficients_acceptor
mylogger.info("-==- Done -==-")
else:
dual_signal_mode = False
target_id = config["required_order"].index("acceptor")
data_acceptor = data[target_id, ...].clone()
data_acceptor[mask_negative, :] = 0.0
if len(config["target_camera_donor"]) > 0:
mylogger.info("Regression Donor")
mylogger.info(f"Target: {config['target_camera_donor']}")
mylogger.info(
f"Regressors: constant, linear and {config['regressor_cameras_donor']}"
)
target_id = config["required_order"].index(config["target_camera_donor"])
regressor_id = []
for i in range(0, len(config["regressor_cameras_donor"])):
regressor_id.append(
config["required_order"].index(config["regressor_cameras_donor"][i])
)
data_donor, coefficients_donor = regression(
mylogger=mylogger,
target_camera_id=target_id,
regressor_camera_ids=regressor_id,
mask=mask_negative,
data=data,
data_filtered=data_filtered,
first_none_ramp_frame=int(config["skip_frames_in_the_beginning"]),
)
if config["save_regression_coefficients"]:
temp_path = os.path.join(
config["export_path"], experiment_name + "_coefficients_donor.npy"
)
mylogger.info(f"Save acceptor donor to {temp_path}")
np.save(temp_path, coefficients_donor.cpu())
del coefficients_donor
mylogger.info("-==- Done -==-")
else:
dual_signal_mode = False
target_id = config["required_order"].index("donor")
data_donor = data[target_id, ...].clone()
data_donor[mask_negative, :] = 0.0
# TODO clean up --->
if config["save_oxyvol_as_python"] or config["save_oxyvol_as_matlab"]:
mylogger.info("")
mylogger.info("(I-OPTIONAL) SAVE OXY/VOL/MASK")
mylogger.info("-----------------------------------------------")
mylogger.info("")
# extract oxy and vol
mylogger.info("Save Oxygenation/Volume/Mask")
data_oxygenation = data[oxygenation_index, ...].clone()
data_volume = data[volume_index, ...].clone()
data_mask = mask_positve.clone()
# bin, if required...
if config["binning_enable"] and config["binning_at_the_end"]:
mylogger.info("Binning of data")
mylogger.info(
(
f"kernel_size={int(config['binning_kernel_size'])}, "
f"stride={int(config['binning_stride'])}, "
"divisor_override=None"
)
)
data_oxygenation = binning(
data_oxygenation.unsqueeze(-1),
kernel_size=int(config["binning_kernel_size"]),
stride=int(config["binning_stride"]),
divisor_override=None,
).squeeze(-1)
data_volume = binning(
data_volume.unsqueeze(-1),
kernel_size=int(config["binning_kernel_size"]),
stride=int(config["binning_stride"]),
divisor_override=None,
).squeeze(-1)
data_mask = (
binning(
data_mask.unsqueeze(-1).unsqueeze(-1).type(dtype=dtype),
kernel_size=int(config["binning_kernel_size"]),
stride=int(config["binning_stride"]),
divisor_override=None,
)
.squeeze(-1)
.squeeze(-1)
)
data_mask = (data_mask > 0).type(torch.bool)
if config["save_oxyvol_as_python"]:
# export it!
temp_path = os.path.join(
config["export_path"], experiment_name + "_oxygenation_volume.npz"
)
mylogger.info(f"Save data oxygenation and volume to {temp_path}")
np.savez_compressed(
temp_path,
data_oxygenation=data_oxygenation.cpu(),
data_volume=data_volume.cpu(),
data_mask=data_mask.cpu(),
)
if config["save_oxyvol_as_matlab"]:
temp_path = os.path.join(
config["export_path"], experiment_name + "_oxygenation_volume.hd5"
)
mylogger.info(f"Save data oxygenation and volume to {temp_path}")
file_handle = h5py.File(temp_path, "w")
data_mask = data_mask.movedim(0, -1)
data_oxygenation = data_oxygenation.movedim(1, -1).movedim(0, -1)
data_volume = data_volume.movedim(1, -1).movedim(0, -1)
_ = file_handle.create_dataset(
"data_mask",
data=data_mask.type(torch.uint8).cpu(),
compression="gzip",
compression_opts=9,
)
_ = file_handle.create_dataset(
"data_oxygenation",
data=data_oxygenation.cpu(),
compression="gzip",
compression_opts=9,
)
_ = file_handle.create_dataset(
"data_volume",
data=data_volume.cpu(),
compression="gzip",
compression_opts=9,
)
mylogger.info("Reminder: How to read with matlab:")
mylogger.info(f"data_mask = h5read('{temp_path}','/data_mask');")
mylogger.info(f"data_oxygenation = h5read('{temp_path}','/data_oxygenation');")
mylogger.info(f"data_volume = h5read('{temp_path}','/data_volume');")
file_handle.close()
# TODO <------ clean up
del data
del data_filtered
if device != torch.device("cpu"):
torch.cuda.empty_cache()
mylogger.info("Empty CUDA cache")
free_mem = cuda_total_memory - max(
[torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)]
)
mylogger.info(f"CUDA memory: {free_mem // 1024} MByte")
# #####################
if config["gevi"]:
assert dual_signal_mode
else:
assert dual_signal_mode is False
if dual_signal_mode is False:
mylogger.info("")
mylogger.info("(J1-OPTIONAL) SAVE ACC/DON/MASK (NO RATIO!+OPT BIN@END)")
mylogger.info("-----------------------------------------------")
mylogger.info("")
mylogger.info("mono signal model")
mylogger.info("Remove nan")
data_acceptor = torch.nan_to_num(data_acceptor, nan=0.0)
data_donor = torch.nan_to_num(data_donor, nan=0.0)
mylogger.info("-==- Done -==-")
if config["binning_enable"] and config["binning_at_the_end"]:
mylogger.info("Binning of data")
mylogger.info(
(
f"kernel_size={int(config['binning_kernel_size'])}, "
f"stride={int(config['binning_stride'])}, "
"divisor_override=None"
)
)
data_acceptor = binning(
data_acceptor.unsqueeze(-1),
kernel_size=int(config["binning_kernel_size"]),
stride=int(config["binning_stride"]),
divisor_override=None,
).squeeze(-1)
data_donor = binning(
data_donor.unsqueeze(-1),
kernel_size=int(config["binning_kernel_size"]),
stride=int(config["binning_stride"]),
divisor_override=None,
).squeeze(-1)
mask_positve = (
binning(
mask_positve.unsqueeze(-1).unsqueeze(-1).type(dtype=dtype),
kernel_size=int(config["binning_kernel_size"]),
stride=int(config["binning_stride"]),
divisor_override=None,
)
.squeeze(-1)
.squeeze(-1)
)
mask_positve = (mask_positve > 0).type(torch.bool)
if config["save_as_python"]:
temp_path = os.path.join(
config["export_path"], experiment_name + "_acceptor_donor.npz"
)
mylogger.info(f"Save data donor and acceptor and mask to {temp_path}")
np.savez_compressed(
temp_path,
data_acceptor=data_acceptor.cpu(),
data_donor=data_donor.cpu(),
mask=mask_positve.cpu(),
)
if config["save_as_matlab"]:
temp_path = os.path.join(
config["export_path"], experiment_name + "_acceptor_donor.hd5"
)
mylogger.info(f"Save data donor and acceptor and mask to {temp_path}")
file_handle = h5py.File(temp_path, "w")
mask_positve = mask_positve.movedim(0, -1)
data_acceptor = data_acceptor.movedim(1, -1).movedim(0, -1)
data_donor = data_donor.movedim(1, -1).movedim(0, -1)
_ = file_handle.create_dataset(
"mask",
data=mask_positve.type(torch.uint8).cpu(),
compression="gzip",
compression_opts=9,
)
_ = file_handle.create_dataset(
"data_acceptor",
data=data_acceptor.cpu(),
compression="gzip",
compression_opts=9,
)
_ = file_handle.create_dataset(
"data_donor",
data=data_donor.cpu(),
compression="gzip",
compression_opts=9,
)
mylogger.info("Reminder: How to read with matlab:")
mylogger.info(f"mask = h5read('{temp_path}','/mask');")
mylogger.info(f"data_acceptor = h5read('{temp_path}','/data_acceptor');")
mylogger.info(f"data_donor = h5read('{temp_path}','/data_donor');")
file_handle.close()
return
# #####################
mylogger.info("")
mylogger.info("(J2-OPTIONAL) BUILD AND SAVE RATIO (+OPT BIN@END)")
mylogger.info("-----------------------------------------------")
mylogger.info("")
mylogger.info("Calculate ratio sequence")
if config["classical_ratio_mode"]:
mylogger.info("via acceptor / donor")
ratio_sequence: torch.Tensor = data_acceptor / data_donor
mylogger.info("via / mean over time")
ratio_sequence /= ratio_sequence.mean(dim=-1, keepdim=True)
else:
mylogger.info("via 1.0 + acceptor - donor")
ratio_sequence = 1.0 + data_acceptor - data_donor
mylogger.info("Remove nan")
ratio_sequence = torch.nan_to_num(ratio_sequence, nan=0.0)
mylogger.info("-==- Done -==-")
if config["binning_enable"] and config["binning_at_the_end"]:
mylogger.info("Binning of data")
mylogger.info(
(
f"kernel_size={int(config['binning_kernel_size'])}, "
f"stride={int(config['binning_stride'])}, "
"divisor_override=None"
)
)
ratio_sequence = binning(
ratio_sequence.unsqueeze(-1),
kernel_size=int(config["binning_kernel_size"]),
stride=int(config["binning_stride"]),
divisor_override=None,
).squeeze(-1)
if config["save_gevi_with_donor_acceptor"]:
data_acceptor = binning(
data_acceptor.unsqueeze(-1),
kernel_size=int(config["binning_kernel_size"]),
stride=int(config["binning_stride"]),
divisor_override=None,
).squeeze(-1)
data_donor = binning(
data_donor.unsqueeze(-1),
kernel_size=int(config["binning_kernel_size"]),
stride=int(config["binning_stride"]),
divisor_override=None,
).squeeze(-1)
mask_positve = (
binning(
mask_positve.unsqueeze(-1).unsqueeze(-1).type(dtype=dtype),
kernel_size=int(config["binning_kernel_size"]),
stride=int(config["binning_stride"]),
divisor_override=None,
)
.squeeze(-1)
.squeeze(-1)
)
mask_positve = (mask_positve > 0).type(torch.bool)
if config["save_as_python"]:
temp_path = os.path.join(
config["export_path"], experiment_name + "_ratio_sequence.npz"
)
mylogger.info(f"Save ratio_sequence and mask to {temp_path}")
if config["save_gevi_with_donor_acceptor"]:
np.savez_compressed(
temp_path, ratio_sequence=ratio_sequence.cpu(), mask=mask_positve.cpu(), data_acceptor=data_acceptor.cpu(), data_donor=data_donor.cpu()
)
else:
np.savez_compressed(
temp_path, ratio_sequence=ratio_sequence.cpu(), mask=mask_positve.cpu()
)
if config["save_as_matlab"]:
temp_path = os.path.join(
config["export_path"], experiment_name + "_ratio_sequence.hd5"
)
mylogger.info(f"Save ratio_sequence and mask to {temp_path}")
file_handle = h5py.File(temp_path, "w")
mask_positve = mask_positve.movedim(0, -1)
ratio_sequence = ratio_sequence.movedim(1, -1).movedim(0, -1)
_ = file_handle.create_dataset(
"mask",
data=mask_positve.type(torch.uint8).cpu(),
compression="gzip",
compression_opts=9,
)
_ = file_handle.create_dataset(
"ratio_sequence",
data=ratio_sequence.cpu(),
compression="gzip",
compression_opts=9,
)
if config["save_gevi_with_donor_acceptor"]:
_ = file_handle.create_dataset(
"data_acceptor",
data=data_acceptor.cpu(),
compression="gzip",
compression_opts=9,
)
_ = file_handle.create_dataset(
"data_donor",
data=data_donor.cpu(),
compression="gzip",
compression_opts=9,
)
mylogger.info("Reminder: How to read with matlab:")
mylogger.info(f"mask = h5read('{temp_path}','/mask');")
mylogger.info(f"ratio_sequence = h5read('{temp_path}','/ratio_sequence');")
if config["save_gevi_with_donor_acceptor"]:
mylogger.info(f"data_donor = h5read('{temp_path}','/data_donor');")
mylogger.info(f"data_acceptor = h5read('{temp_path}','/data_acceptor');")
file_handle.close()
del ratio_sequence
del mask_positve
del mask_negative
mylogger.info("")
mylogger.info("***********************************************")
mylogger.info("* TRIAL END ***********************************")
mylogger.info("***********************************************")
mylogger.info("")
return
def main(
*,
config_filename: str = "config.json",
experiment_id_overwrite: int = -1,
trial_id_overwrite: int = -1,
) -> None:
mylogger = create_logger(
save_logging_messages=True,
display_logging_messages=True,
log_stage_name="stage_4",
)
config = load_config(mylogger=mylogger, filename=config_filename)
if (config["save_as_python"] is False) and (config["save_as_matlab"] is False):
mylogger.info("No output will be created. ")
mylogger.info("Change save_as_python and/or save_as_matlab in the config file")
mylogger.info("ERROR: STOP!!!")
exit()
if (len(config["target_camera_donor"]) == 0) and (
len(config["target_camera_acceptor"]) == 0
):
mylogger.info(
"Configure at least target_camera_donor or target_camera_acceptor correctly."
)
mylogger.info("ERROR: STOP!!!")
exit()
device = get_torch_device(mylogger, config["force_to_cpu"])
mylogger.info(
f"Create directory {config['export_path']} in the case it does not exist"
)
os.makedirs(config["export_path"], exist_ok=True)
raw_data_path: str = os.path.join(
config["basic_path"],
config["recoding_data"],
config["mouse_identifier"],
config["raw_path"],
)
if os.path.isdir(raw_data_path) is False:
mylogger.info(f"ERROR: could not find raw directory {raw_data_path}!!!!")
exit()
if experiment_id_overwrite == -1:
experiments = get_experiments(raw_data_path)
else:
assert experiment_id_overwrite >= 0
experiments = torch.tensor([experiment_id_overwrite])
for experiment_counter in range(0, experiments.shape[0]):
experiment_id = int(experiments[experiment_counter])
if trial_id_overwrite == -1:
trials = get_trials(raw_data_path, experiment_id)
else:
assert trial_id_overwrite >= 0
trials = torch.tensor([trial_id_overwrite])
for trial_counter in range(0, trials.shape[0]):
trial_id = int(trials[trial_counter])
mylogger.info("")
mylogger.info(
f"======= EXPERIMENT ID: {experiment_id} ==== TRIAL ID: {trial_id} ======="
)
mylogger.info("")
try:
process_trial(
config=config,
mylogger=mylogger,
experiment_id=experiment_id,
trial_id=trial_id,
device=device,
)
except torch.cuda.OutOfMemoryError:
mylogger.info("WARNING: RUNNING IN FAILBACK MODE!!!!")
mylogger.info("Not enough GPU memory. Retry on CPU")
process_trial(
config=config,
mylogger=mylogger,
experiment_id=experiment_id,
trial_id=trial_id,
device=torch.device("cpu"),
)
if __name__ == "__main__":
argh.dispatch_command(main)
# %%