From d0b8ff58d26b27c79f6e53629c2bf370f69ae72e Mon Sep 17 00:00:00 2001 From: David Rotermund Date: Wed, 19 Mar 2025 18:32:36 +0100 Subject: [PATCH] Add files via upload --- README.md | 59 + functions/Anime.py | 93 ++ functions/ImageAlignment.py | 1015 ++++++++++++ functions/align_refref.py | 60 + functions/bandpass.py | 113 ++ functions/binning.py | 46 + functions/calculate_rotation.py | 40 + functions/calculate_translation.py | 37 + functions/create_logger.py | 37 + functions/data_raw_loader.py | 339 ++++ functions/gauss_smear_individual.py | 168 ++ functions/get_experiments.py | 19 + functions/get_parts.py | 18 + functions/get_torch_device.py | 17 + functions/get_trials.py | 19 + functions/load_config.py | 16 + functions/load_meta_data.py | 68 + functions/perform_donor_volume_rotation.py | 207 +++ functions/perform_donor_volume_translation.py | 210 +++ functions/regression.py | 117 ++ functions/regression_internal.py | 27 + geci/config_M_Sert_Cre_41.json | 67 + geci/config_M_Sert_Cre_42.json | 67 + geci/config_M_Sert_Cre_45.json | 67 + geci/config_M_Sert_Cre_46.json | 67 + geci/config_M_Sert_Cre_49.json | 67 + geci/config_example_GECI.json | 67 + geci/geci_loader.py | 168 ++ geci/geci_plot.py | 181 +++ geci/stage_6_convert_roi.py | 53 + gevi/config_M0134M_2024-11-06_SessionA.json | 67 + gevi/config_M0134M_2024-11-06_SessionB.json | 67 + gevi/config_M0134M_2024-11-07_SessionA.json | 67 + gevi/config_M0134M_2024-11-07_SessionB.json | 67 + gevi/config_M0134M_2024-11-13_SessionA.json | 67 + gevi/config_M0134M_2024-11-13_SessionB.json | 67 + gevi/config_M0134M_2024-11-15_SessionA.json | 67 + gevi/config_M0134M_2024-11-15_SessionB.json | 67 + gevi/config_M0134M_2024-11-18_SessionA.json | 67 + gevi/config_M0134M_2024-11-18_SessionB.json | 67 + gevi/config_M0134M_2024-12-04_SessionA.json | 67 + gevi/config_M0134M_2024-12-04_SessionB.json | 67 + gevi/config_M3905F_SessionB.json | 67 + gevi/config_example_GEVI.json | 66 + gevi/example_load_gevi.py | 56 + other/stage_4b_inspect.py | 532 +++++++ other/stage_4c_viewer.py | 56 + stage_1_get_ref_image.py | 129 ++ stage_2_make_heartbeat_mask.py | 163 ++ stage_3_refine_mask.py | 169 ++ stage_4_process.py | 1413 +++++++++++++++++ stage_5_convert_metadata.py | 57 + 52 files changed, 7041 insertions(+) create mode 100644 README.md create mode 100644 functions/Anime.py create mode 100644 functions/ImageAlignment.py create mode 100644 functions/align_refref.py create mode 100644 functions/bandpass.py create mode 100644 functions/binning.py create mode 100644 functions/calculate_rotation.py create mode 100644 functions/calculate_translation.py create mode 100644 functions/create_logger.py create mode 100644 functions/data_raw_loader.py create mode 100644 functions/gauss_smear_individual.py create mode 100644 functions/get_experiments.py create mode 100644 functions/get_parts.py create mode 100644 functions/get_torch_device.py create mode 100644 functions/get_trials.py create mode 100644 functions/load_config.py create mode 100644 functions/load_meta_data.py create mode 100644 functions/perform_donor_volume_rotation.py create mode 100644 functions/perform_donor_volume_translation.py create mode 100644 functions/regression.py create mode 100644 functions/regression_internal.py create mode 100644 geci/config_M_Sert_Cre_41.json create mode 100644 geci/config_M_Sert_Cre_42.json create mode 100644 geci/config_M_Sert_Cre_45.json create mode 100644 geci/config_M_Sert_Cre_46.json create mode 100644 geci/config_M_Sert_Cre_49.json create mode 100644 geci/config_example_GECI.json create mode 100644 geci/geci_loader.py create mode 100644 geci/geci_plot.py create mode 100644 geci/stage_6_convert_roi.py create mode 100644 gevi/config_M0134M_2024-11-06_SessionA.json create mode 100644 gevi/config_M0134M_2024-11-06_SessionB.json create mode 100644 gevi/config_M0134M_2024-11-07_SessionA.json create mode 100644 gevi/config_M0134M_2024-11-07_SessionB.json create mode 100644 gevi/config_M0134M_2024-11-13_SessionA.json create mode 100644 gevi/config_M0134M_2024-11-13_SessionB.json create mode 100644 gevi/config_M0134M_2024-11-15_SessionA.json create mode 100644 gevi/config_M0134M_2024-11-15_SessionB.json create mode 100644 gevi/config_M0134M_2024-11-18_SessionA.json create mode 100644 gevi/config_M0134M_2024-11-18_SessionB.json create mode 100644 gevi/config_M0134M_2024-12-04_SessionA.json create mode 100644 gevi/config_M0134M_2024-12-04_SessionB.json create mode 100644 gevi/config_M3905F_SessionB.json create mode 100644 gevi/config_example_GEVI.json create mode 100644 gevi/example_load_gevi.py create mode 100644 other/stage_4b_inspect.py create mode 100644 other/stage_4c_viewer.py create mode 100644 stage_1_get_ref_image.py create mode 100644 stage_2_make_heartbeat_mask.py create mode 100644 stage_3_refine_mask.py create mode 100644 stage_4_process.py create mode 100644 stage_5_convert_metadata.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..79fa2ca --- /dev/null +++ b/README.md @@ -0,0 +1,59 @@ +This code is a reimagining of + +Robert Staadt + +Development of a system for high-volume multi-channel brain imaging of fluorescent voltage signals + +Dissertation + +Ruhr-Universität Bochum, Universitätsbibliothek + +08.02.2024 + +[https://doi.org/10.13154/294-11032](https://doi.org/10.13154/294-11032) + +----------------------------------------------------------------------------------------------------- + +Updated: 19.03.2025 + +Files are now organized in subdirectories to distinguish better between code for GEVI or GECI analysis. + +gevi-geci/ + stage_1*, stage_2*, stage_3*, stage_4*, stage_5* + -> main stages for data preprocessing + -> use e.g.: python stage_1_get_ref_image.py -c config_example_GEVI.json + functions/ + -> functions used by the main stages + +gevi-geci/gevi/ + config_example_GEVI.json + -> typical config file for GEVI (compare to gevi-geci/geci/config_example_GECI.json) + config_M0134M*, config_M3905F* + -> config files for a few recordings (adjust directory names, if necessary!) + example_load_gevi.py + -> simple script demonstrating how to load data + +gevi-geci/geci/ + config_example_GECI.json + -> typical config file for GECI (compare to gevi-geci/gevi/config_example_GEVI.json) + config_M_Sert_Cre_4* + -> config files for a few recordings (adjust directory names, if necessary!) + stage_6_convert_roi.py + -> additional stage for the analysis of Hendrik's recordings + -> use e.g.: python stage_6_convert_roi.py -f config_M_Sert_Cre_41.json + geci_loader.py, geci_plot.py + -> additional code for summarizing the results and plotting with the ROIs + -> use e.g. python geci_loader.py --filename config_M_Sert_Cre_41.json + +gevi-geci/other/ + stage_4b_inspect.py, stage_4c_viewer.py + -> temporary code for assisting search for implantation electrode + + + + + + + + + diff --git a/functions/Anime.py b/functions/Anime.py new file mode 100644 index 0000000..bfc4e46 --- /dev/null +++ b/functions/Anime.py @@ -0,0 +1,93 @@ +import numpy as np +import torch +import matplotlib.pyplot as plt +import matplotlib.animation + + +class Anime: + def __init__(self) -> None: + super().__init__() + + def show( + self, + input: torch.Tensor | np.ndarray, + mask: torch.Tensor | np.ndarray | None = None, + vmin: float | None = None, + vmax: float | None = None, + cmap: str = "hot", + axis_off: bool = True, + show_frame_count: bool = True, + interval: int = 100, + repeat: bool = False, + colorbar: bool = True, + vmin_scale: float | None = None, + vmax_scale: float | None = None, + movie_file: str | None = None, + ) -> None: + assert input.ndim == 3 + + if isinstance(input, torch.Tensor): + input_np: np.ndarray = input.cpu().numpy() + if mask is not None: + mask_np: np.ndarray | None = (mask == 0).cpu().numpy() + else: + mask_np = None + else: + input_np = input + if mask is not None: + mask_np = mask == 0 # type: ignore + else: + mask_np = None + + if vmin is None: + vmin = float(np.where(np.isfinite(input_np), input_np, 0.0).min()) + if vmax is None: + vmax = float(np.where(np.isfinite(input_np), input_np, 0.0).max()) + + if vmin_scale is not None: + vmin *= vmin_scale + + if vmax_scale is not None: + vmax *= vmax_scale + + fig = plt.figure() + image = np.nan_to_num(input_np[0, ...], copy=True, nan=0.0) + if mask_np is not None: + image[mask_np] = float("NaN") + image_handle = plt.imshow( + image, + cmap=cmap, + vmin=vmin, + vmax=vmax, + ) + + if colorbar: + plt.colorbar() + + if axis_off: + plt.axis("off") + + def next_frame(i: int) -> None: + image = np.nan_to_num(input_np[i, ...], copy=True, nan=0.0) + if mask_np is not None: + image[mask_np] = float("NaN") + + image_handle.set_data(image) + if show_frame_count: + bar_length: int = 10 + filled_length = int(round(bar_length * i / input_np.shape[0])) + bar = "\u25A0" * filled_length + "\u25A1" * (bar_length - filled_length) + plt.title(f"{bar} {i} of {int(input_np.shape[0]-1)}", loc="left") + return + + ani = matplotlib.animation.FuncAnimation( + fig, + next_frame, + frames=int(input.shape[0]), + interval=interval, + repeat=repeat, + ) + if movie_file is not None: + ani.save(movie_file) + else: + plt.show() diff --git a/functions/ImageAlignment.py b/functions/ImageAlignment.py new file mode 100644 index 0000000..6472d02 --- /dev/null +++ b/functions/ImageAlignment.py @@ -0,0 +1,1015 @@ +import torch +import torchvision as tv # type: ignore + +# The source code is based on: +# https://github.com/matejak/imreg_dft + +# The original LICENSE: +# Copyright (c) 2014, Matěj Týč +# Copyright (c) 2011-2014, Christoph Gohlke +# Copyright (c) 2011-2014, The Regents of the University of California +# Produced at the Laboratory for Fluorescence Dynamics + +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# * Neither the name of the {organization} nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +class ImageAlignment(torch.nn.Module): + device: torch.device + default_dtype: torch.dtype + excess_const: float = 1.1 + exponent: str = "inf" + success: torch.Tensor | None = None + + # The factor that detmines how many + # sub-pixel we will shift + scale_factor: int = 4 + + reference_image: torch.Tensor | None = None + + last_scale: torch.Tensor | None = None + last_angle: torch.Tensor | None = None + last_tvec: torch.Tensor | None = None + + # Cache + image_reference_dft: torch.Tensor | None = None + filt: torch.Tensor + pcorr_shape: torch.Tensor + log_base: torch.Tensor + image_reference_logp: torch.Tensor + + def __init__( + self, + device: torch.device | None = None, + default_dtype: torch.dtype | None = None, + ) -> None: + super().__init__() + + assert device is not None + assert default_dtype is not None + self.device = device + self.default_dtype = default_dtype + + def set_new_reference_image(self, new_reference_image: torch.Tensor | None = None): + assert new_reference_image is not None + assert new_reference_image.ndim == 2 + self.reference_image = ( + new_reference_image.detach() + .clone() + .to(device=self.device) + .type(dtype=self.default_dtype) + ) + self.image_reference_dft = None + + def forward( + self, input: torch.Tensor, new_reference_image: torch.Tensor | None = None + ) -> torch.Tensor: + assert input.ndim == 3 + + if new_reference_image is not None: + self.set_new_reference_image(new_reference_image) + + assert self.reference_image is not None + assert self.reference_image.ndim == 2 + assert input.shape[-2] == self.reference_image.shape[-2] + assert input.shape[-1] == self.reference_image.shape[-1] + + self.last_scale, self.last_angle, self.last_tvec, output = self.similarity( + self.reference_image, + input.to(device=self.device).type(dtype=self.default_dtype), + ) + + return output + + def dry_run( + self, input: torch.Tensor, new_reference_image: torch.Tensor | None = None + ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: + assert input.ndim == 3 + + if new_reference_image is not None: + self.set_new_reference_image(new_reference_image) + + assert self.reference_image is not None + assert self.reference_image.ndim == 2 + assert input.shape[-2] == self.reference_image.shape[-2] + assert input.shape[-1] == self.reference_image.shape[-1] + + images_todo = input.to(device=self.device).type(dtype=self.default_dtype) + image_reference = self.reference_image + + assert image_reference.ndim == 2 + assert images_todo.ndim == 3 + + bgval: torch.Tensor = self.get_borderval(img=images_todo, radius=5) + + self.last_scale, self.last_angle, self.last_tvec = self._similarity( + image_reference, + images_todo, + bgval, + ) + + return self.last_scale, self.last_angle, self.last_tvec + + def dry_run_translation( + self, input: torch.Tensor, new_reference_image: torch.Tensor | None = None + ) -> torch.Tensor: + assert input.ndim == 3 + + if new_reference_image is not None: + self.set_new_reference_image(new_reference_image) + + assert self.reference_image is not None + assert self.reference_image.ndim == 2 + assert input.shape[-2] == self.reference_image.shape[-2] + assert input.shape[-1] == self.reference_image.shape[-1] + + images_todo = input.to(device=self.device).type(dtype=self.default_dtype) + image_reference = self.reference_image + + assert image_reference.ndim == 2 + assert images_todo.ndim == 3 + + tvec, _ = self._translation(image_reference, images_todo) + + return tvec + + # --------------- + + def dry_run_angle( + self, + input: torch.Tensor, + new_reference_image: torch.Tensor | None = None, + ) -> torch.Tensor: + assert input.ndim == 3 + + if new_reference_image is not None: + self.set_new_reference_image(new_reference_image) + + constraints_dynamic_angle_0: torch.Tensor = torch.zeros( + (input.shape[0]), dtype=self.default_dtype, device=self.device + ) + constraints_dynamic_angle_1: torch.Tensor | None = None + constraints_dynamic_scale_0: torch.Tensor = torch.ones( + (input.shape[0]), dtype=self.default_dtype, device=self.device + ) + constraints_dynamic_scale_1: torch.Tensor | None = None + + assert self.reference_image is not None + assert self.reference_image.ndim == 2 + assert input.shape[-2] == self.reference_image.shape[-2] + assert input.shape[-1] == self.reference_image.shape[-1] + + images_todo = input.to(device=self.device).type(dtype=self.default_dtype) + image_reference = self.reference_image + + assert image_reference.ndim == 2 + assert images_todo.ndim == 3 + + _, newangle = self._get_ang_scale( + image_reference, + images_todo, + constraints_dynamic_scale_0, + constraints_dynamic_scale_1, + constraints_dynamic_angle_0, + constraints_dynamic_angle_1, + ) + + return newangle + + # --------------- + + def _get_pcorr_shape(self, shape: torch.Size) -> tuple[int, int]: + ret = (int(max(shape[-2:]) * 1.0),) * 2 + return ret + + def _get_log_base(self, shape: torch.Size, new_r: torch.Tensor) -> torch.Tensor: + old_r = torch.tensor( + (float(shape[-2]) * self.excess_const) / 2.0, + dtype=self.default_dtype, + device=self.device, + ) + log_base = torch.exp(torch.log(old_r) / new_r) + return log_base + + def wrap_angle( + self, angles: torch.Tensor, ceil: float = 2 * torch.pi + ) -> torch.Tensor: + angles += ceil / 2.0 + angles %= ceil + angles -= ceil / 2.0 + return angles + + def get_borderval( + self, img: torch.Tensor, radius: int | None = None + ) -> torch.Tensor: + assert img.ndim == 3 + if radius is None: + mindim = min([int(img.shape[-2]), int(img.shape[-1])]) + radius = max(1, mindim // 20) + mask = torch.zeros( + (int(img.shape[-2]), int(img.shape[-1])), + dtype=torch.bool, + device=self.device, + ) + mask[:, :radius] = True + mask[:, -radius:] = True + mask[:radius, :] = True + mask[-radius:, :] = True + + mean = torch.median(img[:, mask], dim=-1)[0] + return mean + + def get_apofield(self, shape: torch.Size, aporad: int) -> torch.Tensor: + if aporad == 0: + return torch.ones( + shape[-2:], + dtype=self.default_dtype, + device=self.device, + ) + + assert int(shape[-2]) > aporad * 2 + assert int(shape[-1]) > aporad * 2 + + apos = torch.hann_window( + aporad * 2, dtype=self.default_dtype, periodic=False, device=self.device + ) + + toapp_0 = torch.ones( + shape[-2], + dtype=self.default_dtype, + device=self.device, + ) + toapp_0[:aporad] = apos[:aporad] + toapp_0[-aporad:] = apos[-aporad:] + + toapp_1 = torch.ones( + shape[-1], + dtype=self.default_dtype, + device=self.device, + ) + toapp_1[:aporad] = apos[:aporad] + toapp_1[-aporad:] = apos[-aporad:] + + apofield = torch.outer(toapp_0, toapp_1) + + return apofield + + def _get_subarr( + self, array: torch.Tensor, center: torch.Tensor, rad: int + ) -> torch.Tensor: + assert array.ndim == 3 + assert center.ndim == 2 + assert array.shape[0] == center.shape[0] + assert center.shape[1] == 2 + + dim = 1 + 2 * rad + subarr = torch.zeros( + (array.shape[0], dim, dim), dtype=self.default_dtype, device=self.device + ) + + corner = center - rad + idx_p = range(0, corner.shape[0]) + for ii in range(0, dim): + yidx = corner[:, 0] + ii + yidx %= array.shape[-2] + for jj in range(0, dim): + xidx = corner[:, 1] + jj + xidx %= array.shape[-1] + subarr[:, ii, jj] = array[idx_p, yidx, xidx] + + return subarr + + def _argmax_2d(self, array: torch.Tensor) -> torch.Tensor: + assert array.ndim == 3 + + max_pos = array.reshape( + (array.shape[0], array.shape[1] * array.shape[2]) + ).argmax(dim=1) + pos_0 = max_pos // array.shape[2] + + max_pos -= pos_0 * array.shape[2] + ret = torch.zeros( + (array.shape[0], 2), dtype=self.default_dtype, device=self.device + ) + ret[:, 0] = pos_0 + ret[:, 1] = max_pos + + return ret.type(dtype=torch.int64) + + def _apodize(self, what: torch.Tensor) -> torch.Tensor: + mindim = min([int(what.shape[-2]), int(what.shape[-1])]) + aporad = int(mindim * 0.12) + + apofield = self.get_apofield(what.shape, aporad).unsqueeze(0) + + res = what * apofield + bg = self.get_borderval(what, aporad // 2).unsqueeze(-1).unsqueeze(-1) + res += bg * (1 - apofield) + return res + + def _logpolar_filter(self, shape: torch.Size) -> torch.Tensor: + yy = torch.linspace( + -torch.pi / 2.0, + torch.pi / 2.0, + shape[-2], + dtype=self.default_dtype, + device=self.device, + ).unsqueeze(1) + + xx = torch.linspace( + -torch.pi / 2.0, + torch.pi / 2.0, + shape[-1], + dtype=self.default_dtype, + device=self.device, + ).unsqueeze(0) + + rads = torch.sqrt(yy**2 + xx**2) + filt = 1.0 - torch.cos(rads) ** 2 + + filt[torch.abs(rads) > torch.pi / 2] = 1 + return filt + + def _get_angles(self, shape: torch.Tensor) -> torch.Tensor: + ret = torch.zeros( + (int(shape[-2]), int(shape[-1])), + dtype=self.default_dtype, + device=self.device, + ) + ret -= torch.linspace( + 0, + torch.pi, + int(shape[-2] + 1), + dtype=self.default_dtype, + device=self.device, + )[:-1].unsqueeze(-1) + + return ret + + def _get_lograd(self, shape: torch.Tensor, log_base: torch.Tensor) -> torch.Tensor: + ret = torch.zeros( + (int(shape[-2]), int(shape[-1])), + dtype=self.default_dtype, + device=self.device, + ) + ret += torch.pow( + log_base, + torch.arange( + 0, + int(shape[-1]), + dtype=self.default_dtype, + device=self.device, + ), + ).unsqueeze(0) + return ret + + def _logpolar( + self, image: torch.Tensor, shape: torch.Tensor, log_base: torch.Tensor + ) -> torch.Tensor: + assert image.ndim == 3 + + imshape: torch.Tensor = torch.tensor( + image.shape[-2:], + dtype=self.default_dtype, + device=self.device, + ) + + center: torch.Tensor = imshape.clone() / 2 + + theta: torch.Tensor = self._get_angles(shape) + radius_x: torch.Tensor = self._get_lograd(shape, log_base) + radius_y: torch.Tensor = radius_x.clone() + + ellipse_coef: torch.Tensor = imshape[0] / imshape[1] + radius_x /= ellipse_coef + + y = radius_y * torch.sin(theta) + center[0] + y /= float(image.shape[-2]) + y *= 2 + y -= 1 + + x = radius_x * torch.cos(theta) + center[1] + x /= float(image.shape[-1]) + x *= 2 + x -= 1 + + idx_x = torch.where(torch.abs(x) <= 1.0, 1.0, 0.0) + idx_y = torch.where(torch.abs(y) <= 1.0, 1.0, 0.0) + + normalized_coords = torch.cat( + ( + x.unsqueeze(-1), + y.unsqueeze(-1), + ), + dim=-1, + ).unsqueeze(0) + + output = torch.empty( + (int(image.shape[0]), int(y.shape[0]), int(y.shape[1])), + dtype=self.default_dtype, + device=self.device, + ) + + for id in range(0, int(image.shape[0])): + bgval: torch.Tensor = torch.quantile(image[id, :, :], q=1.0 / 100.0) + + temp = torch.nn.functional.grid_sample( + image[id, :, :].unsqueeze(0).unsqueeze(0), + normalized_coords, + mode="bilinear", + padding_mode="zeros", + align_corners=False, + ) + + output[id, :, :] = torch.where((idx_x * idx_y) == 0.0, bgval, temp) + + return output + + def _argmax_ext(self, array: torch.Tensor, exponent: float | str) -> torch.Tensor: + assert array.ndim == 3 + + if exponent == "inf": + ret = self._argmax_2d(array) + else: + assert isinstance(exponent, float) or isinstance(exponent, int) + + col = ( + torch.arange( + 0, array.shape[-2], dtype=self.default_dtype, device=self.device + ) + .unsqueeze(-1) + .unsqueeze(0) + ) + row = ( + torch.arange( + 0, array.shape[-1], dtype=self.default_dtype, device=self.device + ) + .unsqueeze(0) + .unsqueeze(0) + ) + + arr2 = torch.pow(array, float(exponent)) + arrsum = arr2.sum(dim=-2).sum(dim=-1) + + ret = torch.zeros( + (array.shape[0], 2), dtype=self.default_dtype, device=self.device + ) + + arrprody = (arr2 * col).sum(dim=-1).sum(dim=-1) / arrsum + arrprodx = (arr2 * row).sum(dim=-1).sum(dim=-1) / arrsum + + ret[:, 0] = arrprody.squeeze(-1).squeeze(-1) + ret[:, 1] = arrprodx.squeeze(-1).squeeze(-1) + + idx = torch.where(arrsum == 0.0)[0] + ret[idx, :] = 0.0 + return ret + + def _interpolate( + self, array: torch.Tensor, rough: torch.Tensor, rad: int = 2 + ) -> torch.Tensor: + assert array.ndim == 3 + assert rough.ndim == 2 + + rough = torch.round(rough).type(torch.int64) + + surroundings = self._get_subarr(array, rough, rad) + + com = self._argmax_ext(surroundings, 1.0) + + offset = com - rad + ret = rough + offset + + ret += 0.5 + ret %= ( + torch.tensor(array.shape[-2:], dtype=self.default_dtype, device=self.device) + .type(dtype=torch.int64) + .unsqueeze(0) + ) + ret -= 0.5 + return ret + + def _get_success( + self, array: torch.Tensor, coord: torch.Tensor, radius: int = 2 + ) -> torch.Tensor: + assert array.ndim == 3 + assert coord.ndim == 2 + assert array.shape[0] == coord.shape[0] + assert coord.shape[1] == 2 + + coord = torch.round(coord).type(dtype=torch.int64) + subarr = self._get_subarr( + array, coord, 2 + ) # Not my fault. They want a 2 there. Not radius + + theval = subarr.sum(dim=-1).sum(dim=-1) + + theval2 = array[range(0, coord.shape[0]), coord[:, 0], coord[:, 1]] + + success = torch.sqrt(theval * theval2) + return success + + def _get_constraint_mask( + self, + shape: torch.Size, + log_base: torch.Tensor, + constraints_scale_0: torch.Tensor, + constraints_scale_1: torch.Tensor | None, + constraints_angle_0: torch.Tensor, + constraints_angle_1: torch.Tensor | None, + ) -> torch.Tensor: + assert constraints_scale_0 is not None + assert constraints_angle_0 is not None + assert constraints_scale_0.ndim == 1 + assert constraints_angle_0.ndim == 1 + + assert constraints_scale_0.shape[0] == constraints_angle_0.shape[0] + + mask: torch.Tensor = torch.ones( + (constraints_scale_0.shape[0], int(shape[-2]), int(shape[-1])), + device=self.device, + dtype=self.default_dtype, + ) + + scale: torch.Tensor = constraints_scale_0.clone() + if constraints_scale_1 is not None: + sigma: torch.Tensor | None = constraints_scale_1.clone() + else: + sigma = None + + scales = torch.fft.ifftshift( + self._get_lograd( + torch.tensor(shape[-2:], device=self.device, dtype=self.default_dtype), + log_base, + ) + ) + + scales *= log_base ** (-shape[-1] / 2.0) + scales = scales.unsqueeze(0) - (1.0 / scale).unsqueeze(-1).unsqueeze(-1) + + if sigma is not None: + assert sigma.shape[0] == constraints_scale_0.shape[0] + + for p_id in range(0, sigma.shape[0]): + if sigma[p_id] == 0: + ascales = torch.abs(scales[p_id, ...]) + scale_min = ascales.min() + binary_mask = torch.where(ascales > scale_min, 0.0, 1.0) + mask[p_id, ...] *= binary_mask + else: + mask[p_id, ...] *= torch.exp( + -(torch.pow(scales[p_id, ...], 2)) / torch.pow(sigma[p_id], 2) + ) + + angle: torch.Tensor = constraints_angle_0.clone() + if constraints_angle_1 is not None: + sigma = constraints_angle_1.clone() + else: + sigma = None + + angles = self._get_angles( + torch.tensor(shape[-2:], device=self.device, dtype=self.default_dtype) + ) + + angles = angles.unsqueeze(0) + torch.deg2rad(angle).unsqueeze(-1).unsqueeze(-1) + + angles = torch.rad2deg(angles) + + if sigma is not None: + assert sigma.shape[0] == constraints_scale_0.shape[0] + + for p_id in range(0, sigma.shape[0]): + if sigma[p_id] == 0: + aangles = torch.abs(angles[p_id, ...]) + angle_min = aangles.min() + binary_mask = torch.where(aangles > angle_min, 0.0, 1.0) + mask[p_id, ...] *= binary_mask + else: + mask *= torch.exp( + -(torch.pow(angles[p_id, ...], 2)) / torch.pow(sigma[p_id], 2) + ) + + mask = torch.fft.fftshift(mask, dim=(-2, -1)) + + return mask + + def argmax_angscale( + self, + array: torch.Tensor, + log_base: torch.Tensor, + constraints_scale_0: torch.Tensor, + constraints_scale_1: torch.Tensor | None, + constraints_angle_0: torch.Tensor, + constraints_angle_1: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + assert array.ndim == 3 + assert constraints_scale_0 is not None + assert constraints_angle_0 is not None + assert constraints_scale_0.ndim == 1 + assert constraints_angle_0.ndim == 1 + + mask = self._get_constraint_mask( + array.shape[-2:], + log_base, + constraints_scale_0, + constraints_scale_1, + constraints_angle_0, + constraints_angle_1, + ) + + array_orig = array.clone() + + array *= mask + ret = self._argmax_ext(array, self.exponent) + + ret_final = self._interpolate(array, ret) + + success = self._get_success(array_orig, ret_final, 0) + + return ret_final, success + + def argmax_translation( + self, array: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + assert array.ndim == 3 + + array_orig = array.clone() + + ashape = torch.tensor(array.shape[-2:], device=self.device).type( + dtype=torch.int64 + ) + + aporad = (ashape // 6).min() + mask2 = self.get_apofield(torch.Size(ashape), aporad).unsqueeze(0) + array *= mask2 + + tvec = self._argmax_ext(array, "inf") + + tvec = self._interpolate(array_orig, tvec) + + success = self._get_success(array_orig, tvec, 2) + + return tvec, success + + def transform_img( + self, + img: torch.Tensor, + scale: torch.Tensor | None = None, + angle: torch.Tensor | None = None, + tvec: torch.Tensor | None = None, + bgval: torch.Tensor | None = None, + ) -> torch.Tensor: + assert img.ndim == 3 + + if scale is None: + scale = torch.ones( + (img.shape[0],), dtype=self.default_dtype, device=self.device + ) + assert scale.ndim == 1 + assert scale.shape[0] == img.shape[0] + + if angle is None: + angle = torch.zeros( + (img.shape[0],), dtype=self.default_dtype, device=self.device + ) + assert angle.ndim == 1 + assert angle.shape[0] == img.shape[0] + + if tvec is None: + tvec = torch.zeros( + (img.shape[0], 2), dtype=self.default_dtype, device=self.device + ) + assert tvec.ndim == 2 + assert tvec.shape[0] == img.shape[0] + assert tvec.shape[1] == 2 + + if bgval is None: + bgval = self.get_borderval(img) + assert bgval.ndim == 1 + assert bgval.shape[0] == img.shape[0] + + # Otherwise we need to decompose it and put it back together + assert torch.is_complex(img) is False + + output = torch.zeros_like(img) + + for pos in range(0, img.shape[0]): + image_processed = img[pos, :, :].unsqueeze(0).clone() + + temp_shift = [ + int(round(tvec[pos, 1].item() * self.scale_factor)), + int(round(tvec[pos, 0].item() * self.scale_factor)), + ] + + image_processed = torch.nn.functional.interpolate( + image_processed.unsqueeze(0), + scale_factor=self.scale_factor, + mode="bilinear", + ).squeeze(0) + + image_processed = tv.transforms.functional.affine( + img=image_processed, + angle=-float(angle[pos]), + translate=temp_shift, + scale=float(scale[pos]), + shear=[0, 0], + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=float(bgval[pos]), + center=None, + ) + + image_processed = torch.nn.functional.interpolate( + image_processed.unsqueeze(0), + scale_factor=1.0 / self.scale_factor, + mode="bilinear", + ).squeeze(0) + + image_processed = tv.transforms.functional.center_crop( + image_processed, img.shape[-2:] + ) + + output[pos, ...] = image_processed.squeeze(0) + + return output + + def transform_img_dict( + self, + img: torch.Tensor, + scale: torch.Tensor | None = None, + angle: torch.Tensor | None = None, + tvec: torch.Tensor | None = None, + bgval: torch.Tensor | None = None, + invert=False, + ) -> torch.Tensor: + if invert: + if scale is not None: + scale = 1.0 / scale + if angle is not None: + angle *= -1 + if tvec is not None: + tvec *= -1 + + res = self.transform_img(img, scale, angle, tvec, bgval=bgval) + return res + + def _phase_correlation( + self, image_reference: torch.Tensor, images_todo: torch.Tensor, callback, *args + ) -> tuple[torch.Tensor, torch.Tensor]: + assert image_reference.ndim == 3 + assert image_reference.shape[0] == 1 + assert images_todo.ndim == 3 + + assert callback is not None + + image_reference_fft = torch.fft.fft2(image_reference, dim=(-2, -1)) + images_todo_fft = torch.fft.fft2(images_todo, dim=(-2, -1)) + + eps = torch.abs(images_todo_fft).max(dim=-1)[0].max(dim=-1)[0] * 1e-15 + + cps = abs( + torch.fft.ifft2( + (image_reference_fft * images_todo_fft.conj()) + / ( + torch.abs(image_reference_fft) * torch.abs(images_todo_fft) + + eps.unsqueeze(-1).unsqueeze(-1) + ), + dim=(-2, -1), + ) + ) + + scps = torch.fft.fftshift(cps, dim=(-2, -1)) + + ret, success = callback(scps, *args) + + ret[:, 0] -= image_reference_fft.shape[-2] // 2 + ret[:, 1] -= image_reference_fft.shape[-1] // 2 + + return ret, success + + def _translation( + self, im0: torch.Tensor, im1: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + assert im0.ndim == 2 + ret, succ = self._phase_correlation( + im0.unsqueeze(0), im1, self.argmax_translation + ) + + return ret, succ + + def _get_ang_scale( + self, + image_reference: torch.Tensor, + images_todo: torch.Tensor, + constraints_scale_0: torch.Tensor, + constraints_scale_1: torch.Tensor | None, + constraints_angle_0: torch.Tensor, + constraints_angle_1: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + assert image_reference.ndim == 2 + assert images_todo.ndim == 3 + assert image_reference.shape[-1] == images_todo.shape[-1] + assert image_reference.shape[-2] == images_todo.shape[-2] + assert constraints_scale_0.shape[0] == images_todo.shape[0] + assert constraints_angle_0.shape[0] == images_todo.shape[0] + + if constraints_scale_1 is not None: + assert constraints_scale_1.shape[0] == images_todo.shape[0] + + if constraints_angle_1 is not None: + assert constraints_angle_1.shape[0] == images_todo.shape[0] + + if self.image_reference_dft is None: + image_reference_apod = self._apodize(image_reference.unsqueeze(0)) + self.image_reference_dft = torch.fft.fftshift( + torch.fft.fft2(image_reference_apod, dim=(-2, -1)), dim=(-2, -1) + ) + self.filt = self._logpolar_filter(image_reference.shape).unsqueeze(0) + self.image_reference_dft *= self.filt + self.pcorr_shape = torch.tensor( + self._get_pcorr_shape(image_reference.shape[-2:]), + dtype=self.default_dtype, + device=self.device, + ) + self.log_base = self._get_log_base( + image_reference.shape, + self.pcorr_shape[1], + ) + self.image_reference_logp = self._logpolar( + torch.abs(self.image_reference_dft), self.pcorr_shape, self.log_base + ) + + images_todo_apod = self._apodize(images_todo) + images_todo_dft = torch.fft.fftshift( + torch.fft.fft2(images_todo_apod, dim=(-2, -1)), dim=(-2, -1) + ) + + images_todo_dft *= self.filt + + images_todo_lopg = self._logpolar( + torch.abs(images_todo_dft), self.pcorr_shape, self.log_base + ) + + temp, _ = self._phase_correlation( + self.image_reference_logp, + images_todo_lopg, + self.argmax_angscale, + self.log_base, + constraints_scale_0, + constraints_scale_1, + constraints_angle_0, + constraints_angle_1, + ) + + arg_ang = temp[:, 0].clone() + arg_rad = temp[:, 1].clone() + + angle = -torch.pi * arg_ang / float(self.pcorr_shape[0]) + angle = torch.rad2deg(angle) + + angle = self.wrap_angle(angle, 360) + + scale = torch.pow(self.log_base, arg_rad) + + angle = -angle + scale = 1.0 / scale + + assert torch.where(scale < 2)[0].shape[0] == scale.shape[0] + assert torch.where(scale > 0.5)[0].shape[0] == scale.shape[0] + + return scale, angle + + def translation( + self, im0: torch.Tensor, im1: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + angle = torch.zeros( + (im1.shape[0]), dtype=self.default_dtype, device=self.device + ) + assert im1.ndim == 3 + assert im0.shape[-2] == im1.shape[-2] + assert im0.shape[-1] == im1.shape[-1] + + tvec, succ = self._translation(im0, im1) + tvec2, succ2 = self._translation(im0, torch.rot90(im1, k=2, dims=[-2, -1])) + + assert tvec.shape[0] == tvec2.shape[0] + assert tvec.ndim == 2 + assert tvec2.ndim == 2 + assert tvec.shape[1] == 2 + assert tvec2.shape[1] == 2 + assert succ.shape[0] == succ2.shape[0] + assert succ.ndim == 1 + assert succ2.ndim == 1 + assert tvec.shape[0] == succ.shape[0] + assert angle.shape[0] == tvec.shape[0] + assert angle.ndim == 1 + + for pos in range(0, angle.shape[0]): + pick_rotated = False + if succ2[pos] > succ[pos]: + pick_rotated = True + + if pick_rotated: + tvec[pos, :] = tvec2[pos, :] + succ[pos] = succ2[pos] + angle[pos] += 180 + + return tvec, succ, angle + + def _similarity( + self, + image_reference: torch.Tensor, + images_todo: torch.Tensor, + bgval: torch.Tensor, + ): + assert image_reference.ndim == 2 + assert images_todo.ndim == 3 + assert image_reference.shape[-1] == images_todo.shape[-1] + assert image_reference.shape[-2] == images_todo.shape[-2] + + # We are going to iterate and precise scale and angle estimates + scale: torch.Tensor = torch.ones( + (images_todo.shape[0]), dtype=self.default_dtype, device=self.device + ) + angle: torch.Tensor = torch.zeros( + (images_todo.shape[0]), dtype=self.default_dtype, device=self.device + ) + + constraints_dynamic_angle_0: torch.Tensor = torch.zeros( + (images_todo.shape[0]), dtype=self.default_dtype, device=self.device + ) + constraints_dynamic_angle_1: torch.Tensor | None = None + constraints_dynamic_scale_0: torch.Tensor = torch.ones( + (images_todo.shape[0]), dtype=self.default_dtype, device=self.device + ) + constraints_dynamic_scale_1: torch.Tensor | None = None + + newscale, newangle = self._get_ang_scale( + image_reference, + images_todo, + constraints_dynamic_scale_0, + constraints_dynamic_scale_1, + constraints_dynamic_angle_0, + constraints_dynamic_angle_1, + ) + scale *= newscale + angle += newangle + + im2 = self.transform_img(images_todo, scale, angle, bgval=bgval) + + tvec, self.success, res_angle = self.translation(image_reference, im2) + + angle += res_angle + + angle = self.wrap_angle(angle, 360) + + return scale, angle, tvec + + def similarity( + self, + image_reference: torch.Tensor, + images_todo: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + assert image_reference.ndim == 2 + assert images_todo.ndim == 3 + + bgval: torch.Tensor = self.get_borderval(img=images_todo, radius=5) + + scale, angle, tvec = self._similarity( + image_reference, + images_todo, + bgval, + ) + + im2 = self.transform_img_dict( + img=images_todo, + scale=scale, + angle=angle, + tvec=tvec, + bgval=bgval, + ) + + return scale, angle, tvec, im2 diff --git a/functions/align_refref.py b/functions/align_refref.py new file mode 100644 index 0000000..3208cf3 --- /dev/null +++ b/functions/align_refref.py @@ -0,0 +1,60 @@ +import torch +import torchvision as tv # type: ignore +import logging +from functions.ImageAlignment import ImageAlignment +from functions.calculate_translation import calculate_translation +from functions.calculate_rotation import calculate_rotation + + +@torch.no_grad() +def align_refref( + mylogger: logging.Logger, + ref_image_acceptor: torch.Tensor, + ref_image_donor: torch.Tensor, + batch_size: int, + fill_value: float = 0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + + image_alignment = ImageAlignment( + default_dtype=ref_image_acceptor.dtype, device=ref_image_acceptor.device + ) + + mylogger.info("Rotate ref image acceptor onto donor") + angle_refref = calculate_rotation( + image_alignment=image_alignment, + input=ref_image_acceptor.unsqueeze(0), + reference_image=ref_image_donor, + batch_size=batch_size, + ) + + ref_image_acceptor = tv.transforms.functional.affine( + img=ref_image_acceptor.unsqueeze(0), + angle=-float(angle_refref), + translate=[0, 0], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=fill_value, + ) + + mylogger.info("Translate ref image acceptor onto donor") + tvec_refref = calculate_translation( + image_alignment=image_alignment, + input=ref_image_acceptor, + reference_image=ref_image_donor, + batch_size=batch_size, + ) + + tvec_refref = tvec_refref[0, :] + + ref_image_acceptor = tv.transforms.functional.affine( + img=ref_image_acceptor, + angle=0, + translate=[tvec_refref[1], tvec_refref[0]], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=fill_value, + ).squeeze(0) + + return angle_refref, tvec_refref, ref_image_acceptor, ref_image_donor diff --git a/functions/bandpass.py b/functions/bandpass.py new file mode 100644 index 0000000..2659847 --- /dev/null +++ b/functions/bandpass.py @@ -0,0 +1,113 @@ +import torchaudio as ta # type: ignore +import torch + + +@torch.no_grad() +def filtfilt( + input: torch.Tensor, + butter_a: torch.Tensor, + butter_b: torch.Tensor, +) -> torch.Tensor: + assert butter_a.ndim == 1 + assert butter_b.ndim == 1 + assert butter_a.shape[0] == butter_b.shape[0] + + process_data: torch.Tensor = input.detach().clone() + + padding_length = 12 * int(butter_a.shape[0]) + left_padding = 2 * process_data[..., 0].unsqueeze(-1) - process_data[ + ..., 1 : padding_length + 1 + ].flip(-1) + right_padding = 2 * process_data[..., -1].unsqueeze(-1) - process_data[ + ..., -(padding_length + 1) : -1 + ].flip(-1) + process_data_padded = torch.cat((left_padding, process_data, right_padding), dim=-1) + + output = ta.functional.filtfilt( + process_data_padded.unsqueeze(0), butter_a, butter_b, clamp=False + ).squeeze(0) + + output = output[..., padding_length:-padding_length] + return output + + +@torch.no_grad() +def butter_bandpass( + device: torch.device, + low_frequency: float = 0.1, + high_frequency: float = 1.0, + fs: float = 30.0, +) -> tuple[torch.Tensor, torch.Tensor]: + import scipy # type: ignore + + butter_b_np, butter_a_np = scipy.signal.butter( + 4, [low_frequency, high_frequency], btype="bandpass", output="ba", fs=fs + ) + butter_a = torch.tensor(butter_a_np, device=device, dtype=torch.float32) + butter_b = torch.tensor(butter_b_np, device=device, dtype=torch.float32) + return butter_a, butter_b + + +@torch.no_grad() +def chunk_iterator(array: torch.Tensor, chunk_size: int): + for i in range(0, array.shape[0], chunk_size): + yield array[i : i + chunk_size] + + +@torch.no_grad() +def bandpass( + data: torch.Tensor, + low_frequency: float = 0.1, + high_frequency: float = 1.0, + fs=30.0, + filtfilt_chuck_size: int = 10, +) -> torch.Tensor: + + try: + return bandpass_internal( + data=data, + low_frequency=low_frequency, + high_frequency=high_frequency, + fs=fs, + filtfilt_chuck_size=filtfilt_chuck_size, + ) + + except torch.cuda.OutOfMemoryError: + + return bandpass_internal( + data=data.cpu(), + low_frequency=low_frequency, + high_frequency=high_frequency, + fs=fs, + filtfilt_chuck_size=filtfilt_chuck_size, + ).to(device=data.device) + + +@torch.no_grad() +def bandpass_internal( + data: torch.Tensor, + low_frequency: float = 0.1, + high_frequency: float = 1.0, + fs=30.0, + filtfilt_chuck_size: int = 10, +) -> torch.Tensor: + butter_a, butter_b = butter_bandpass( + device=data.device, + low_frequency=low_frequency, + high_frequency=high_frequency, + fs=fs, + ) + + index_full_dataset: torch.Tensor = torch.arange( + 0, data.shape[1], device=data.device, dtype=torch.int64 + ) + + for chunk in chunk_iterator(index_full_dataset, filtfilt_chuck_size): + temp_filtfilt = filtfilt( + data[:, chunk, :], + butter_a=butter_a, + butter_b=butter_b, + ) + data[:, chunk, :] = temp_filtfilt + + return data diff --git a/functions/binning.py b/functions/binning.py new file mode 100644 index 0000000..f873433 --- /dev/null +++ b/functions/binning.py @@ -0,0 +1,46 @@ +import torch + + +@torch.no_grad() +def binning( + data: torch.Tensor, + kernel_size: int = 4, + stride: int = 4, + divisor_override: int | None = 1, +) -> torch.Tensor: + + try: + return binning_internal( + data=data, + kernel_size=kernel_size, + stride=stride, + divisor_override=divisor_override, + ) + except torch.cuda.OutOfMemoryError: + return binning_internal( + data=data.cpu(), + kernel_size=kernel_size, + stride=stride, + divisor_override=divisor_override, + ).to(device=data.device) + + +@torch.no_grad() +def binning_internal( + data: torch.Tensor, + kernel_size: int = 4, + stride: int = 4, + divisor_override: int | None = 1, +) -> torch.Tensor: + + assert data.ndim == 4 + return ( + torch.nn.functional.avg_pool2d( + input=data.movedim(0, -1).movedim(0, -1), + kernel_size=kernel_size, + stride=stride, + divisor_override=divisor_override, + ) + .movedim(-1, 0) + .movedim(-1, 0) + ) diff --git a/functions/calculate_rotation.py b/functions/calculate_rotation.py new file mode 100644 index 0000000..6a53afd --- /dev/null +++ b/functions/calculate_rotation.py @@ -0,0 +1,40 @@ +import torch + +from functions.ImageAlignment import ImageAlignment + + +@torch.no_grad() +def calculate_rotation( + image_alignment: ImageAlignment, + input: torch.Tensor, + reference_image: torch.Tensor, + batch_size: int, +) -> torch.Tensor: + angle = torch.zeros((input.shape[0])) + + data_loader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset(input), + batch_size=batch_size, + shuffle=False, + ) + start_position: int = 0 + for input_batch in data_loader: + assert len(input_batch) == 1 + + end_position = start_position + input_batch[0].shape[0] + + angle_temp = image_alignment.dry_run_angle( + input=input_batch[0], + new_reference_image=reference_image, + ) + + assert angle_temp is not None + + angle[start_position:end_position] = angle_temp + + start_position += input_batch[0].shape[0] + + angle = torch.where(angle >= 180, 360.0 - angle, angle) + angle = torch.where(angle <= -180, 360.0 + angle, angle) + + return angle diff --git a/functions/calculate_translation.py b/functions/calculate_translation.py new file mode 100644 index 0000000..9eadf59 --- /dev/null +++ b/functions/calculate_translation.py @@ -0,0 +1,37 @@ +import torch + +from functions.ImageAlignment import ImageAlignment + + +@torch.no_grad() +def calculate_translation( + image_alignment: ImageAlignment, + input: torch.Tensor, + reference_image: torch.Tensor, + batch_size: int, +) -> torch.Tensor: + tvec = torch.zeros((input.shape[0], 2)) + + data_loader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset(input), + batch_size=batch_size, + shuffle=False, + ) + start_position: int = 0 + for input_batch in data_loader: + assert len(input_batch) == 1 + + end_position = start_position + input_batch[0].shape[0] + + tvec_temp = image_alignment.dry_run_translation( + input=input_batch[0], + new_reference_image=reference_image, + ) + + assert tvec_temp is not None + + tvec[start_position:end_position, :] = tvec_temp + + start_position += input_batch[0].shape[0] + + return tvec diff --git a/functions/create_logger.py b/functions/create_logger.py new file mode 100644 index 0000000..b7e746f --- /dev/null +++ b/functions/create_logger.py @@ -0,0 +1,37 @@ +import logging +import datetime +import os + + +def create_logger( + save_logging_messages: bool, display_logging_messages: bool, log_stage_name: str +): + now = datetime.datetime.now() + dt_string_filename = now.strftime("%Y_%m_%d_%H_%M_%S") + + logger = logging.getLogger("MyLittleLogger") + logger.setLevel(logging.DEBUG) + + if save_logging_messages: + time_format = "%b %d %Y %H:%M:%S" + logformat = "%(asctime)s %(message)s" + file_formatter = logging.Formatter(fmt=logformat, datefmt=time_format) + os.makedirs("logs_" + log_stage_name, exist_ok=True) + file_handler = logging.FileHandler( + os.path.join("logs_" + log_stage_name, f"log_{dt_string_filename}.txt") + ) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(file_formatter) + logger.addHandler(file_handler) + + if display_logging_messages: + time_format = "%H:%M:%S" + logformat = "%(asctime)s %(message)s" + stream_formatter = logging.Formatter(fmt=logformat, datefmt=time_format) + + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.INFO) + stream_handler.setFormatter(stream_formatter) + logger.addHandler(stream_handler) + + return logger diff --git a/functions/data_raw_loader.py b/functions/data_raw_loader.py new file mode 100644 index 0000000..67e55cf --- /dev/null +++ b/functions/data_raw_loader.py @@ -0,0 +1,339 @@ +import numpy as np +import torch +import os +import logging +import copy + +from functions.get_experiments import get_experiments +from functions.get_trials import get_trials +from functions.get_parts import get_parts +from functions.load_meta_data import load_meta_data + + +def data_raw_loader( + raw_data_path: str, + mylogger: logging.Logger, + experiment_id: int, + trial_id: int, + device: torch.device, + force_to_cpu_memory: bool, + config: dict, +) -> tuple[list[str], str, str, dict, dict, float, float, str, torch.Tensor]: + + meta_channels: list[str] = [] + meta_mouse_markings: str = "" + meta_recording_date: str = "" + meta_stimulation_times: dict = {} + meta_experiment_names: dict = {} + meta_trial_recording_duration: float = 0.0 + meta_frame_time: float = 0.0 + meta_mouse: str = "" + data: torch.Tensor = torch.zeros((1)) + + dtype_str = config["dtype"] + mylogger.info(f"Data precision will be {dtype_str}") + dtype: torch.dtype = getattr(torch, dtype_str) + dtype_np: np.dtype = getattr(np, dtype_str) + + if os.path.isdir(raw_data_path) is False: + mylogger.info(f"ERROR: could not find raw directory {raw_data_path}!!!!") + assert os.path.isdir(raw_data_path) + return ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + data, + ) + + if (torch.where(get_experiments(raw_data_path) == experiment_id)[0].shape[0]) != 1: + mylogger.info(f"ERROR: could not find experiment id {experiment_id}!!!!") + assert ( + torch.where(get_experiments(raw_data_path) == experiment_id)[0].shape[0] + ) == 1 + return ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + data, + ) + + if ( + torch.where(get_trials(raw_data_path, experiment_id) == trial_id)[0].shape[0] + ) != 1: + mylogger.info(f"ERROR: could not find trial id {trial_id}!!!!") + assert ( + torch.where(get_trials(raw_data_path, experiment_id) == trial_id)[0].shape[ + 0 + ] + ) == 1 + return ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + data, + ) + + available_parts: torch.Tensor = get_parts(raw_data_path, experiment_id, trial_id) + if available_parts.shape[0] < 1: + mylogger.info("ERROR: could not find any part files") + assert available_parts.shape[0] >= 1 + + experiment_name = f"Exp{experiment_id:03d}_Trial{trial_id:03d}" + mylogger.info(f"Will work on: {experiment_name}") + + mylogger.info(f"We found {int(available_parts.shape[0])} parts.") + + first_run: bool = True + + mylogger.info("Compare meta data of all parts") + for id in range(0, available_parts.shape[0]): + part_id = available_parts[id] + + filename_meta: str = os.path.join( + raw_data_path, + f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}_meta.txt", + ) + + if os.path.isfile(filename_meta) is False: + mylogger.info(f"Could not load meta data... {filename_meta}") + assert os.path.isfile(filename_meta) + return ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + data, + ) + + ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + ) = load_meta_data( + mylogger=mylogger, filename_meta=filename_meta, silent_mode=True + ) + + if first_run: + first_run = False + master_meta_channels: list[str] = copy.deepcopy(meta_channels) + master_meta_mouse_markings: str = meta_mouse_markings + master_meta_recording_date: str = meta_recording_date + master_meta_stimulation_times: dict = copy.deepcopy(meta_stimulation_times) + master_meta_experiment_names: dict = copy.deepcopy(meta_experiment_names) + master_meta_trial_recording_duration: float = meta_trial_recording_duration + master_meta_frame_time: float = meta_frame_time + master_meta_mouse: str = meta_mouse + + meta_channels_check = master_meta_channels == meta_channels + + # Check channel order + if meta_channels_check: + for channel_a, channel_b in zip(master_meta_channels, meta_channels): + if channel_a != channel_b: + meta_channels_check = False + + meta_mouse_markings_check = master_meta_mouse_markings == meta_mouse_markings + meta_recording_date_check = master_meta_recording_date == meta_recording_date + meta_stimulation_times_check = ( + master_meta_stimulation_times == meta_stimulation_times + ) + meta_experiment_names_check = ( + master_meta_experiment_names == meta_experiment_names + ) + meta_trial_recording_duration_check = ( + master_meta_trial_recording_duration == meta_trial_recording_duration + ) + meta_frame_time_check = master_meta_frame_time == meta_frame_time + meta_mouse_check = master_meta_mouse == meta_mouse + + if meta_channels_check is False: + mylogger.info(f"{filename_meta} failed: channels") + assert meta_channels_check + + if meta_mouse_markings_check is False: + mylogger.info(f"{filename_meta} failed: mouse_markings") + assert meta_mouse_markings_check + + if meta_recording_date_check is False: + mylogger.info(f"{filename_meta} failed: recording_date") + assert meta_recording_date_check + + if meta_stimulation_times_check is False: + mylogger.info(f"{filename_meta} failed: stimulation_times") + assert meta_stimulation_times_check + + if meta_experiment_names_check is False: + mylogger.info(f"{filename_meta} failed: experiment_names") + assert meta_experiment_names_check + + if meta_trial_recording_duration_check is False: + mylogger.info(f"{filename_meta} failed: trial_recording_duration") + assert meta_trial_recording_duration_check + + if meta_frame_time_check is False: + mylogger.info(f"{filename_meta} failed: frame_time_check") + assert meta_frame_time_check + + if meta_mouse_check is False: + mylogger.info(f"{filename_meta} failed: mouse") + assert meta_mouse_check + mylogger.info("-==- Done -==-") + + mylogger.info(f"Will use: {filename_meta} for meta data") + ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + ) = load_meta_data(mylogger=mylogger, filename_meta=filename_meta) + + ################# + # Meta data end # + ################# + + first_run = True + mylogger.info("Count the number of frames in the data of all parts") + frame_count: int = 0 + for id in range(0, available_parts.shape[0]): + part_id = available_parts[id] + + filename_data: str = os.path.join( + raw_data_path, + f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}.npy", + ) + + if os.path.isfile(filename_data) is False: + mylogger.info(f"Could not load data... {filename_data}") + assert os.path.isfile(filename_data) + return ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + data, + ) + data_np: np.ndarray = np.load(filename_data, mmap_mode="r") + + if data_np.ndim != 4: + mylogger.info(f"ERROR: Data needs to have 4 dimensions {filename_data}") + assert data_np.ndim == 4 + + if first_run: + first_run = False + dim_0: int = int(data_np.shape[0]) + dim_1: int = int(data_np.shape[1]) + dim_3: int = int(data_np.shape[3]) + + frame_count += int(data_np.shape[2]) + + if int(data_np.shape[0]) != dim_0: + mylogger.info( + f"ERROR: Data dim 0 is broken {int(data_np.shape[0])} vs {dim_0} {filename_data}" + ) + assert int(data_np.shape[0]) == dim_0 + + if int(data_np.shape[1]) != dim_1: + mylogger.info( + f"ERROR: Data dim 1 is broken {int(data_np.shape[1])} vs {dim_1} {filename_data}" + ) + assert int(data_np.shape[1]) == dim_1 + + if int(data_np.shape[3]) != dim_3: + mylogger.info( + f"ERROR: Data dim 3 is broken {int(data_np.shape[3])} vs {dim_3} {filename_data}" + ) + assert int(data_np.shape[3]) == dim_3 + + mylogger.info( + f"{filename_data}: {int(data_np.shape[2])} frames -> {frame_count} frames total" + ) + + if force_to_cpu_memory: + mylogger.info("Using CPU memory for data") + data = torch.empty( + (dim_0, dim_1, frame_count, dim_3), dtype=dtype, device=torch.device("cpu") + ) + else: + mylogger.info("Using GPU memory for data") + data = torch.empty( + (dim_0, dim_1, frame_count, dim_3), dtype=dtype, device=device + ) + + start_position: int = 0 + end_position: int = 0 + for id in range(0, available_parts.shape[0]): + part_id = available_parts[id] + + filename_data = os.path.join( + raw_data_path, + f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}.npy", + ) + + mylogger.info(f"Will work on {filename_data}") + mylogger.info("Loading data file") + data_np = np.load(filename_data).astype(dtype_np) + + end_position = start_position + int(data_np.shape[2]) + + for i in range(0, len(config["required_order"])): + mylogger.info(f"Move raw data channel: {config['required_order'][i]}") + + idx = meta_channels.index(config["required_order"][i]) + data[..., start_position:end_position, i] = torch.tensor( + data_np[..., idx], dtype=dtype, device=data.device + ) + start_position = end_position + + if start_position != int(data.shape[2]): + mylogger.info("ERROR: data was not fulled fully!!!") + assert start_position == int(data.shape[2]) + + mylogger.info("-==- Done -==-") + + ################# + # Raw data end # + ################# + + return ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + data, + ) diff --git a/functions/gauss_smear_individual.py b/functions/gauss_smear_individual.py new file mode 100644 index 0000000..73dba65 --- /dev/null +++ b/functions/gauss_smear_individual.py @@ -0,0 +1,168 @@ +import torch +import math + + +@torch.no_grad() +def gauss_smear_individual( + input: torch.Tensor, + spatial_width: float, + temporal_width: float, + overwrite_fft_gauss: None | torch.Tensor = None, + use_matlab_mask: bool = True, + epsilon: float = float(torch.finfo(torch.float64).eps), +) -> tuple[torch.Tensor, torch.Tensor]: + try: + return gauss_smear_individual_core( + input=input, + spatial_width=spatial_width, + temporal_width=temporal_width, + overwrite_fft_gauss=overwrite_fft_gauss, + use_matlab_mask=use_matlab_mask, + epsilon=epsilon, + ) + except torch.cuda.OutOfMemoryError: + + if overwrite_fft_gauss is None: + overwrite_fft_gauss_cpu: None | torch.Tensor = None + else: + overwrite_fft_gauss_cpu = overwrite_fft_gauss.cpu() + + input_cpu: torch.Tensor = input.cpu() + + output, overwrite_fft_gauss = gauss_smear_individual_core( + input=input_cpu, + spatial_width=spatial_width, + temporal_width=temporal_width, + overwrite_fft_gauss=overwrite_fft_gauss_cpu, + use_matlab_mask=use_matlab_mask, + epsilon=epsilon, + ) + return ( + output.to(device=input.device), + overwrite_fft_gauss.to(device=input.device), + ) + + +@torch.no_grad() +def gauss_smear_individual_core( + input: torch.Tensor, + spatial_width: float, + temporal_width: float, + overwrite_fft_gauss: None | torch.Tensor = None, + use_matlab_mask: bool = True, + epsilon: float = float(torch.finfo(torch.float64).eps), +) -> tuple[torch.Tensor, torch.Tensor]: + + dim_x: int = int(2 * math.ceil(2 * spatial_width) + 1) + dim_y: int = int(2 * math.ceil(2 * spatial_width) + 1) + dim_t: int = int(2 * math.ceil(2 * temporal_width) + 1) + dims_xyt: torch.Tensor = torch.tensor( + [dim_x, dim_y, dim_t], dtype=torch.int64, device=input.device + ) + + if input.ndim == 2: + input = input.unsqueeze(-1) + + input_padded = torch.nn.functional.pad( + input.unsqueeze(0), + pad=( + dim_t, + dim_t, + dim_y, + dim_y, + dim_x, + dim_x, + ), + mode="replicate", + ).squeeze(0) + + if overwrite_fft_gauss is None: + center_x: int = int(math.ceil(input_padded.shape[0] / 2)) + center_y: int = int(math.ceil(input_padded.shape[1] / 2)) + center_z: int = int(math.ceil(input_padded.shape[2] / 2)) + grid_x: torch.Tensor = ( + torch.arange(0, input_padded.shape[0], device=input.device) - center_x + 1 + ) + grid_y: torch.Tensor = ( + torch.arange(0, input_padded.shape[1], device=input.device) - center_y + 1 + ) + grid_z: torch.Tensor = ( + torch.arange(0, input_padded.shape[2], device=input.device) - center_z + 1 + ) + + grid_x = grid_x.unsqueeze(-1).unsqueeze(-1) ** 2 + grid_y = grid_y.unsqueeze(0).unsqueeze(-1) ** 2 + grid_z = grid_z.unsqueeze(0).unsqueeze(0) ** 2 + + gauss_kernel: torch.Tensor = ( + (grid_x / (spatial_width**2)) + + (grid_y / (spatial_width**2)) + + (grid_z / (temporal_width**2)) + ) + + if use_matlab_mask: + filter_radius: torch.Tensor = (dims_xyt - 1) // 2 + + border_lower: list[int] = [ + center_x - int(filter_radius[0]) - 1, + center_y - int(filter_radius[1]) - 1, + center_z - int(filter_radius[2]) - 1, + ] + + border_upper: list[int] = [ + center_x + int(filter_radius[0]), + center_y + int(filter_radius[1]), + center_z + int(filter_radius[2]), + ] + + matlab_mask: torch.Tensor = torch.zeros_like(gauss_kernel) + matlab_mask[ + border_lower[0] : border_upper[0], + border_lower[1] : border_upper[1], + border_lower[2] : border_upper[2], + ] = 1.0 + + gauss_kernel = torch.exp(-gauss_kernel / 2.0) + if use_matlab_mask: + gauss_kernel = gauss_kernel * matlab_mask + + gauss_kernel[gauss_kernel < (epsilon * gauss_kernel.max())] = 0 + + sum_gauss_kernel: float = float(gauss_kernel.sum()) + + if sum_gauss_kernel != 0.0: + gauss_kernel = gauss_kernel / sum_gauss_kernel + + # FFT Shift + gauss_kernel = torch.cat( + (gauss_kernel[center_x - 1 :, :, :], gauss_kernel[: center_x - 1, :, :]), + dim=0, + ) + gauss_kernel = torch.cat( + (gauss_kernel[:, center_y - 1 :, :], gauss_kernel[:, : center_y - 1, :]), + dim=1, + ) + gauss_kernel = torch.cat( + (gauss_kernel[:, :, center_z - 1 :], gauss_kernel[:, :, : center_z - 1]), + dim=2, + ) + overwrite_fft_gauss = torch.fft.fftn(gauss_kernel) + input_padded_gauss_filtered: torch.Tensor = torch.real( + torch.fft.ifftn(torch.fft.fftn(input_padded) * overwrite_fft_gauss) + ) + else: + input_padded_gauss_filtered = torch.real( + torch.fft.ifftn(torch.fft.fftn(input_padded) * overwrite_fft_gauss) + ) + + start = dims_xyt + stop = ( + torch.tensor(input_padded.shape, device=dims_xyt.device, dtype=dims_xyt.dtype) + - dims_xyt + ) + + output = input_padded_gauss_filtered[ + start[0] : stop[0], start[1] : stop[1], start[2] : stop[2] + ] + + return (output, overwrite_fft_gauss) diff --git a/functions/get_experiments.py b/functions/get_experiments.py new file mode 100644 index 0000000..d92b936 --- /dev/null +++ b/functions/get_experiments.py @@ -0,0 +1,19 @@ +import torch +import os +import glob + + +@torch.no_grad() +def get_experiments(path: str) -> torch.Tensor: + filename_np: str = os.path.join( + path, + "Exp*_Part001.npy", + ) + + list_str = glob.glob(filename_np) + list_int: list[int] = [] + for i in range(0, len(list_str)): + list_int.append(int(list_str[i].split("Exp")[-1].split("_Trial")[0])) + list_int = sorted(list_int) + + return torch.tensor(list_int).unique() diff --git a/functions/get_parts.py b/functions/get_parts.py new file mode 100644 index 0000000..d68e1ae --- /dev/null +++ b/functions/get_parts.py @@ -0,0 +1,18 @@ +import torch +import os +import glob + + +@torch.no_grad() +def get_parts(path: str, experiment_id: int, trial_id: int) -> torch.Tensor: + filename_np: str = os.path.join( + path, + f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part*.npy", + ) + + list_str = glob.glob(filename_np) + list_int: list[int] = [] + for i in range(0, len(list_str)): + list_int.append(int(list_str[i].split("_Part")[-1].split(".npy")[0])) + list_int = sorted(list_int) + return torch.tensor(list_int).unique() diff --git a/functions/get_torch_device.py b/functions/get_torch_device.py new file mode 100644 index 0000000..9eec5e9 --- /dev/null +++ b/functions/get_torch_device.py @@ -0,0 +1,17 @@ +import torch +import logging + + +def get_torch_device(mylogger: logging.Logger, force_to_cpu: bool) -> torch.device: + + if torch.cuda.is_available(): + device_name: str = "cuda:0" + else: + device_name = "cpu" + + if force_to_cpu: + device_name = "cpu" + + mylogger.info(f"Using device: {device_name}") + device: torch.device = torch.device(device_name) + return device diff --git a/functions/get_trials.py b/functions/get_trials.py new file mode 100644 index 0000000..abe33d2 --- /dev/null +++ b/functions/get_trials.py @@ -0,0 +1,19 @@ +import torch +import os +import glob + + +@torch.no_grad() +def get_trials(path: str, experiment_id: int) -> torch.Tensor: + filename_np: str = os.path.join( + path, + f"Exp{experiment_id:03d}_Trial*_Part001.npy", + ) + + list_str = glob.glob(filename_np) + list_int: list[int] = [] + for i in range(0, len(list_str)): + list_int.append(int(list_str[i].split("_Trial")[-1].split("_Part")[0])) + + list_int = sorted(list_int) + return torch.tensor(list_int).unique() diff --git a/functions/load_config.py b/functions/load_config.py new file mode 100644 index 0000000..c17fa40 --- /dev/null +++ b/functions/load_config.py @@ -0,0 +1,16 @@ +import json +import os +import logging + +from jsmin import jsmin # type:ignore + + +def load_config(mylogger: logging.Logger, filename: str = "config.json") -> dict: + mylogger.info("loading config file") + if os.path.isfile(filename) is False: + mylogger.info(f"{filename} is missing") + + with open(filename, "r") as file: + config = json.loads(jsmin(file.read())) + + return config diff --git a/functions/load_meta_data.py b/functions/load_meta_data.py new file mode 100644 index 0000000..622473c --- /dev/null +++ b/functions/load_meta_data.py @@ -0,0 +1,68 @@ +import logging +import json + + +def load_meta_data( + mylogger: logging.Logger, filename_meta: str, silent_mode=False +) -> tuple[list[str], str, str, dict, dict, float, float, str]: + + if silent_mode is False: + mylogger.info("Loading meta data") + with open(filename_meta, "r") as file_handle: + metadata: dict = json.load(file_handle) + + channels: list[str] = metadata["channelKey"] + + if silent_mode is False: + mylogger.info(f"meta data: channel order: {channels}") + + if "mouseMarkings" in metadata["sessionMetaData"]: + mouse_markings: str = metadata["sessionMetaData"]["mouseMarkings"] + if silent_mode is False: + mylogger.info(f"meta data: mouse markings: {mouse_markings}") + else: + mouse_markings = "" + if silent_mode is False: + mylogger.info("meta data: no mouse markings") + + recording_date: str = metadata["sessionMetaData"]["date"] + if silent_mode is False: + mylogger.info(f"meta data: recording data: {recording_date}") + + stimulation_times: dict = metadata["sessionMetaData"]["stimulationTimes"] + if silent_mode is False: + mylogger.info(f"meta data: stimulation times: {stimulation_times}") + + experiment_names: dict = metadata["sessionMetaData"]["experimentNames"] + if silent_mode is False: + mylogger.info(f"meta data: experiment names: {experiment_names}") + + trial_recording_duration: float = float( + metadata["sessionMetaData"]["trialRecordingDuration"] + ) + if silent_mode is False: + mylogger.info( + f"meta data: trial recording duration: {trial_recording_duration} sec" + ) + + frame_time: float = float(metadata["sessionMetaData"]["frameTime"]) + if silent_mode is False: + mylogger.info( + f"meta data: frame time: {frame_time} sec ; frame rate: {1.0/frame_time}Hz" + ) + + mouse: str = metadata["sessionMetaData"]["mouse"] + if silent_mode is False: + mylogger.info(f"meta data: mouse: {mouse}") + mylogger.info("-==- Done -==-") + + return ( + channels, + mouse_markings, + recording_date, + stimulation_times, + experiment_names, + trial_recording_duration, + frame_time, + mouse, + ) diff --git a/functions/perform_donor_volume_rotation.py b/functions/perform_donor_volume_rotation.py new file mode 100644 index 0000000..1d2f55b --- /dev/null +++ b/functions/perform_donor_volume_rotation.py @@ -0,0 +1,207 @@ +import torch +import torchvision as tv # type: ignore +import logging +from functions.calculate_rotation import calculate_rotation +from functions.ImageAlignment import ImageAlignment + + +@torch.no_grad() +def perform_donor_volume_rotation( + mylogger: logging.Logger, + acceptor: torch.Tensor, + donor: torch.Tensor, + oxygenation: torch.Tensor, + volume: torch.Tensor, + ref_image_donor: torch.Tensor, + ref_image_volume: torch.Tensor, + batch_size: int, + config: dict, + fill_value: float = 0, +) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + try: + + return perform_donor_volume_rotation_internal( + mylogger=mylogger, + acceptor=acceptor, + donor=donor, + oxygenation=oxygenation, + volume=volume, + ref_image_donor=ref_image_donor, + ref_image_volume=ref_image_volume, + batch_size=batch_size, + config=config, + fill_value=fill_value, + ) + + except torch.cuda.OutOfMemoryError: + + ( + acceptor_cpu, + donor_cpu, + oxygenation_cpu, + volume_cpu, + angle_donor_volume_cpu, + ) = perform_donor_volume_rotation_internal( + mylogger=mylogger, + acceptor=acceptor.cpu(), + donor=donor.cpu(), + oxygenation=oxygenation.cpu(), + volume=volume.cpu(), + ref_image_donor=ref_image_donor.cpu(), + ref_image_volume=ref_image_volume.cpu(), + batch_size=batch_size, + config=config, + fill_value=fill_value, + ) + + return ( + acceptor_cpu.to(device=acceptor.device), + donor_cpu.to(device=acceptor.device), + oxygenation_cpu.to(device=acceptor.device), + volume_cpu.to(device=acceptor.device), + angle_donor_volume_cpu.to(device=acceptor.device), + ) + + +@torch.no_grad() +def perform_donor_volume_rotation_internal( + mylogger: logging.Logger, + acceptor: torch.Tensor, + donor: torch.Tensor, + oxygenation: torch.Tensor, + volume: torch.Tensor, + ref_image_donor: torch.Tensor, + ref_image_volume: torch.Tensor, + batch_size: int, + config: dict, + fill_value: float = 0, +) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + + image_alignment = ImageAlignment( + default_dtype=acceptor.dtype, device=acceptor.device + ) + + mylogger.info("Calculate rotation between donor data and donor ref image") + + angle_donor = calculate_rotation( + input=donor, + reference_image=ref_image_donor, + image_alignment=image_alignment, + batch_size=batch_size, + ) + + mylogger.info("Calculate rotation between volume data and volume ref image") + angle_volume = calculate_rotation( + input=volume, + reference_image=ref_image_volume, + image_alignment=image_alignment, + batch_size=batch_size, + ) + + mylogger.info("Average over both rotations") + + donor_threshold: torch.Tensor = torch.sort(torch.abs(angle_donor))[0] + donor_threshold = donor_threshold[ + int( + donor_threshold.shape[0] + * float(config["rotation_stabilization_threshold_border"]) + ) + ] * float(config["rotation_stabilization_threshold_factor"]) + + volume_threshold: torch.Tensor = torch.sort(torch.abs(angle_volume))[0] + volume_threshold = volume_threshold[ + int( + volume_threshold.shape[0] + * float(config["rotation_stabilization_threshold_border"]) + ) + ] * float(config["rotation_stabilization_threshold_factor"]) + + donor_idx = torch.where(torch.abs(angle_donor) > donor_threshold)[0] + volume_idx = torch.where(torch.abs(angle_volume) > volume_threshold)[0] + mylogger.info( + f"Border: {config['rotation_stabilization_threshold_border']}, " + f"factor {config['rotation_stabilization_threshold_factor']} " + ) + mylogger.info( + f"Donor threshold: {donor_threshold:.3e}, volume threshold: {volume_threshold:.3e}" + ) + mylogger.info( + f"Found broken rotation values: " + f"donor {int(donor_idx.shape[0])}, " + f"volume {int(volume_idx.shape[0])}" + ) + angle_donor[donor_idx] = angle_volume[donor_idx] + angle_volume[volume_idx] = angle_donor[volume_idx] + + donor_idx = torch.where(torch.abs(angle_donor) > donor_threshold)[0] + volume_idx = torch.where(torch.abs(angle_volume) > volume_threshold)[0] + mylogger.info( + f"After fill in these broken rotation values remain: " + f"donor {int(donor_idx.shape[0])}, " + f"volume {int(volume_idx.shape[0])}" + ) + angle_donor[donor_idx] = 0.0 + angle_volume[volume_idx] = 0.0 + angle_donor_volume = (angle_donor + angle_volume) / 2.0 + + mylogger.info("Rotate acceptor data based on the average rotation") + for frame_id in range(0, angle_donor_volume.shape[0]): + acceptor[frame_id, ...] = tv.transforms.functional.affine( + img=acceptor[frame_id, ...].unsqueeze(0), + angle=-float(angle_donor_volume[frame_id]), + translate=[0, 0], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=fill_value, + ).squeeze(0) + + mylogger.info("Rotate donor data based on the average rotation") + for frame_id in range(0, angle_donor_volume.shape[0]): + donor[frame_id, ...] = tv.transforms.functional.affine( + img=donor[frame_id, ...].unsqueeze(0), + angle=-float(angle_donor_volume[frame_id]), + translate=[0, 0], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=fill_value, + ).squeeze(0) + + mylogger.info("Rotate oxygenation data based on the average rotation") + for frame_id in range(0, angle_donor_volume.shape[0]): + oxygenation[frame_id, ...] = tv.transforms.functional.affine( + img=oxygenation[frame_id, ...].unsqueeze(0), + angle=-float(angle_donor_volume[frame_id]), + translate=[0, 0], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=fill_value, + ).squeeze(0) + + mylogger.info("Rotate volume data based on the average rotation") + for frame_id in range(0, angle_donor_volume.shape[0]): + volume[frame_id, ...] = tv.transforms.functional.affine( + img=volume[frame_id, ...].unsqueeze(0), + angle=-float(angle_donor_volume[frame_id]), + translate=[0, 0], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=fill_value, + ).squeeze(0) + + return (acceptor, donor, oxygenation, volume, angle_donor_volume) diff --git a/functions/perform_donor_volume_translation.py b/functions/perform_donor_volume_translation.py new file mode 100644 index 0000000..72e94fa --- /dev/null +++ b/functions/perform_donor_volume_translation.py @@ -0,0 +1,210 @@ +import torch +import torchvision as tv # type: ignore +import logging + +from functions.calculate_translation import calculate_translation +from functions.ImageAlignment import ImageAlignment + + +@torch.no_grad() +def perform_donor_volume_translation( + mylogger: logging.Logger, + acceptor: torch.Tensor, + donor: torch.Tensor, + oxygenation: torch.Tensor, + volume: torch.Tensor, + ref_image_donor: torch.Tensor, + ref_image_volume: torch.Tensor, + batch_size: int, + config: dict, + fill_value: float = 0, +) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + try: + + return perform_donor_volume_translation_internal( + mylogger=mylogger, + acceptor=acceptor, + donor=donor, + oxygenation=oxygenation, + volume=volume, + ref_image_donor=ref_image_donor, + ref_image_volume=ref_image_volume, + batch_size=batch_size, + config=config, + fill_value=fill_value, + ) + + except torch.cuda.OutOfMemoryError: + + ( + acceptor_cpu, + donor_cpu, + oxygenation_cpu, + volume_cpu, + tvec_donor_volume_cpu, + ) = perform_donor_volume_translation_internal( + mylogger=mylogger, + acceptor=acceptor.cpu(), + donor=donor.cpu(), + oxygenation=oxygenation.cpu(), + volume=volume.cpu(), + ref_image_donor=ref_image_donor.cpu(), + ref_image_volume=ref_image_volume.cpu(), + batch_size=batch_size, + config=config, + fill_value=fill_value, + ) + + return ( + acceptor_cpu.to(device=acceptor.device), + donor_cpu.to(device=acceptor.device), + oxygenation_cpu.to(device=acceptor.device), + volume_cpu.to(device=acceptor.device), + tvec_donor_volume_cpu.to(device=acceptor.device), + ) + + +@torch.no_grad() +def perform_donor_volume_translation_internal( + mylogger: logging.Logger, + acceptor: torch.Tensor, + donor: torch.Tensor, + oxygenation: torch.Tensor, + volume: torch.Tensor, + ref_image_donor: torch.Tensor, + ref_image_volume: torch.Tensor, + batch_size: int, + config: dict, + fill_value: float = 0, +) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + + image_alignment = ImageAlignment( + default_dtype=acceptor.dtype, device=acceptor.device + ) + + mylogger.info("Calculate translation between donor data and donor ref image") + tvec_donor = calculate_translation( + input=donor, + reference_image=ref_image_donor, + image_alignment=image_alignment, + batch_size=batch_size, + ) + + mylogger.info("Calculate translation between volume data and volume ref image") + tvec_volume = calculate_translation( + input=volume, + reference_image=ref_image_volume, + image_alignment=image_alignment, + batch_size=batch_size, + ) + + mylogger.info("Average over both translations") + + for i in range(0, 2): + mylogger.info(f"Processing dimension {i}") + donor_threshold: torch.Tensor = torch.sort(torch.abs(tvec_donor[:, i]))[0] + donor_threshold = donor_threshold[ + int( + donor_threshold.shape[0] + * float(config["rotation_stabilization_threshold_border"]) + ) + ] * float(config["rotation_stabilization_threshold_factor"]) + + volume_threshold: torch.Tensor = torch.sort(torch.abs(tvec_volume[:, i]))[0] + volume_threshold = volume_threshold[ + int( + volume_threshold.shape[0] + * float(config["rotation_stabilization_threshold_border"]) + ) + ] * float(config["rotation_stabilization_threshold_factor"]) + + donor_idx = torch.where(torch.abs(tvec_donor[:, i]) > donor_threshold)[0] + volume_idx = torch.where(torch.abs(tvec_volume[:, i]) > volume_threshold)[0] + mylogger.info( + f"Border: {config['rotation_stabilization_threshold_border']}, " + f"factor {config['rotation_stabilization_threshold_factor']} " + ) + mylogger.info( + f"Donor threshold: {donor_threshold:.3e}, volume threshold: {volume_threshold:.3e}" + ) + mylogger.info( + f"Found broken rotation values: " + f"donor {int(donor_idx.shape[0])}, " + f"volume {int(volume_idx.shape[0])}" + ) + tvec_donor[donor_idx, i] = tvec_volume[donor_idx, i] + tvec_volume[volume_idx, i] = tvec_donor[volume_idx, i] + + donor_idx = torch.where(torch.abs(tvec_donor[:, i]) > donor_threshold)[0] + volume_idx = torch.where(torch.abs(tvec_volume[:, i]) > volume_threshold)[0] + mylogger.info( + f"After fill in these broken rotation values remain: " + f"donor {int(donor_idx.shape[0])}, " + f"volume {int(volume_idx.shape[0])}" + ) + tvec_donor[donor_idx, i] = 0.0 + tvec_volume[volume_idx, i] = 0.0 + + tvec_donor_volume = (tvec_donor + tvec_volume) / 2.0 + + mylogger.info("Translate acceptor data based on the average translation vector") + for frame_id in range(0, tvec_donor_volume.shape[0]): + acceptor[frame_id, ...] = tv.transforms.functional.affine( + img=acceptor[frame_id, ...].unsqueeze(0), + angle=0, + translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=fill_value, + ).squeeze(0) + + mylogger.info("Translate donor data based on the average translation vector") + for frame_id in range(0, tvec_donor_volume.shape[0]): + donor[frame_id, ...] = tv.transforms.functional.affine( + img=donor[frame_id, ...].unsqueeze(0), + angle=0, + translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=fill_value, + ).squeeze(0) + + mylogger.info("Translate oxygenation data based on the average translation vector") + for frame_id in range(0, tvec_donor_volume.shape[0]): + oxygenation[frame_id, ...] = tv.transforms.functional.affine( + img=oxygenation[frame_id, ...].unsqueeze(0), + angle=0, + translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=fill_value, + ).squeeze(0) + + mylogger.info("Translate volume data based on the average translation vector") + for frame_id in range(0, tvec_donor_volume.shape[0]): + volume[frame_id, ...] = tv.transforms.functional.affine( + img=volume[frame_id, ...].unsqueeze(0), + angle=0, + translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=fill_value, + ).squeeze(0) + + return (acceptor, donor, oxygenation, volume, tvec_donor_volume) diff --git a/functions/regression.py b/functions/regression.py new file mode 100644 index 0000000..d4efac0 --- /dev/null +++ b/functions/regression.py @@ -0,0 +1,117 @@ +import torch +import logging +from functions.regression_internal import regression_internal + + +@torch.no_grad() +def regression( + mylogger: logging.Logger, + target_camera_id: int, + regressor_camera_ids: list[int], + mask: torch.Tensor, + data: torch.Tensor, + data_filtered: torch.Tensor, + first_none_ramp_frame: int, +) -> tuple[torch.Tensor, torch.Tensor]: + + assert len(regressor_camera_ids) > 0 + + mylogger.info("Prepare the target signal - 1.0 (from data_filtered)") + target_signals_train: torch.Tensor = ( + data_filtered[target_camera_id, ..., first_none_ramp_frame:].clone() - 1.0 + ) + target_signals_train[target_signals_train < -1] = 0.0 + + # Check if everything is happy + assert target_signals_train.ndim == 3 + assert target_signals_train.ndim == data[target_camera_id, ...].ndim + assert target_signals_train.shape[0] == data[target_camera_id, ...].shape[0] + assert target_signals_train.shape[1] == data[target_camera_id, ...].shape[1] + assert (target_signals_train.shape[2] + first_none_ramp_frame) == data[ + target_camera_id, ... + ].shape[2] + + mylogger.info("Prepare the regressor signals (linear plus from data_filtered)") + + regressor_signals_train: torch.Tensor = torch.zeros( + ( + data_filtered.shape[1], + data_filtered.shape[2], + data_filtered.shape[3], + len(regressor_camera_ids) + 1, + ), + device=data_filtered.device, + dtype=data_filtered.dtype, + ) + + mylogger.info("Copy the regressor signals - 1.0") + for matrix_id, id in enumerate(regressor_camera_ids): + regressor_signals_train[..., matrix_id] = data_filtered[id, ...] - 1.0 + + regressor_signals_train[regressor_signals_train < -1] = 0.0 + + mylogger.info("Create the linear regressor") + trend = torch.arange( + 0, regressor_signals_train.shape[-2], device=data_filtered.device + ) / float(regressor_signals_train.shape[-2] - 1) + trend -= trend.mean() + trend = trend.unsqueeze(0).unsqueeze(0) + trend = trend.tile( + (regressor_signals_train.shape[0], regressor_signals_train.shape[1], 1) + ) + regressor_signals_train[..., -1] = trend + + regressor_signals_train = regressor_signals_train[:, :, first_none_ramp_frame:, :] + + mylogger.info("Calculating the regression coefficients") + coefficients, intercept = regression_internal( + input_regressor=regressor_signals_train, input_target=target_signals_train + ) + del regressor_signals_train + del target_signals_train + + mylogger.info("Prepare the target signal - 1.0 (from data)") + target_signals_perform: torch.Tensor = data[target_camera_id, ...].clone() - 1.0 + + mylogger.info("Prepare the regressor signals (linear plus from data)") + regressor_signals_perform: torch.Tensor = torch.zeros( + ( + data.shape[1], + data.shape[2], + data.shape[3], + len(regressor_camera_ids) + 1, + ), + device=data.device, + dtype=data.dtype, + ) + + mylogger.info("Copy the regressor signals - 1.0 ") + for matrix_id, id in enumerate(regressor_camera_ids): + regressor_signals_perform[..., matrix_id] = data[id] - 1.0 + + mylogger.info("Create the linear regressor") + trend = torch.arange( + 0, regressor_signals_perform.shape[-2], device=data[0].device + ) / float(regressor_signals_perform.shape[-2] - 1) + trend -= trend.mean() + trend = trend.unsqueeze(0).unsqueeze(0) + trend = trend.tile( + (regressor_signals_perform.shape[0], regressor_signals_perform.shape[1], 1) + ) + regressor_signals_perform[..., -1] = trend + + mylogger.info("Remove regressors") + target_signals_perform -= ( + regressor_signals_perform * coefficients.unsqueeze(-2) + ).sum(dim=-1) + + mylogger.info("Remove offset") + target_signals_perform -= intercept.unsqueeze(-1) + + mylogger.info("Remove masked pixels") + target_signals_perform[mask, :] = 0.0 + + mylogger.info("Add an offset of 1.0") + target_signals_perform += 1.0 + + return target_signals_perform, coefficients diff --git a/functions/regression_internal.py b/functions/regression_internal.py new file mode 100644 index 0000000..dd94d3c --- /dev/null +++ b/functions/regression_internal.py @@ -0,0 +1,27 @@ +import torch + + +def regression_internal( + input_regressor: torch.Tensor, input_target: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + + regressor_offset = input_regressor.mean(keepdim=True, dim=-2) + target_offset = input_target.mean(keepdim=True, dim=-1) + + regressor = input_regressor - regressor_offset + target = input_target - target_offset + + try: + coefficients, _, _, _ = torch.linalg.lstsq(regressor, target, rcond=None) + except torch.cuda.OutOfMemoryError: + coefficients_cpu, _, _, _ = torch.linalg.lstsq( + regressor.cpu(), target.cpu(), rcond=None + ) + coefficients = coefficients_cpu.to(regressor.device, copy=True) + del coefficients_cpu + + intercept = target_offset.squeeze(-1) - ( + coefficients * regressor_offset.squeeze(-2) + ).sum(dim=-1) + + return coefficients, intercept diff --git a/geci/config_M_Sert_Cre_41.json b/geci/config_M_Sert_Cre_41.json new file mode 100644 index 0000000..48ede27 --- /dev/null +++ b/geci/config_M_Sert_Cre_41.json @@ -0,0 +1,67 @@ +{ + "basic_path": "/data_1/hendrik", + "recoding_data": "2023-07-17", + "mouse_identifier": "M_Sert_Cre_41", + "raw_path": "raw", + "export_path": "output/M_Sert_Cre_41", + "ref_image_path": "ref_images/M_Sert_Cre_41", + "heartbeat_remove": true, + "gevi": false, // true => gevi, false => geci + // Ratio Sequence + "classical_ratio_mode": true, // true: a/d false: 1+a-d + // Regression + // EMPTY FOR GECI "target_camera_acceptor": "acceptor", + "target_camera_acceptor": "", + "regressor_cameras_acceptor": [ + "oxygenation", + "volume" + ], + "target_camera_donor": "donor", + "regressor_cameras_donor": [ + // REMOVED FOR GECI "oxygenation", + "volume" + ], + // binning + "binning_enable": true, + "binning_at_the_end": false, + "binning_kernel_size": 4, + "binning_stride": 4, + "binning_divisor_override": 1, + // alignment + "alignment_batch_size": 200, + "rotation_stabilization_threshold_factor": 3.0, // >= 1.0 + "rotation_stabilization_threshold_border": 0.9, // <= 1.0 + // Heart beat detection + "lower_freqency_bandpass": 5.0, // Hz + "upper_freqency_bandpass": 14.0, // Hz + "heartbeat_filtfilt_chuck_size": 10, + // Gauss smear + "gauss_smear_spatial_width": 8, + "gauss_smear_temporal_width": 0.1, + "gauss_smear_use_matlab_mask": false, + // LED Ramp on + "skip_frames_in_the_beginning": 100, // Frames + // PyTorch + "dtype": "float32", + "force_to_cpu": false, + // Save + "save_as_python": true, // produces .npz files (compressed) + "save_as_matlab": false, // produces .hd5 file (compressed) + // Save extra information + "save_alignment": false, + "save_heartbeat": false, + "save_factors": false, + "save_regression_coefficients": false, + "save_aligned_as_python": false, + "save_aligned_as_matlab": false, + "save_oxyvol_as_python": false, + "save_oxyvol_as_matlab": false, + "save_gevi_with_donor_acceptor": true, + // Not important parameter + "required_order": [ + "acceptor", + "donor", + "oxygenation", + "volume" + ] +} diff --git a/geci/config_M_Sert_Cre_42.json b/geci/config_M_Sert_Cre_42.json new file mode 100644 index 0000000..daf6c3e --- /dev/null +++ b/geci/config_M_Sert_Cre_42.json @@ -0,0 +1,67 @@ +{ + "basic_path": "/data_1/hendrik", + "recoding_data": "2023-07-18", + "mouse_identifier": "M_Sert_Cre_42", + "raw_path": "raw", + "export_path": "output/M_Sert_Cre_42", + "ref_image_path": "ref_images/M_Sert_Cre_42", + "heartbeat_remove": true, + "gevi": false, // true => gevi, false => geci + // Ratio Sequence + "classical_ratio_mode": true, // true: a/d false: 1+a-d + // Regression + // EMPTY FOR GECI "target_camera_acceptor": "acceptor", + "target_camera_acceptor": "", + "regressor_cameras_acceptor": [ + "oxygenation", + "volume" + ], + "target_camera_donor": "donor", + "regressor_cameras_donor": [ + // REMOVED FOR GECI "oxygenation", + "volume" + ], + // binning + "binning_enable": true, + "binning_at_the_end": false, + "binning_kernel_size": 4, + "binning_stride": 4, + "binning_divisor_override": 1, + // alignment + "alignment_batch_size": 200, + "rotation_stabilization_threshold_factor": 3.0, // >= 1.0 + "rotation_stabilization_threshold_border": 0.9, // <= 1.0 + // Heart beat detection + "lower_freqency_bandpass": 5.0, // Hz + "upper_freqency_bandpass": 14.0, // Hz + "heartbeat_filtfilt_chuck_size": 10, + // Gauss smear + "gauss_smear_spatial_width": 8, + "gauss_smear_temporal_width": 0.1, + "gauss_smear_use_matlab_mask": false, + // LED Ramp on + "skip_frames_in_the_beginning": 100, // Frames + // PyTorch + "dtype": "float32", + "force_to_cpu": false, + // Save + "save_as_python": true, // produces .npz files (compressed) + "save_as_matlab": false, // produces .hd5 file (compressed) + // Save extra information + "save_alignment": false, + "save_heartbeat": false, + "save_factors": false, + "save_regression_coefficients": false, + "save_aligned_as_python": false, + "save_aligned_as_matlab": false, + "save_oxyvol_as_python": false, + "save_oxyvol_as_matlab": false, + "save_gevi_with_donor_acceptor": true, + // Not important parameter + "required_order": [ + "acceptor", + "donor", + "oxygenation", + "volume" + ] +} diff --git a/geci/config_M_Sert_Cre_45.json b/geci/config_M_Sert_Cre_45.json new file mode 100644 index 0000000..875faf5 --- /dev/null +++ b/geci/config_M_Sert_Cre_45.json @@ -0,0 +1,67 @@ +{ + "basic_path": "/data_1/hendrik", + "recoding_data": "2023-07-18", + "mouse_identifier": "M_Sert_Cre_45", + "raw_path": "raw", + "export_path": "output/M_Sert_Cre_45", + "ref_image_path": "ref_images/M_Sert_Cre_45", + "heartbeat_remove": true, + "gevi": false, // true => gevi, false => geci + // Ratio Sequence + "classical_ratio_mode": true, // true: a/d false: 1+a-d + // Regression + // EMPTY FOR GECI "target_camera_acceptor": "acceptor", + "target_camera_acceptor": "", + "regressor_cameras_acceptor": [ + "oxygenation", + "volume" + ], + "target_camera_donor": "donor", + "regressor_cameras_donor": [ + // REMOVED FOR GECI "oxygenation", + "volume" + ], + // binning + "binning_enable": true, + "binning_at_the_end": false, + "binning_kernel_size": 4, + "binning_stride": 4, + "binning_divisor_override": 1, + // alignment + "alignment_batch_size": 200, + "rotation_stabilization_threshold_factor": 3.0, // >= 1.0 + "rotation_stabilization_threshold_border": 0.9, // <= 1.0 + // Heart beat detection + "lower_freqency_bandpass": 5.0, // Hz + "upper_freqency_bandpass": 14.0, // Hz + "heartbeat_filtfilt_chuck_size": 10, + // Gauss smear + "gauss_smear_spatial_width": 8, + "gauss_smear_temporal_width": 0.1, + "gauss_smear_use_matlab_mask": false, + // LED Ramp on + "skip_frames_in_the_beginning": 100, // Frames + // PyTorch + "dtype": "float32", + "force_to_cpu": false, + // Save + "save_as_python": true, // produces .npz files (compressed) + "save_as_matlab": false, // produces .hd5 file (compressed) + // Save extra information + "save_alignment": false, + "save_heartbeat": false, + "save_factors": false, + "save_regression_coefficients": false, + "save_aligned_as_python": false, + "save_aligned_as_matlab": false, + "save_oxyvol_as_python": false, + "save_oxyvol_as_matlab": false, + "save_gevi_with_donor_acceptor": true, + // Not important parameter + "required_order": [ + "acceptor", + "donor", + "oxygenation", + "volume" + ] +} diff --git a/geci/config_M_Sert_Cre_46.json b/geci/config_M_Sert_Cre_46.json new file mode 100644 index 0000000..085a80d --- /dev/null +++ b/geci/config_M_Sert_Cre_46.json @@ -0,0 +1,67 @@ +{ + "basic_path": "/data_1/hendrik", + "recoding_data": "2023-03-16", + "mouse_identifier": "M_Sert_Cre_46", + "raw_path": "raw", + "export_path": "output/M_Sert_Cre_46", + "ref_image_path": "ref_images/M_Sert_Cre_46", + "heartbeat_remove": true, + "gevi": false, // true => gevi, false => geci + // Ratio Sequence + "classical_ratio_mode": true, // true: a/d false: 1+a-d + // Regression + // EMPTY FOR GECI "target_camera_acceptor": "acceptor", + "target_camera_acceptor": "", + "regressor_cameras_acceptor": [ + "oxygenation", + "volume" + ], + "target_camera_donor": "donor", + "regressor_cameras_donor": [ + // REMOVED FOR GECI "oxygenation", + "volume" + ], + // binning + "binning_enable": true, + "binning_at_the_end": false, + "binning_kernel_size": 4, + "binning_stride": 4, + "binning_divisor_override": 1, + // alignment + "alignment_batch_size": 200, + "rotation_stabilization_threshold_factor": 3.0, // >= 1.0 + "rotation_stabilization_threshold_border": 0.9, // <= 1.0 + // Heart beat detection + "lower_freqency_bandpass": 5.0, // Hz + "upper_freqency_bandpass": 14.0, // Hz + "heartbeat_filtfilt_chuck_size": 10, + // Gauss smear + "gauss_smear_spatial_width": 8, + "gauss_smear_temporal_width": 0.1, + "gauss_smear_use_matlab_mask": false, + // LED Ramp on + "skip_frames_in_the_beginning": 100, // Frames + // PyTorch + "dtype": "float32", + "force_to_cpu": false, + // Save + "save_as_python": true, // produces .npz files (compressed) + "save_as_matlab": false, // produces .hd5 file (compressed) + // Save extra information + "save_alignment": false, + "save_heartbeat": false, + "save_factors": false, + "save_regression_coefficients": false, + "save_aligned_as_python": false, + "save_aligned_as_matlab": false, + "save_oxyvol_as_python": false, + "save_oxyvol_as_matlab": false, + "save_gevi_with_donor_acceptor": true, + // Not important parameter + "required_order": [ + "acceptor", + "donor", + "oxygenation", + "volume" + ] +} diff --git a/geci/config_M_Sert_Cre_49.json b/geci/config_M_Sert_Cre_49.json new file mode 100644 index 0000000..11e6e8c --- /dev/null +++ b/geci/config_M_Sert_Cre_49.json @@ -0,0 +1,67 @@ +{ + "basic_path": "/data_1/hendrik", + "recoding_data": "2023-03-15", + "mouse_identifier": "M_Sert_Cre_49", + "raw_path": "raw", + "export_path": "output/M_Sert_Cre_49", + "ref_image_path": "ref_images/M_Sert_Cre_49", + "heartbeat_remove": true, + "gevi": false, // true => gevi, false => geci + // Ratio Sequence + "classical_ratio_mode": true, // true: a/d false: 1+a-d + // Regression + // EMPTY FOR GECI "target_camera_acceptor": "acceptor", + "target_camera_acceptor": "", + "regressor_cameras_acceptor": [ + "oxygenation", + "volume" + ], + "target_camera_donor": "donor", + "regressor_cameras_donor": [ + // REMOVED FOR GECI "oxygenation", + "volume" + ], + // binning + "binning_enable": true, + "binning_at_the_end": false, + "binning_kernel_size": 4, + "binning_stride": 4, + "binning_divisor_override": 1, + // alignment + "alignment_batch_size": 200, + "rotation_stabilization_threshold_factor": 3.0, // >= 1.0 + "rotation_stabilization_threshold_border": 0.9, // <= 1.0 + // Heart beat detection + "lower_freqency_bandpass": 5.0, // Hz + "upper_freqency_bandpass": 14.0, // Hz + "heartbeat_filtfilt_chuck_size": 10, + // Gauss smear + "gauss_smear_spatial_width": 8, + "gauss_smear_temporal_width": 0.1, + "gauss_smear_use_matlab_mask": false, + // LED Ramp on + "skip_frames_in_the_beginning": 100, // Frames + // PyTorch + "dtype": "float32", + "force_to_cpu": false, + // Save + "save_as_python": true, // produces .npz files (compressed) + "save_as_matlab": false, // produces .hd5 file (compressed) + // Save extra information + "save_alignment": false, + "save_heartbeat": false, + "save_factors": false, + "save_regression_coefficients": false, + "save_aligned_as_python": false, + "save_aligned_as_matlab": false, + "save_oxyvol_as_python": false, + "save_oxyvol_as_matlab": false, + "save_gevi_with_donor_acceptor": true, + // Not important parameter + "required_order": [ + "acceptor", + "donor", + "oxygenation", + "volume" + ] +} diff --git a/geci/config_example_GECI.json b/geci/config_example_GECI.json new file mode 100644 index 0000000..48ede27 --- /dev/null +++ b/geci/config_example_GECI.json @@ -0,0 +1,67 @@ +{ + "basic_path": "/data_1/hendrik", + "recoding_data": "2023-07-17", + "mouse_identifier": "M_Sert_Cre_41", + "raw_path": "raw", + "export_path": "output/M_Sert_Cre_41", + "ref_image_path": "ref_images/M_Sert_Cre_41", + "heartbeat_remove": true, + "gevi": false, // true => gevi, false => geci + // Ratio Sequence + "classical_ratio_mode": true, // true: a/d false: 1+a-d + // Regression + // EMPTY FOR GECI "target_camera_acceptor": "acceptor", + "target_camera_acceptor": "", + "regressor_cameras_acceptor": [ + "oxygenation", + "volume" + ], + "target_camera_donor": "donor", + "regressor_cameras_donor": [ + // REMOVED FOR GECI "oxygenation", + "volume" + ], + // binning + "binning_enable": true, + "binning_at_the_end": false, + "binning_kernel_size": 4, + "binning_stride": 4, + "binning_divisor_override": 1, + // alignment + "alignment_batch_size": 200, + "rotation_stabilization_threshold_factor": 3.0, // >= 1.0 + "rotation_stabilization_threshold_border": 0.9, // <= 1.0 + // Heart beat detection + "lower_freqency_bandpass": 5.0, // Hz + "upper_freqency_bandpass": 14.0, // Hz + "heartbeat_filtfilt_chuck_size": 10, + // Gauss smear + "gauss_smear_spatial_width": 8, + "gauss_smear_temporal_width": 0.1, + "gauss_smear_use_matlab_mask": false, + // LED Ramp on + "skip_frames_in_the_beginning": 100, // Frames + // PyTorch + "dtype": "float32", + "force_to_cpu": false, + // Save + "save_as_python": true, // produces .npz files (compressed) + "save_as_matlab": false, // produces .hd5 file (compressed) + // Save extra information + "save_alignment": false, + "save_heartbeat": false, + "save_factors": false, + "save_regression_coefficients": false, + "save_aligned_as_python": false, + "save_aligned_as_matlab": false, + "save_oxyvol_as_python": false, + "save_oxyvol_as_matlab": false, + "save_gevi_with_donor_acceptor": true, + // Not important parameter + "required_order": [ + "acceptor", + "donor", + "oxygenation", + "volume" + ] +} diff --git a/geci/geci_loader.py b/geci/geci_loader.py new file mode 100644 index 0000000..a8f4da1 --- /dev/null +++ b/geci/geci_loader.py @@ -0,0 +1,168 @@ +import numpy as np +import os +import json +from jsmin import jsmin # type:ignore +import argh +from functions.get_trials import get_trials +from functions.get_experiments import get_experiments +import scipy # type: ignore + + +def func_pow(x, a, b, c): + return -a * x**b + c + + +def func_exp(x, a, b, c): + return a * np.exp(-x / b) + c + + +def loader( + filename: str = "config_M_Sert_Cre_49.json", + fpath: str|None = None, + skip_timesteps: int = 100, + # If there is no special ROI... Get one! This is just a backup + roi_control_path_default: str = "roi_controlM_Sert_Cre_49.npy", + roi_sdarken_path_default: str = "roi_sdarkenM_Sert_Cre_49.npy", + remove_fit: bool = True, + fit_power: bool = False, # True => -ax^b ; False => exp(-b) +) -> None: + + if fpath is None: + fpath = os.getcwd() + + if os.path.isfile(filename) is False: + print(f"{filename} is missing") + exit() + + with open(filename, "r") as file: + config = json.loads(jsmin(file.read())) + + raw_data_path: str = os.path.join( + config["basic_path"], + config["recoding_data"], + config["mouse_identifier"], + config["raw_path"], + ) + + if remove_fit: + roi_control_path: str = f"roi_control{config['mouse_identifier']}.npy" + roi_sdarken_path: str = f"roi_sdarken{config['mouse_identifier']}.npy" + + if os.path.isfile(roi_control_path) is False: + print(f"Using replacement RIO: {roi_control_path_default}") + roi_control_path = roi_control_path_default + + if os.path.isfile(roi_sdarken_path) is False: + print(f"Using replacement RIO: {roi_sdarken_path_default}") + roi_sdarken_path = roi_sdarken_path_default + + roi_control: np.ndarray = np.load(roi_control_path) + roi_darken: np.ndarray = np.load(roi_sdarken_path) + + experiments = get_experiments(raw_data_path).numpy() + n_exp = experiments.shape[0] + + first_run: bool = True + + for i_exp in range(0, n_exp): + trials = get_trials(raw_data_path, experiments[i_exp]).numpy() + n_tri = trials.shape[0] + + for i_tri in range(0, n_tri): + + experiment_name: str = ( + f"Exp{experiments[i_exp]:03d}_Trial{trials[i_tri]:03d}" + ) + tmp_fname = os.path.join( + fpath, + config["export_path"], + experiment_name + "_acceptor_donor.npz", + ) + print(f'Processing file "{tmp_fname}"...') + tmp = np.load(tmp_fname) + + tmp_data_sequence = tmp["data_donor"] + tmp_data_sequence = tmp_data_sequence[:, :, skip_timesteps:] + tmp_light_signal = tmp["data_acceptor"] + tmp_light_signal = tmp_light_signal[:, :, skip_timesteps:] + + if first_run: + mask = tmp["mask"] + new_shape = [n_exp, *tmp_data_sequence.shape] + data_sequence = np.zeros(new_shape) + light_signal = np.zeros(new_shape) + first_run = False + + if remove_fit: + roi_control *= mask + assert roi_control.sum() > 0, "ROI control empty" + roi_darken *= mask + assert roi_darken.sum() > 0, "ROI sDarken empty" + + if remove_fit: + combined_matrix = (roi_darken + roi_control) > 0 + idx = np.where(combined_matrix) + for idx_pos in range(0, idx[0].shape[0]): + + temp = tmp_data_sequence[idx[0][idx_pos], idx[1][idx_pos], :] + temp -= temp.mean() + + data_time = np.arange(0, temp.shape[0], dtype=np.float32) + skip_timesteps + data_time /= 100.0 + + data_min = temp.min() + data_max = temp.max() + data_delta = data_max - data_min + a_min = data_min - data_delta + b_min = 0.01 + a_max = data_max + data_delta + if fit_power: + b_max = 10.0 + else: + b_max = 100.0 + c_min = data_min - data_delta + c_max = data_max + data_delta + + try: + if fit_power: + popt, _ = scipy.optimize.curve_fit( + f=func_pow, + xdata=data_time, + ydata=np.nan_to_num(temp), + bounds=([a_min, b_min, c_min], [a_max, b_max, c_max]), + ) + pattern: np.ndarray | None = func_pow(data_time, *popt) + else: + popt, _ = scipy.optimize.curve_fit( + f=func_exp, + xdata=data_time, + ydata=np.nan_to_num(temp), + bounds=([a_min, b_min, c_min], [a_max, b_max, c_max]), + ) + pattern = func_exp(data_time, *popt) + + assert pattern is not None + pattern -= pattern.mean() + + scale = (temp * pattern).sum() / (pattern**2).sum() + pattern *= scale + + except ValueError: + print(f"Fit failed: Position ({idx[0][idx_pos]}, {idx[1][idx_pos]}") + pattern = None + + if pattern is not None: + temp -= pattern + tmp_data_sequence[idx[0][idx_pos], idx[1][idx_pos], :] = temp + + data_sequence[i_exp] += tmp_data_sequence + light_signal[i_exp] += tmp_light_signal + data_sequence[i_exp] /= n_tri + light_signal[i_exp] /= n_tri + np.save(os.path.join(fpath, config["export_path"], "dsq_" + config["mouse_identifier"]), data_sequence) + np.save(os.path.join(fpath, config["export_path"], "lsq_" + config["mouse_identifier"]), light_signal) + np.save(os.path.join(fpath, config["export_path"], "msq_" + config["mouse_identifier"]), mask) + + +if __name__ == "__main__": + argh.dispatch_command(loader) diff --git a/geci/geci_plot.py b/geci/geci_plot.py new file mode 100644 index 0000000..8d32ab4 --- /dev/null +++ b/geci/geci_plot.py @@ -0,0 +1,181 @@ +# %% + +import numpy as np +import matplotlib.pyplot as plt +import argh +import scipy # type: ignore +import json +import os +from jsmin import jsmin # type:ignore + + +def func_pow(x, a, b, c): + return -a * x**b + c + + +def func_exp(x, a, b, c): + return a * np.exp(-x / b) + c + + +# mouse: int = 0, 1, 2, 3, 4 +def plot( + filename: str = "config_M_Sert_Cre_49.json", + fpath: str | None = None, + experiment: int = 4, + skip_timesteps: int = 100, + remove_fit: bool = False, + fit_power: bool = False, # True => -ax^b ; False => exp(-b) +) -> None: + + if fpath is None: + fpath = os.getcwd() + + if os.path.isfile(filename) is False: + print(f"{filename} is missing") + exit() + + with open(filename, "r") as file: + config = json.loads(jsmin(file.read())) + + raw_data_path: str = os.path.join( + config["basic_path"], + config["recoding_data"], + config["mouse_identifier"], + config["raw_path"], + ) + + if os.path.isdir(raw_data_path) is False: + print(f"ERROR: could not find raw directory {raw_data_path}!!!!") + exit() + + with open(f"meta_{config['mouse_identifier']}_exp{experiment:03d}.json", "r") as file: + metadata = json.loads(jsmin(file.read())) + + experiment_names = metadata['sessionMetaData']['experimentNames'][str(experiment)] + + roi_control_path: str = f"roi_control{config['mouse_identifier']}.npy" + roi_sdarken_path: str = f"roi_sdarken{config['mouse_identifier']}.npy" + + assert os.path.isfile(roi_control_path) + assert os.path.isfile(roi_sdarken_path) + + print("Load data...") + data = np.load(os.path.join(fpath, config["export_path"], "dsq_" + config["mouse_identifier"] + ".npy"), mmap_mode="r") + + print("Load light signal...") + light = np.load(os.path.join(fpath, config["export_path"], "lsq_" + config["mouse_identifier"] + ".npy"), mmap_mode="r") + + print("Load mask...") + mask = np.load(os.path.join(fpath, config["export_path"], "msq_" + config["mouse_identifier"] + ".npy")) + + roi_control = np.load(roi_control_path) + roi_control *= mask + assert roi_control.sum() > 0, "ROI control empty" + + roi_darken = np.load(roi_sdarken_path) + roi_darken *= mask + assert roi_darken.sum() > 0, "ROI sDarken empty" + + plt.figure(1) + a_show = data[experiment - 1, :, :, 1000].copy() + a_show[(roi_darken + roi_control) < 0.5] = np.nan + plt.imshow(a_show) + plt.title(f"{config['mouse_identifier']} -- Experiment: {experiment}") + plt.show(block=False) + + plt.figure(2) + a_dontshow = data[experiment - 1, :, :, 1000].copy() + a_dontshow[(roi_darken + roi_control) > 0.5] = np.nan + plt.imshow(a_dontshow) + plt.title(f"{config['mouse_identifier']} -- Experiment: {experiment}") + plt.show(block=False) + + plt.figure(3) + if remove_fit: + light_exp = light[experiment - 1, :, :, skip_timesteps:].copy() + else: + light_exp = light[experiment - 1, :, :, :].copy() + light_exp[(roi_darken + roi_control) < 0.5, :] = 0.0 + light_signal = light_exp.mean(axis=(0, 1)) + light_signal -= light_signal.min() + light_signal /= light_signal.max() + + if remove_fit: + a_exp = data[experiment - 1, :, :, skip_timesteps:].copy() + else: + a_exp = data[experiment - 1, :, :, :].copy() + + if remove_fit: + combined_matrix = (roi_darken + roi_control) > 0 + idx = np.where(combined_matrix) + for idx_pos in range(0, idx[0].shape[0]): + temp = a_exp[idx[0][idx_pos], idx[1][idx_pos], :] + temp -= temp.mean() + + data_time = np.arange(0, temp.shape[0], dtype=np.float32) + skip_timesteps + data_time /= 100.0 + + data_min = temp.min() + data_max = temp.max() + data_delta = data_max - data_min + a_min = data_min - data_delta + b_min = 0.01 + a_max = data_max + data_delta + if fit_power: + b_max = 10.0 + else: + b_max = 100.0 + c_min = data_min - data_delta + c_max = data_max + data_delta + + try: + if fit_power: + popt, _ = scipy.optimize.curve_fit( + f=func_pow, + xdata=data_time, + ydata=np.nan_to_num(temp), + bounds=([a_min, b_min, c_min], [a_max, b_max, c_max]), + ) + pattern: np.ndarray | None = func_pow(data_time, *popt) + else: + popt, _ = scipy.optimize.curve_fit( + f=func_exp, + xdata=data_time, + ydata=np.nan_to_num(temp), + bounds=([a_min, b_min, c_min], [a_max, b_max, c_max]), + ) + pattern = func_exp(data_time, *popt) + + assert pattern is not None + pattern -= pattern.mean() + + scale = (temp * pattern).sum() / (pattern**2).sum() + pattern *= scale + + except ValueError: + print(f"Fit failed: Position ({idx[0][idx_pos]}, {idx[1][idx_pos]}") + pattern = None + + if pattern is not None: + temp -= pattern + a_exp[idx[0][idx_pos], idx[1][idx_pos], :] = temp + + darken = a_exp[roi_darken > 0.5, :].sum(axis=0) / (roi_darken > 0.5).sum() + lighten = a_exp[roi_control > 0.5, :].sum(axis=0) / (roi_control > 0.5).sum() + + light_signal *= darken.max() - darken.min() + light_signal += darken.min() + + time_axis = np.arange(0, lighten.shape[-1], dtype=np.float32) + skip_timesteps + time_axis /= 100.0 + + plt.plot(time_axis, light_signal, c="k", label="light") + plt.plot(time_axis, darken, label="sDarken") + plt.plot(time_axis, lighten, label="control") + plt.title(f"{config['mouse_identifier']} -- Experiment: {experiment} ({experiment_names})") + plt.legend() + plt.show() + + +if __name__ == "__main__": + argh.dispatch_command(plot) diff --git a/geci/stage_6_convert_roi.py b/geci/stage_6_convert_roi.py new file mode 100644 index 0000000..7bedc29 --- /dev/null +++ b/geci/stage_6_convert_roi.py @@ -0,0 +1,53 @@ +import json +import os +import argh +from jsmin import jsmin # type:ignore +import numpy as np +import h5py + + +def converter(filename: str = "config_M_Sert_Cre_49.json") -> None: + + if os.path.isfile(filename) is False: + print(f"{filename} is missing") + exit() + + with open(filename, "r") as file: + config = json.loads(jsmin(file.read())) + + raw_data_path: str = os.path.join( + config["basic_path"], + config["recoding_data"], + config["mouse_identifier"], + config["raw_path"], + ) + + if os.path.isdir(raw_data_path) is False: + print(f"ERROR: could not find raw directory {raw_data_path}!!!!") + exit() + + roi_path: str = os.path.join( + config["basic_path"], config["recoding_data"], config["mouse_identifier"] + ) + roi_control_mat: str = os.path.join(roi_path, "ROI_control.mat") + roi_sdarken_mat: str = os.path.join(roi_path, "ROI_sDarken.mat") + + if os.path.isfile(roi_control_mat): + hf = h5py.File(roi_control_mat, "r") + roi_control = np.array(hf["roi"]).T + filename_out: str = f"roi_control{config['mouse_identifier']}.npy" + np.save(filename_out, roi_control) + else: + print("ROI Control not found") + + if os.path.isfile(roi_sdarken_mat): + hf = h5py.File(roi_sdarken_mat, "r") + roi_darken = np.array(hf["roi"]).T + filename_out: str = f"roi_sdarken{config['mouse_identifier']}.npy" + np.save(filename_out, roi_darken) + else: + print("ROI sDarken not found") + + +if __name__ == "__main__": + argh.dispatch_command(converter) diff --git a/gevi/config_M0134M_2024-11-06_SessionA.json b/gevi/config_M0134M_2024-11-06_SessionA.json new file mode 100644 index 0000000..b6f4da8 --- /dev/null +++ b/gevi/config_M0134M_2024-11-06_SessionA.json @@ -0,0 +1,67 @@ +{ + "basic_path": "/data_1/fatma/GEVI/", + "recoding_data": "2024-11-06", + "mouse_identifier": "M0134M_SessionA", + "raw_path": "raw", + "export_path": "output/M0134M_2024-11-06_SessionA", + "ref_image_path": "ref_images/M0134M_2024-11-06_SessionA", + "raw_path": "raw", + "heartbeat_remove": true, + "gevi": true, // true => gevi, false => geci + // Ratio Sequence + "classical_ratio_mode": true, // true: a/d false: 1+a-d + // Regression + "target_camera_acceptor": "acceptor", + "regressor_cameras_acceptor": [ + "oxygenation", + "volume" + ], + "target_camera_donor": "donor", + "regressor_cameras_donor": [ + "oxygenation", + "volume" + ], + // binning + "binning_enable": true, + "binning_at_the_end": false, + "binning_kernel_size": 4, + "binning_stride": 4, + "binning_divisor_override": 1, + // alignment + "alignment_batch_size": 200, + "rotation_stabilization_threshold_factor": 3.0, // >= 1.0 + "rotation_stabilization_threshold_border": 0.9, // <= 1.0 + // Heart beat detection + "lower_freqency_bandpass": 5.0, // Hz + "upper_freqency_bandpass": 14.0, // Hz + "heartbeat_filtfilt_chuck_size": 10, + // Gauss smear + "gauss_smear_spatial_width": 8, + "gauss_smear_temporal_width": 0.1, + "gauss_smear_use_matlab_mask": false, + // LED Ramp on + "skip_frames_in_the_beginning": 100, // Frames + // PyTorch + "dtype": "float32", + "force_to_cpu": false, + // Save + "save_as_python": true, // produces .npz files (compressed) + "save_as_matlab": false, // produces .hd5 file (compressed) + // Save extra information + "save_alignment": false, + "save_heartbeat": false, + "save_factors": false, + "save_regression_coefficients": false, + "save_aligned_as_python": false, + "save_aligned_as_matlab": false, + "save_oxyvol_as_python": false, + "save_oxyvol_as_matlab": false, + "save_gevi_with_donor_acceptor": true, + // Not important parameter + "required_order": [ + "acceptor", + "donor", + "oxygenation", + "volume" + ] +} diff --git a/gevi/config_M0134M_2024-11-06_SessionB.json b/gevi/config_M0134M_2024-11-06_SessionB.json new file mode 100644 index 0000000..b620fbd --- /dev/null +++ b/gevi/config_M0134M_2024-11-06_SessionB.json @@ -0,0 +1,67 @@ +{ + "basic_path": "/data_1/fatma/GEVI/", + "recoding_data": "2024-11-06", + "mouse_identifier": "M0134M_SessionB", + "raw_path": "raw", + "export_path": "output/M0134M_2024-11-06_SessionB", + "ref_image_path": "ref_images/M0134M_2024-11-06_SessionB", + "raw_path": "raw", + "heartbeat_remove": true, + "gevi": true, // true => gevi, false => geci + // Ratio Sequence + "classical_ratio_mode": true, // true: a/d false: 1+a-d + // Regression + "target_camera_acceptor": "acceptor", + "regressor_cameras_acceptor": [ + "oxygenation", + "volume" + ], + "target_camera_donor": "donor", + "regressor_cameras_donor": [ + "oxygenation", + "volume" + ], + // binning + "binning_enable": true, + "binning_at_the_end": false, + "binning_kernel_size": 4, + "binning_stride": 4, + "binning_divisor_override": 1, + // alignment + "alignment_batch_size": 200, + "rotation_stabilization_threshold_factor": 3.0, // >= 1.0 + "rotation_stabilization_threshold_border": 0.9, // <= 1.0 + // Heart beat detection + "lower_freqency_bandpass": 5.0, // Hz + "upper_freqency_bandpass": 14.0, // Hz + "heartbeat_filtfilt_chuck_size": 10, + // Gauss smear + "gauss_smear_spatial_width": 8, + "gauss_smear_temporal_width": 0.1, + "gauss_smear_use_matlab_mask": false, + // LED Ramp on + "skip_frames_in_the_beginning": 100, // Frames + // PyTorch + "dtype": "float32", + "force_to_cpu": false, + // Save + "save_as_python": true, // produces .npz files (compressed) + "save_as_matlab": false, // produces .hd5 file (compressed) + // Save extra information + "save_alignment": false, + "save_heartbeat": false, + "save_factors": false, + "save_regression_coefficients": false, + "save_aligned_as_python": false, + "save_aligned_as_matlab": false, + "save_oxyvol_as_python": false, + "save_oxyvol_as_matlab": false, + "save_gevi_with_donor_acceptor": true, + // Not important parameter + "required_order": [ + "acceptor", + "donor", + "oxygenation", + "volume" + ] +} diff --git a/gevi/config_M0134M_2024-11-07_SessionA.json b/gevi/config_M0134M_2024-11-07_SessionA.json new file mode 100644 index 0000000..01fb9b3 --- /dev/null +++ b/gevi/config_M0134M_2024-11-07_SessionA.json @@ -0,0 +1,67 @@ +{ + "basic_path": "/data_1/fatma/GEVI/", + "recoding_data": "2024-11-07", + "mouse_identifier": "M0134M_SessionA", + "raw_path": "raw", + "export_path": "output/M0134M_2024-11-07_SessionA", + "ref_image_path": "ref_images/M0134M_2024-11-07_SessionA", + "raw_path": "raw", + "heartbeat_remove": true, + "gevi": true, // true => gevi, false => geci + // Ratio Sequence + "classical_ratio_mode": true, // true: a/d false: 1+a-d + // Regression + "target_camera_acceptor": "acceptor", + "regressor_cameras_acceptor": [ + "oxygenation", + "volume" + ], + "target_camera_donor": "donor", + "regressor_cameras_donor": [ + "oxygenation", + "volume" + ], + // binning + "binning_enable": true, + "binning_at_the_end": false, + "binning_kernel_size": 4, + "binning_stride": 4, + "binning_divisor_override": 1, + // alignment + "alignment_batch_size": 200, + "rotation_stabilization_threshold_factor": 3.0, // >= 1.0 + "rotation_stabilization_threshold_border": 0.9, // <= 1.0 + // Heart beat detection + "lower_freqency_bandpass": 5.0, // Hz + "upper_freqency_bandpass": 14.0, // Hz + "heartbeat_filtfilt_chuck_size": 10, + // Gauss smear + "gauss_smear_spatial_width": 8, + "gauss_smear_temporal_width": 0.1, + "gauss_smear_use_matlab_mask": false, + // LED Ramp on + "skip_frames_in_the_beginning": 100, // Frames + // PyTorch + "dtype": "float32", + "force_to_cpu": false, + // Save + "save_as_python": true, // produces .npz files (compressed) + "save_as_matlab": false, // produces .hd5 file (compressed) + // Save extra information + "save_alignment": false, + "save_heartbeat": false, + "save_factors": false, + "save_regression_coefficients": false, + "save_aligned_as_python": false, + "save_aligned_as_matlab": false, + "save_oxyvol_as_python": false, + "save_oxyvol_as_matlab": false, + "save_gevi_with_donor_acceptor": true, + // Not important parameter + "required_order": [ + "acceptor", + "donor", + "oxygenation", + "volume" + ] +} diff --git a/gevi/config_M0134M_2024-11-07_SessionB.json b/gevi/config_M0134M_2024-11-07_SessionB.json new file mode 100644 index 0000000..d92b34b --- /dev/null +++ b/gevi/config_M0134M_2024-11-07_SessionB.json @@ -0,0 +1,67 @@ +{ + "basic_path": "/data_1/fatma/GEVI/", + "recoding_data": "2024-11-07", + "mouse_identifier": "M0134M_SessionB", + "raw_path": "raw", + "export_path": "output/M0134M_2024-11-07_SessionB", + "ref_image_path": "ref_images/M0134M_2024-11-07_SessionB", + "raw_path": "raw", + "heartbeat_remove": true, + "gevi": true, // true => gevi, false => geci + // Ratio Sequence + "classical_ratio_mode": true, // true: a/d false: 1+a-d + // Regression + "target_camera_acceptor": "acceptor", + "regressor_cameras_acceptor": [ + "oxygenation", + "volume" + ], + "target_camera_donor": "donor", + "regressor_cameras_donor": [ + "oxygenation", + "volume" + ], + // binning + "binning_enable": true, + "binning_at_the_end": false, + "binning_kernel_size": 4, + "binning_stride": 4, + "binning_divisor_override": 1, + // alignment + "alignment_batch_size": 200, + "rotation_stabilization_threshold_factor": 3.0, // >= 1.0 + "rotation_stabilization_threshold_border": 0.9, // <= 1.0 + // Heart beat detection + "lower_freqency_bandpass": 5.0, // Hz + "upper_freqency_bandpass": 14.0, // Hz + "heartbeat_filtfilt_chuck_size": 10, + // Gauss smear + "gauss_smear_spatial_width": 8, + "gauss_smear_temporal_width": 0.1, + "gauss_smear_use_matlab_mask": false, + // LED Ramp on + "skip_frames_in_the_beginning": 100, // Frames + // PyTorch + "dtype": "float32", + "force_to_cpu": false, + // Save + "save_as_python": true, // produces .npz files (compressed) + "save_as_matlab": false, // produces .hd5 file (compressed) + // Save extra information + "save_alignment": false, + "save_heartbeat": false, + "save_factors": false, + "save_regression_coefficients": false, + "save_aligned_as_python": false, + "save_aligned_as_matlab": false, + "save_oxyvol_as_python": false, + "save_oxyvol_as_matlab": false, + "save_gevi_with_donor_acceptor": true, + // Not important parameter + "required_order": [ + "acceptor", + "donor", + "oxygenation", + "volume" + ] +} diff --git a/gevi/config_M0134M_2024-11-13_SessionA.json b/gevi/config_M0134M_2024-11-13_SessionA.json new file mode 100644 index 0000000..eab7d1e --- /dev/null +++ b/gevi/config_M0134M_2024-11-13_SessionA.json @@ -0,0 +1,67 @@ +{ + "basic_path": "/data_1/fatma/GEVI/", + "recoding_data": "2024-11-13", + "mouse_identifier": "M0134M_SessionA", + "raw_path": "raw", + "export_path": "output/M0134M_2024-11-13_SessionA", + "ref_image_path": "ref_images/M0134M_2024-11-13_SessionA", + "raw_path": "raw", + "heartbeat_remove": true, + "gevi": true, // true => gevi, false => geci + // Ratio Sequence + "classical_ratio_mode": true, // true: a/d false: 1+a-d + // Regression + "target_camera_acceptor": "acceptor", + "regressor_cameras_acceptor": [ + "oxygenation", + "volume" + ], + "target_camera_donor": "donor", + "regressor_cameras_donor": [ + "oxygenation", + "volume" + ], + // binning + "binning_enable": true, + "binning_at_the_end": false, + "binning_kernel_size": 4, + "binning_stride": 4, + "binning_divisor_override": 1, + // alignment + "alignment_batch_size": 200, + "rotation_stabilization_threshold_factor": 3.0, // >= 1.0 + "rotation_stabilization_threshold_border": 0.9, // <= 1.0 + // Heart beat detection + "lower_freqency_bandpass": 5.0, // Hz + "upper_freqency_bandpass": 14.0, // Hz + "heartbeat_filtfilt_chuck_size": 10, + // Gauss smear + "gauss_smear_spatial_width": 8, + "gauss_smear_temporal_width": 0.1, + "gauss_smear_use_matlab_mask": false, + // LED Ramp on + "skip_frames_in_the_beginning": 100, // Frames + // PyTorch + "dtype": "float32", + "force_to_cpu": false, + // Save + "save_as_python": true, // produces .npz files (compressed) + "save_as_matlab": false, // produces .hd5 file (compressed) + // Save extra information + "save_alignment": false, + "save_heartbeat": false, + "save_factors": false, + "save_regression_coefficients": false, + "save_aligned_as_python": false, + "save_aligned_as_matlab": false, + "save_oxyvol_as_python": false, + "save_oxyvol_as_matlab": false, + "save_gevi_with_donor_acceptor": true, + // Not important parameter + "required_order": [ + "acceptor", + "donor", + "oxygenation", + "volume" + ] +} diff --git a/gevi/config_M0134M_2024-11-13_SessionB.json b/gevi/config_M0134M_2024-11-13_SessionB.json new file mode 100644 index 0000000..0ae7eab --- /dev/null +++ b/gevi/config_M0134M_2024-11-13_SessionB.json @@ -0,0 +1,67 @@ +{ + "basic_path": "/data_1/fatma/GEVI/", + "recoding_data": "2024-11-13", + "mouse_identifier": "M0134M_SessionB", + "raw_path": "raw", + "export_path": "output/M0134M_2024-11-13_SessionB", + "ref_image_path": "ref_images/M0134M_2024-11-13_SessionB", + "raw_path": "raw", + "heartbeat_remove": true, + "gevi": true, // true => gevi, false => geci + // Ratio Sequence + "classical_ratio_mode": true, // true: a/d false: 1+a-d + // Regression + "target_camera_acceptor": "acceptor", + "regressor_cameras_acceptor": [ + "oxygenation", + "volume" + ], + "target_camera_donor": "donor", + "regressor_cameras_donor": [ + "oxygenation", + "volume" + ], + // binning + "binning_enable": true, + "binning_at_the_end": false, + "binning_kernel_size": 4, + "binning_stride": 4, + "binning_divisor_override": 1, + // alignment + "alignment_batch_size": 200, + "rotation_stabilization_threshold_factor": 3.0, // >= 1.0 + "rotation_stabilization_threshold_border": 0.9, // <= 1.0 + // Heart beat detection + "lower_freqency_bandpass": 5.0, // Hz + "upper_freqency_bandpass": 14.0, // Hz + "heartbeat_filtfilt_chuck_size": 10, + // Gauss smear + "gauss_smear_spatial_width": 8, + "gauss_smear_temporal_width": 0.1, + "gauss_smear_use_matlab_mask": false, + // LED Ramp on + "skip_frames_in_the_beginning": 100, // Frames + // PyTorch + "dtype": "float32", + "force_to_cpu": false, + // Save + "save_as_python": true, // produces .npz files (compressed) + "save_as_matlab": false, // produces .hd5 file (compressed) + // Save extra information + "save_alignment": false, + "save_heartbeat": false, + "save_factors": false, + "save_regression_coefficients": false, + "save_aligned_as_python": false, + "save_aligned_as_matlab": false, + "save_oxyvol_as_python": false, + "save_oxyvol_as_matlab": false, + "save_gevi_with_donor_acceptor": true, + // Not important parameter + "required_order": [ + "acceptor", + "donor", + "oxygenation", + "volume" + ] +} diff --git a/gevi/config_M0134M_2024-11-15_SessionA.json b/gevi/config_M0134M_2024-11-15_SessionA.json new file mode 100644 index 0000000..c2aabf1 --- /dev/null +++ b/gevi/config_M0134M_2024-11-15_SessionA.json @@ -0,0 +1,67 @@ +{ + "basic_path": "/data_1/fatma/GEVI/", + "recoding_data": "2024-11-15", + "mouse_identifier": "M0134M_SessionA", + "raw_path": "raw", + "export_path": "output/M0134M_2024-11-15_SessionA", + "ref_image_path": "ref_images/M0134M_2024-11-15_SessionA", + "raw_path": "raw", + "heartbeat_remove": true, + "gevi": true, // true => gevi, false => geci + // Ratio Sequence + "classical_ratio_mode": true, // true: a/d false: 1+a-d + // Regression + "target_camera_acceptor": "acceptor", + "regressor_cameras_acceptor": [ + "oxygenation", + "volume" + ], + "target_camera_donor": "donor", + "regressor_cameras_donor": [ + "oxygenation", + "volume" + ], + // binning + "binning_enable": true, + "binning_at_the_end": false, + "binning_kernel_size": 4, + "binning_stride": 4, + "binning_divisor_override": 1, + // alignment + "alignment_batch_size": 200, + "rotation_stabilization_threshold_factor": 3.0, // >= 1.0 + "rotation_stabilization_threshold_border": 0.9, // <= 1.0 + // Heart beat detection + "lower_freqency_bandpass": 5.0, // Hz + "upper_freqency_bandpass": 14.0, // Hz + "heartbeat_filtfilt_chuck_size": 10, + // Gauss smear + "gauss_smear_spatial_width": 8, + "gauss_smear_temporal_width": 0.1, + "gauss_smear_use_matlab_mask": false, + // LED Ramp on + "skip_frames_in_the_beginning": 100, // Frames + // PyTorch + "dtype": "float32", + "force_to_cpu": false, + // Save + "save_as_python": true, // produces .npz files (compressed) + "save_as_matlab": false, // produces .hd5 file (compressed) + // Save extra information + "save_alignment": false, + "save_heartbeat": false, + "save_factors": false, + "save_regression_coefficients": false, + "save_aligned_as_python": false, + "save_aligned_as_matlab": false, + "save_oxyvol_as_python": false, + "save_oxyvol_as_matlab": false, + "save_gevi_with_donor_acceptor": true, + // Not important parameter + "required_order": [ + "acceptor", + "donor", + "oxygenation", + "volume" + ] +} diff --git a/gevi/config_M0134M_2024-11-15_SessionB.json b/gevi/config_M0134M_2024-11-15_SessionB.json new file mode 100644 index 0000000..3827bc9 --- /dev/null +++ b/gevi/config_M0134M_2024-11-15_SessionB.json @@ -0,0 +1,67 @@ +{ + "basic_path": "/data_1/fatma/GEVI/", + "recoding_data": "2024-11-15", + "mouse_identifier": "M0134M_SessionB", + "raw_path": "raw", + "export_path": "output/M0134M_2024-11-15_SessionB", + "ref_image_path": "ref_images/M0134M_2024-11-15_SessionB", + "raw_path": "raw", + "heartbeat_remove": true, + "gevi": true, // true => gevi, false => geci + // Ratio Sequence + "classical_ratio_mode": true, // true: a/d false: 1+a-d + // Regression + "target_camera_acceptor": "acceptor", + "regressor_cameras_acceptor": [ + "oxygenation", + "volume" + ], + "target_camera_donor": "donor", + "regressor_cameras_donor": [ + "oxygenation", + "volume" + ], + // binning + "binning_enable": true, + "binning_at_the_end": false, + "binning_kernel_size": 4, + "binning_stride": 4, + "binning_divisor_override": 1, + // alignment + "alignment_batch_size": 200, + "rotation_stabilization_threshold_factor": 3.0, // >= 1.0 + "rotation_stabilization_threshold_border": 0.9, // <= 1.0 + // Heart beat detection + "lower_freqency_bandpass": 5.0, // Hz + "upper_freqency_bandpass": 14.0, // Hz + "heartbeat_filtfilt_chuck_size": 10, + // Gauss smear + "gauss_smear_spatial_width": 8, + "gauss_smear_temporal_width": 0.1, + "gauss_smear_use_matlab_mask": false, + // LED Ramp on + "skip_frames_in_the_beginning": 100, // Frames + // PyTorch + "dtype": "float32", + "force_to_cpu": false, + // Save + "save_as_python": true, // produces .npz files (compressed) + "save_as_matlab": false, // produces .hd5 file (compressed) + // Save extra information + "save_alignment": false, + "save_heartbeat": false, + "save_factors": false, + "save_regression_coefficients": false, + "save_aligned_as_python": false, + "save_aligned_as_matlab": false, + "save_oxyvol_as_python": false, + "save_oxyvol_as_matlab": false, + "save_gevi_with_donor_acceptor": true, + // Not important parameter + "required_order": [ + "acceptor", + "donor", + "oxygenation", + "volume" + ] +} diff --git a/gevi/config_M0134M_2024-11-18_SessionA.json b/gevi/config_M0134M_2024-11-18_SessionA.json new file mode 100644 index 0000000..e9e0d00 --- /dev/null +++ b/gevi/config_M0134M_2024-11-18_SessionA.json @@ -0,0 +1,67 @@ +{ + "basic_path": "/data_1/fatma/GEVI/", + "recoding_data": "2024-11-18", + "mouse_identifier": "M0134M_SessionA", + "raw_path": "raw", + "export_path": "output/M0134M_2024-11-18_SessionA", + "ref_image_path": "ref_images/M0134M_2024-11-18_SessionA", + "raw_path": "raw", + "heartbeat_remove": true, + "gevi": true, // true => gevi, false => geci + // Ratio Sequence + "classical_ratio_mode": true, // true: a/d false: 1+a-d + // Regression + "target_camera_acceptor": "acceptor", + "regressor_cameras_acceptor": [ + "oxygenation", + "volume" + ], + "target_camera_donor": "donor", + "regressor_cameras_donor": [ + "oxygenation", + "volume" + ], + // binning + "binning_enable": true, + "binning_at_the_end": false, + "binning_kernel_size": 4, + "binning_stride": 4, + "binning_divisor_override": 1, + // alignment + "alignment_batch_size": 200, + "rotation_stabilization_threshold_factor": 3.0, // >= 1.0 + "rotation_stabilization_threshold_border": 0.9, // <= 1.0 + // Heart beat detection + "lower_freqency_bandpass": 5.0, // Hz + "upper_freqency_bandpass": 14.0, // Hz + "heartbeat_filtfilt_chuck_size": 10, + // Gauss smear + "gauss_smear_spatial_width": 8, + "gauss_smear_temporal_width": 0.1, + "gauss_smear_use_matlab_mask": false, + // LED Ramp on + "skip_frames_in_the_beginning": 100, // Frames + // PyTorch + "dtype": "float32", + "force_to_cpu": false, + // Save + "save_as_python": true, // produces .npz files (compressed) + "save_as_matlab": false, // produces .hd5 file (compressed) + // Save extra information + "save_alignment": false, + "save_heartbeat": false, + "save_factors": false, + "save_regression_coefficients": false, + "save_aligned_as_python": false, + "save_aligned_as_matlab": false, + "save_oxyvol_as_python": false, + "save_oxyvol_as_matlab": false, + "save_gevi_with_donor_acceptor": true, + // Not important parameter + "required_order": [ + "acceptor", + "donor", + "oxygenation", + "volume" + ] +} diff --git a/gevi/config_M0134M_2024-11-18_SessionB.json b/gevi/config_M0134M_2024-11-18_SessionB.json new file mode 100644 index 0000000..143817b --- /dev/null +++ b/gevi/config_M0134M_2024-11-18_SessionB.json @@ -0,0 +1,67 @@ +{ + "basic_path": "/data_1/fatma/GEVI/", + "recoding_data": "2024-11-18", + "mouse_identifier": "M0134M_SessionB", + "raw_path": "raw", + "export_path": "output/M0134M_2024-11-18_SessionB", + "ref_image_path": "ref_images/M0134M_2024-11-18_SessionB", + "raw_path": "raw", + "heartbeat_remove": true, + "gevi": true, // true => gevi, false => geci + // Ratio Sequence + "classical_ratio_mode": true, // true: a/d false: 1+a-d + // Regression + "target_camera_acceptor": "acceptor", + "regressor_cameras_acceptor": [ + "oxygenation", + "volume" + ], + "target_camera_donor": "donor", + "regressor_cameras_donor": [ + "oxygenation", + "volume" + ], + // binning + "binning_enable": true, + "binning_at_the_end": false, + "binning_kernel_size": 4, + "binning_stride": 4, + "binning_divisor_override": 1, + // alignment + "alignment_batch_size": 200, + "rotation_stabilization_threshold_factor": 3.0, // >= 1.0 + "rotation_stabilization_threshold_border": 0.9, // <= 1.0 + // Heart beat detection + "lower_freqency_bandpass": 5.0, // Hz + "upper_freqency_bandpass": 14.0, // Hz + "heartbeat_filtfilt_chuck_size": 10, + // Gauss smear + "gauss_smear_spatial_width": 8, + "gauss_smear_temporal_width": 0.1, + "gauss_smear_use_matlab_mask": false, + // LED Ramp on + "skip_frames_in_the_beginning": 100, // Frames + // PyTorch + "dtype": "float32", + "force_to_cpu": false, + // Save + "save_as_python": true, // produces .npz files (compressed) + "save_as_matlab": false, // produces .hd5 file (compressed) + // Save extra information + "save_alignment": false, + "save_heartbeat": false, + "save_factors": false, + "save_regression_coefficients": false, + "save_aligned_as_python": false, + "save_aligned_as_matlab": false, + "save_oxyvol_as_python": false, + "save_oxyvol_as_matlab": false, + "save_gevi_with_donor_acceptor": true, + // Not important parameter + "required_order": [ + "acceptor", + "donor", + "oxygenation", + "volume" + ] +} diff --git a/gevi/config_M0134M_2024-12-04_SessionA.json b/gevi/config_M0134M_2024-12-04_SessionA.json new file mode 100644 index 0000000..d77e531 --- /dev/null +++ b/gevi/config_M0134M_2024-12-04_SessionA.json @@ -0,0 +1,67 @@ +{ + "basic_path": "/data_1/fatma/GEVI/", + "recoding_data": "2024-12-04", + "mouse_identifier": "M0134M_SessionA", + "raw_path": "raw", + "export_path": "output/M0134M_2024-12-04_SessionA", + "ref_image_path": "ref_images/M0134M_2024-12-04_SessionA", + "raw_path": "raw", + "heartbeat_remove": true, + "gevi": true, // true => gevi, false => geci + // Ratio Sequence + "classical_ratio_mode": true, // true: a/d false: 1+a-d + // Regression + "target_camera_acceptor": "acceptor", + "regressor_cameras_acceptor": [ + "oxygenation", + "volume" + ], + "target_camera_donor": "donor", + "regressor_cameras_donor": [ + "oxygenation", + "volume" + ], + // binning + "binning_enable": true, + "binning_at_the_end": false, + "binning_kernel_size": 4, + "binning_stride": 4, + "binning_divisor_override": 1, + // alignment + "alignment_batch_size": 200, + "rotation_stabilization_threshold_factor": 3.0, // >= 1.0 + "rotation_stabilization_threshold_border": 0.9, // <= 1.0 + // Heart beat detection + "lower_freqency_bandpass": 5.0, // Hz + "upper_freqency_bandpass": 14.0, // Hz + "heartbeat_filtfilt_chuck_size": 10, + // Gauss smear + "gauss_smear_spatial_width": 8, + "gauss_smear_temporal_width": 0.1, + "gauss_smear_use_matlab_mask": false, + // LED Ramp on + "skip_frames_in_the_beginning": 100, // Frames + // PyTorch + "dtype": "float32", + "force_to_cpu": false, + // Save + "save_as_python": true, // produces .npz files (compressed) + "save_as_matlab": false, // produces .hd5 file (compressed) + // Save extra information + "save_alignment": false, + "save_heartbeat": false, + "save_factors": false, + "save_regression_coefficients": false, + "save_aligned_as_python": false, + "save_aligned_as_matlab": false, + "save_oxyvol_as_python": false, + "save_oxyvol_as_matlab": false, + "save_gevi_with_donor_acceptor": true, + // Not important parameter + "required_order": [ + "acceptor", + "donor", + "oxygenation", + "volume" + ] +} diff --git a/gevi/config_M0134M_2024-12-04_SessionB.json b/gevi/config_M0134M_2024-12-04_SessionB.json new file mode 100644 index 0000000..36ad83f --- /dev/null +++ b/gevi/config_M0134M_2024-12-04_SessionB.json @@ -0,0 +1,67 @@ +{ + "basic_path": "/data_1/fatma/GEVI/", + "recoding_data": "2024-12-04", + "mouse_identifier": "M0134M_SessionB", + "raw_path": "raw", + "export_path": "output/M0134M_2024-12-04_SessionB", + "ref_image_path": "ref_images/M0134M_2024-12-04_SessionB", + "raw_path": "raw", + "heartbeat_remove": true, + "gevi": true, // true => gevi, false => geci + // Ratio Sequence + "classical_ratio_mode": true, // true: a/d false: 1+a-d + // Regression + "target_camera_acceptor": "acceptor", + "regressor_cameras_acceptor": [ + "oxygenation", + "volume" + ], + "target_camera_donor": "donor", + "regressor_cameras_donor": [ + "oxygenation", + "volume" + ], + // binning + "binning_enable": true, + "binning_at_the_end": false, + "binning_kernel_size": 4, + "binning_stride": 4, + "binning_divisor_override": 1, + // alignment + "alignment_batch_size": 200, + "rotation_stabilization_threshold_factor": 3.0, // >= 1.0 + "rotation_stabilization_threshold_border": 0.9, // <= 1.0 + // Heart beat detection + "lower_freqency_bandpass": 5.0, // Hz + "upper_freqency_bandpass": 14.0, // Hz + "heartbeat_filtfilt_chuck_size": 10, + // Gauss smear + "gauss_smear_spatial_width": 8, + "gauss_smear_temporal_width": 0.1, + "gauss_smear_use_matlab_mask": false, + // LED Ramp on + "skip_frames_in_the_beginning": 100, // Frames + // PyTorch + "dtype": "float32", + "force_to_cpu": false, + // Save + "save_as_python": true, // produces .npz files (compressed) + "save_as_matlab": false, // produces .hd5 file (compressed) + // Save extra information + "save_alignment": false, + "save_heartbeat": false, + "save_factors": false, + "save_regression_coefficients": false, + "save_aligned_as_python": false, + "save_aligned_as_matlab": false, + "save_oxyvol_as_python": false, + "save_oxyvol_as_matlab": false, + "save_gevi_with_donor_acceptor": true, + // Not important parameter + "required_order": [ + "acceptor", + "donor", + "oxygenation", + "volume" + ] +} diff --git a/gevi/config_M3905F_SessionB.json b/gevi/config_M3905F_SessionB.json new file mode 100644 index 0000000..f2a41bd --- /dev/null +++ b/gevi/config_M3905F_SessionB.json @@ -0,0 +1,67 @@ +{ + "basic_path": "/data_1/fatma/GEVI_GECI_ES", + "recoding_data": "session_B", + "mouse_identifier": "M3905F", + "raw_path": "raw", + "export_path": "output/M3905F_SessionB", + "ref_image_path": "ref_images/M3905F_SessionB", + "raw_path": "raw", + "heartbeat_remove": true, + "gevi": true, // true => gevi, false => geci + // Ratio Sequence + "classical_ratio_mode": true, // true: a/d false: 1+a-d + // Regression + "target_camera_acceptor": "acceptor", + "regressor_cameras_acceptor": [ + "oxygenation", + "volume" + ], + "target_camera_donor": "donor", + "regressor_cameras_donor": [ + "oxygenation", + "volume" + ], + // binning + "binning_enable": true, + "binning_at_the_end": false, + "binning_kernel_size": 4, + "binning_stride": 4, + "binning_divisor_override": 1, + // alignment + "alignment_batch_size": 200, + "rotation_stabilization_threshold_factor": 3.0, // >= 1.0 + "rotation_stabilization_threshold_border": 0.9, // <= 1.0 + // Heart beat detection + "lower_freqency_bandpass": 5.0, // Hz + "upper_freqency_bandpass": 14.0, // Hz + "heartbeat_filtfilt_chuck_size": 10, + // Gauss smear + "gauss_smear_spatial_width": 8, + "gauss_smear_temporal_width": 0.1, + "gauss_smear_use_matlab_mask": false, + // LED Ramp on + "skip_frames_in_the_beginning": 100, // Frames + // PyTorch + "dtype": "float32", + "force_to_cpu": false, + // Save + "save_as_python": true, // produces .npz files (compressed) + "save_as_matlab": false, // produces .hd5 file (compressed) + // Save extra information + "save_alignment": false, + "save_heartbeat": false, + "save_factors": false, + "save_regression_coefficients": false, + "save_aligned_as_python": false, + "save_aligned_as_matlab": false, + "save_oxyvol_as_python": false, + "save_oxyvol_as_matlab": false, + "save_gevi_with_donor_acceptor": true, + // Not important parameter + "required_order": [ + "acceptor", + "donor", + "oxygenation", + "volume" + ] +} diff --git a/gevi/config_example_GEVI.json b/gevi/config_example_GEVI.json new file mode 100644 index 0000000..e7d53ad --- /dev/null +++ b/gevi/config_example_GEVI.json @@ -0,0 +1,66 @@ +{ + "basic_path": "/data_1/fatma/GEVI_GECI_ES", + "recoding_data": "session_B", + "mouse_identifier": "M3905F", + "raw_path": "raw", + "export_path": "output/M3905F_SessionB", + "ref_image_path": "ref_images/M3905F_SessionB", + "heartbeat_remove": true, + "gevi": true, // true => gevi, false => geci + // Ratio Sequence + "classical_ratio_mode": true, // true: a/d false: 1+a-d + // Regression + "target_camera_acceptor": "acceptor", + "regressor_cameras_acceptor": [ + "oxygenation", + "volume" + ], + "target_camera_donor": "donor", + "regressor_cameras_donor": [ + "oxygenation", + "volume" + ], + // binning + "binning_enable": true, + "binning_at_the_end": false, + "binning_kernel_size": 4, + "binning_stride": 4, + "binning_divisor_override": 1, + // alignment + "alignment_batch_size": 200, + "rotation_stabilization_threshold_factor": 3.0, // >= 1.0 + "rotation_stabilization_threshold_border": 0.9, // <= 1.0 + // Heart beat detection + "lower_freqency_bandpass": 5.0, // Hz + "upper_freqency_bandpass": 14.0, // Hz + "heartbeat_filtfilt_chuck_size": 10, + // Gauss smear + "gauss_smear_spatial_width": 8, + "gauss_smear_temporal_width": 0.1, + "gauss_smear_use_matlab_mask": false, + // LED Ramp on + "skip_frames_in_the_beginning": 100, // Frames + // PyTorch + "dtype": "float32", + "force_to_cpu": false, + // Save + "save_as_python": true, // produces .npz files (compressed) + "save_as_matlab": false, // produces .hd5 file (compressed) + // Save extra information + "save_alignment": false, + "save_heartbeat": false, + "save_factors": false, + "save_regression_coefficients": false, + "save_aligned_as_python": false, + "save_aligned_as_matlab": false, + "save_oxyvol_as_python": false, + "save_oxyvol_as_matlab": false, + "save_gevi_with_donor_acceptor": true, + // Not important parameter + "required_order": [ + "acceptor", + "donor", + "oxygenation", + "volume" + ] +} diff --git a/gevi/example_load_gevi.py b/gevi/example_load_gevi.py new file mode 100644 index 0000000..ce8c2e5 --- /dev/null +++ b/gevi/example_load_gevi.py @@ -0,0 +1,56 @@ +# %% +import numpy as np +import matplotlib.pyplot as plt +import os + +output_path = 'output' + +recording_name = 'M0134M_2024-12-04_SessionA' +n_trials_per_experiment = [30, 0, 30, 30, 30, 30, 30, 30, 30,] +name_experiment = ['none', 'visual', '2 uA', '5 uA', '7 uA', '10 uA', '15 uA', '30 uA', '60 uA'] + +# recording_name = 'M0134M_2024-11-06_SessionB' +# n_trials_per_experiment = [15, 15,] +# name_experiment = ['none', 'visual',] + +i_experiment = 8 + +r_avg = None +ad_avg = None +for i_trial in range(n_trials_per_experiment[i_experiment]): + + folder = output_path + os.sep + recording_name + file = f"Exp{i_experiment + 1:03}_Trial{i_trial + 1:03}_ratio_sequence.npz" + fullpath = folder + os.sep + file + + print(f'Loading file "{fullpath}"...') + data = np.load(fullpath) + + print(f"FIle contents: {data.files}") + ratio_sequence = data["ratio_sequence"] + if 'data_acceptor' in data.files: + data_acceptor = data["data_acceptor"] + data_donor = data["data_donor"] + + mask = data["mask"][:, :, np.newaxis] + + if i_trial == 0: + r_avg = ratio_sequence + if 'data_acceptor' in data.files: + ad_avg = (data_acceptor / data_donor) * mask + 1 - mask + else: + r_avg += ratio_sequence + if 'data_acceptor' in data.files: + ad_avg += (data_acceptor / data_donor) * mask + 1 - mask + +if r_avg is not None: + r_avg /= n_trials_per_experiment[i_experiment] +if ad_avg is not None: + ad_avg /= n_trials_per_experiment[i_experiment] + +# %% +for t in range(200, 300, 5): + plt.imshow(r_avg[:, :, t], vmin=0.99, vmax=1.01, cmap='seismic') + plt.colorbar() + plt.show() + diff --git a/other/stage_4b_inspect.py b/other/stage_4b_inspect.py new file mode 100644 index 0000000..f8884f5 --- /dev/null +++ b/other/stage_4b_inspect.py @@ -0,0 +1,532 @@ +# %% + +import numpy as np +import torch +import torchvision as tv # type: ignore + +import os +import logging + +from functions.create_logger import create_logger +from functions.get_torch_device import get_torch_device +from functions.load_config import load_config +from functions.get_experiments import get_experiments +from functions.get_trials import get_trials +from functions.binning import binning +from functions.align_refref import align_refref +from functions.perform_donor_volume_rotation import perform_donor_volume_rotation +from functions.perform_donor_volume_translation import perform_donor_volume_translation +from functions.data_raw_loader import data_raw_loader + +import argh + + +@torch.no_grad() +def process_trial( + config: dict, + mylogger: logging.Logger, + experiment_id: int, + trial_id: int, + device: torch.device, +): + + mylogger.info("") + mylogger.info("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") + mylogger.info("~ TRIAL START ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") + mylogger.info("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") + mylogger.info("") + + if device != torch.device("cpu"): + torch.cuda.empty_cache() + mylogger.info("Empty CUDA cache") + cuda_total_memory: int = torch.cuda.get_device_properties( + device.index + ).total_memory + else: + cuda_total_memory = 0 + + raw_data_path: str = os.path.join( + config["basic_path"], + config["recoding_data"], + config["mouse_identifier"], + config["raw_path"], + ) + + if config["binning_enable"] and (config["binning_at_the_end"] is False): + force_to_cpu_memory: bool = True + else: + force_to_cpu_memory = False + + meta_channels: list[str] + meta_mouse_markings: str + meta_recording_date: str + meta_stimulation_times: dict + meta_experiment_names: dict + meta_trial_recording_duration: float + meta_frame_time: float + meta_mouse: str + data: torch.Tensor + + ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + data, + ) = data_raw_loader( + raw_data_path=raw_data_path, + mylogger=mylogger, + experiment_id=experiment_id, + trial_id=trial_id, + device=device, + force_to_cpu_memory=force_to_cpu_memory, + config=config, + ) + experiment_name: str = f"Exp{experiment_id:03d}_Trial{trial_id:03d}" + + dtype_str = config["dtype"] + dtype_np: np.dtype = getattr(np, dtype_str) + + dtype: torch.dtype = data.dtype + + if device != torch.device("cpu"): + free_mem = cuda_total_memory - max( + [torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)] + ) + mylogger.info(f"CUDA memory: {free_mem // 1024} MByte") + + mylogger.info(f"Data shape: {data.shape}") + mylogger.info("-==- Done -==-") + + mylogger.info("Finding limit values in the RAW data and mark them for masking") + limit: float = (2**16) - 1 + for i in range(0, data.shape[3]): + zero_pixel_mask: torch.Tensor = torch.any(data[..., i] >= limit, dim=-1) + data[zero_pixel_mask, :, i] = -100.0 + mylogger.info( + f"{meta_channels[i]}: " + f"found {int(zero_pixel_mask.type(dtype=dtype).sum())} pixel " + f"with limit values " + ) + mylogger.info("-==- Done -==-") + + mylogger.info("Reference images and mask") + + ref_image_path: str = config["ref_image_path"] + + ref_image_path_acceptor: str = os.path.join(ref_image_path, "acceptor.npy") + if os.path.isfile(ref_image_path_acceptor) is False: + mylogger.info(f"Could not load ref file: {ref_image_path_acceptor}") + assert os.path.isfile(ref_image_path_acceptor) + return + + mylogger.info(f"Loading ref file data: {ref_image_path_acceptor}") + ref_image_acceptor: torch.Tensor = torch.tensor( + np.load(ref_image_path_acceptor).astype(dtype_np), + dtype=dtype, + device=data.device, + ) + + ref_image_path_donor: str = os.path.join(ref_image_path, "donor.npy") + if os.path.isfile(ref_image_path_donor) is False: + mylogger.info(f"Could not load ref file: {ref_image_path_donor}") + assert os.path.isfile(ref_image_path_donor) + return + + mylogger.info(f"Loading ref file data: {ref_image_path_donor}") + ref_image_donor: torch.Tensor = torch.tensor( + np.load(ref_image_path_donor).astype(dtype_np), dtype=dtype, device=data.device + ) + + ref_image_path_oxygenation: str = os.path.join(ref_image_path, "oxygenation.npy") + if os.path.isfile(ref_image_path_oxygenation) is False: + mylogger.info(f"Could not load ref file: {ref_image_path_oxygenation}") + assert os.path.isfile(ref_image_path_oxygenation) + return + + mylogger.info(f"Loading ref file data: {ref_image_path_oxygenation}") + ref_image_oxygenation: torch.Tensor = torch.tensor( + np.load(ref_image_path_oxygenation).astype(dtype_np), + dtype=dtype, + device=data.device, + ) + + ref_image_path_volume: str = os.path.join(ref_image_path, "volume.npy") + if os.path.isfile(ref_image_path_volume) is False: + mylogger.info(f"Could not load ref file: {ref_image_path_volume}") + assert os.path.isfile(ref_image_path_volume) + return + + mylogger.info(f"Loading ref file data: {ref_image_path_volume}") + ref_image_volume: torch.Tensor = torch.tensor( + np.load(ref_image_path_volume).astype(dtype_np), dtype=dtype, device=data.device + ) + + refined_mask_file: str = os.path.join(ref_image_path, "mask_not_rotated.npy") + if os.path.isfile(refined_mask_file) is False: + mylogger.info(f"Could not load mask file: {refined_mask_file}") + assert os.path.isfile(refined_mask_file) + return + + mylogger.info(f"Loading mask file data: {refined_mask_file}") + mask: torch.Tensor = torch.tensor( + np.load(refined_mask_file).astype(dtype_np), dtype=dtype, device=data.device + ) + mylogger.info("-==- Done -==-") + + if config["binning_enable"] and (config["binning_at_the_end"] is False): + mylogger.info("Binning of data") + mylogger.info( + ( + f"kernel_size={int(config['binning_kernel_size'])}, " + f"stride={int(config['binning_stride'])}, " + f"divisor_override={int(config['binning_divisor_override'])}" + ) + ) + + data = binning( + data, + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=int(config["binning_divisor_override"]), + ).to(device=data.device) + ref_image_acceptor = ( + binning( + ref_image_acceptor.unsqueeze(-1).unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=int(config["binning_divisor_override"]), + ) + .squeeze(-1) + .squeeze(-1) + ) + ref_image_donor = ( + binning( + ref_image_donor.unsqueeze(-1).unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=int(config["binning_divisor_override"]), + ) + .squeeze(-1) + .squeeze(-1) + ) + ref_image_oxygenation = ( + binning( + ref_image_oxygenation.unsqueeze(-1).unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=int(config["binning_divisor_override"]), + ) + .squeeze(-1) + .squeeze(-1) + ) + ref_image_volume = ( + binning( + ref_image_volume.unsqueeze(-1).unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=int(config["binning_divisor_override"]), + ) + .squeeze(-1) + .squeeze(-1) + ) + mask = ( + binning( + mask.unsqueeze(-1).unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=int(config["binning_divisor_override"]), + ) + .squeeze(-1) + .squeeze(-1) + ) + mylogger.info(f"Data shape: {data.shape}") + mylogger.info("-==- Done -==-") + + mylogger.info("Preparing alignment") + mylogger.info("Re-order Raw data") + data = data.moveaxis(-2, 0).moveaxis(-1, 0) + mylogger.info(f"Data shape: {data.shape}") + mylogger.info("-==- Done -==-") + + mylogger.info("Alignment of the ref images and the mask") + mylogger.info("Ref image of donor stays fixed.") + mylogger.info("Ref image of volume and the mask doesn't need to be touched") + mylogger.info("Calculate translation and rotation between the reference images") + angle_refref, tvec_refref, ref_image_acceptor, ref_image_donor = align_refref( + mylogger=mylogger, + ref_image_acceptor=ref_image_acceptor, + ref_image_donor=ref_image_donor, + batch_size=config["alignment_batch_size"], + fill_value=-100.0, + ) + mylogger.info(f"Rotation: {round(float(angle_refref[0]), 2)} degree") + mylogger.info( + f"Translation: {round(float(tvec_refref[0]), 1)} x {round(float(tvec_refref[1]), 1)} pixel" + ) + + if config["save_alignment"]: + temp_path: str = os.path.join( + config["export_path"], experiment_name + "_angle_refref.npy" + ) + mylogger.info(f"Save angle to {temp_path}") + np.save(temp_path, angle_refref.cpu()) + + temp_path = os.path.join( + config["export_path"], experiment_name + "_tvec_refref.npy" + ) + mylogger.info(f"Save translation vector to {temp_path}") + np.save(temp_path, tvec_refref.cpu()) + + mylogger.info("Moving & rotating the oxygenation ref image") + ref_image_oxygenation = tv.transforms.functional.affine( # type: ignore + img=ref_image_oxygenation.unsqueeze(0), + angle=-float(angle_refref), + translate=[0, 0], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=-100.0, + ) + + ref_image_oxygenation = tv.transforms.functional.affine( # type: ignore + img=ref_image_oxygenation, + angle=0, + translate=[tvec_refref[1], tvec_refref[0]], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=-100.0, + ).squeeze(0) + mylogger.info("-==- Done -==-") + + mylogger.info("Rotate and translate the acceptor and oxygenation data accordingly") + acceptor_index: int = config["required_order"].index("acceptor") + donor_index: int = config["required_order"].index("donor") + oxygenation_index: int = config["required_order"].index("oxygenation") + volume_index: int = config["required_order"].index("volume") + + mylogger.info("Rotate acceptor") + data[acceptor_index, ...] = tv.transforms.functional.affine( # type: ignore + img=data[acceptor_index, ...], # type: ignore + angle=-float(angle_refref), + translate=[0, 0], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=-100.0, + ) + + mylogger.info("Translate acceptor") + data[acceptor_index, ...] = tv.transforms.functional.affine( # type: ignore + img=data[acceptor_index, ...], + angle=0, + translate=[tvec_refref[1], tvec_refref[0]], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=-100.0, + ) + + mylogger.info("Rotate oxygenation") + data[oxygenation_index, ...] = tv.transforms.functional.affine( # type: ignore + img=data[oxygenation_index, ...], + angle=-float(angle_refref), + translate=[0, 0], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=-100.0, + ) + + mylogger.info("Translate oxygenation") + data[oxygenation_index, ...] = tv.transforms.functional.affine( # type: ignore + img=data[oxygenation_index, ...], + angle=0, + translate=[tvec_refref[1], tvec_refref[0]], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=-100.0, + ) + mylogger.info("-==- Done -==-") + + mylogger.info("Perform rotation between donor and volume and its ref images") + mylogger.info("for all frames and then rotate all the data accordingly") + + ( + data[acceptor_index, ...], + data[donor_index, ...], + data[oxygenation_index, ...], + data[volume_index, ...], + angle_donor_volume, + ) = perform_donor_volume_rotation( + mylogger=mylogger, + acceptor=data[acceptor_index, ...], + donor=data[donor_index, ...], + oxygenation=data[oxygenation_index, ...], + volume=data[volume_index, ...], + ref_image_donor=ref_image_donor, + ref_image_volume=ref_image_volume, + batch_size=config["alignment_batch_size"], + fill_value=-100.0, + config=config, + ) + + mylogger.info( + f"angles: " + f"min {round(float(angle_donor_volume.min()), 2)} " + f"max {round(float(angle_donor_volume.max()), 2)} " + f"mean {round(float(angle_donor_volume.mean()), 2)} " + ) + + if config["save_alignment"]: + temp_path = os.path.join( + config["export_path"], experiment_name + "_angle_donor_volume.npy" + ) + mylogger.info(f"Save angles to {temp_path}") + np.save(temp_path, angle_donor_volume.cpu()) + mylogger.info("-==- Done -==-") + + mylogger.info("Perform translation between donor and volume and its ref images") + mylogger.info("for all frames and then translate all the data accordingly") + + ( + data_acceptor, + data_donor, + data_oxygenation, + data_volume, + _, + ) = perform_donor_volume_translation( + mylogger=mylogger, + acceptor=data[acceptor_index, 0:1, ...], + donor=data[donor_index, 0:1, ...], + oxygenation=data[oxygenation_index, 0:1, ...], + volume=data[volume_index, 0:1, ...], + ref_image_donor=ref_image_donor, + ref_image_volume=ref_image_volume, + batch_size=config["alignment_batch_size"], + fill_value=-100.0, + config=config, + ) + + # + + temp_path = os.path.join( + config["export_path"], experiment_name + "_inspect_images.npz" + ) + mylogger.info(f"Save images for inspection to {temp_path}") + np.savez( + temp_path, + acceptor=data_acceptor.cpu(), + donor=data_donor.cpu(), + oxygenation=data_oxygenation.cpu(), + volume=data_volume.cpu(), + ) + + mylogger.info("") + mylogger.info("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") + mylogger.info("~ TRIAL START ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") + mylogger.info("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") + mylogger.info("") + + return + + +def main( + *, + config_filename: str = "config.json", + experiment_id_overwrite: int = -1, + trial_id_overwrite: int = -1, +) -> None: + mylogger = create_logger( + save_logging_messages=True, + display_logging_messages=True, + log_stage_name="stage_4b", + ) + + config = load_config(mylogger=mylogger, filename=config_filename) + + if (config["save_as_python"] is False) and (config["save_as_matlab"] is False): + mylogger.info("No output will be created. ") + mylogger.info("Change save_as_python and/or save_as_matlab in the config file") + mylogger.info("ERROR: STOP!!!") + exit() + + if (len(config["target_camera_donor"]) == 0) and ( + len(config["target_camera_acceptor"]) == 0 + ): + mylogger.info( + "Configure at least target_camera_donor or target_camera_acceptor correctly." + ) + mylogger.info("ERROR: STOP!!!") + exit() + + device = get_torch_device(mylogger, config["force_to_cpu"]) + + mylogger.info( + f"Create directory {config['export_path']} in the case it does not exist" + ) + os.makedirs(config["export_path"], exist_ok=True) + + raw_data_path: str = os.path.join( + config["basic_path"], + config["recoding_data"], + config["mouse_identifier"], + config["raw_path"], + ) + + if os.path.isdir(raw_data_path) is False: + mylogger.info(f"ERROR: could not find raw directory {raw_data_path}!!!!") + exit() + + if experiment_id_overwrite == -1: + experiments = get_experiments(raw_data_path) + else: + assert experiment_id_overwrite >= 0 + experiments = torch.tensor([experiment_id_overwrite]) + + for experiment_counter in range(0, experiments.shape[0]): + experiment_id = int(experiments[experiment_counter]) + + if trial_id_overwrite == -1: + trials = get_trials(raw_data_path, experiment_id) + else: + assert trial_id_overwrite >= 0 + trials = torch.tensor([trial_id_overwrite]) + + for trial_counter in range(0, trials.shape[0]): + trial_id = int(trials[trial_counter]) + + mylogger.info("") + mylogger.info( + f"======= EXPERIMENT ID: {experiment_id} ==== TRIAL ID: {trial_id} =======" + ) + mylogger.info("") + + try: + process_trial( + config=config, + mylogger=mylogger, + experiment_id=experiment_id, + trial_id=trial_id, + device=device, + ) + except torch.cuda.OutOfMemoryError: + mylogger.info("WARNING: RUNNING IN FAILBACK MODE!!!!") + mylogger.info("Not enough GPU memory. Retry on CPU") + process_trial( + config=config, + mylogger=mylogger, + experiment_id=experiment_id, + trial_id=trial_id, + device=torch.device("cpu"), + ) + + +if __name__ == "__main__": + argh.dispatch_command(main) diff --git a/other/stage_4c_viewer.py b/other/stage_4c_viewer.py new file mode 100644 index 0000000..9c70616 --- /dev/null +++ b/other/stage_4c_viewer.py @@ -0,0 +1,56 @@ +import os +import numpy as np + +import matplotlib.pyplot as plt # type:ignore + +from functions.create_logger import create_logger +from functions.load_config import load_config + +import argh + + +def main( + *, config_filename: str = "config.json", experiment_id: int = 1, trial_id: int = 1 +) -> None: + + experiment_name: str = f"Exp{experiment_id:03d}_Trial{trial_id:03d}" + mylogger = create_logger( + save_logging_messages=False, + display_logging_messages=False, + log_stage_name="stage_4c", + ) + + config = load_config(mylogger=mylogger, filename=config_filename) + + temp_path = os.path.join( + config["export_path"], experiment_name + "_inspect_images.npz" + ) + data = np.load(temp_path) + + acceptor = data["acceptor"][0, ...] + donor = data["donor"][0, ...] + oxygenation = data["oxygenation"][0, ...] + volume = data["volume"][0, ...] + + plt.figure(1) + plt.imshow(acceptor, cmap="hot") + plt.title(f"Acceptor Experiment: {experiment_id:03d} Trial:{trial_id:03d}") + plt.show(block=False) + plt.figure(2) + plt.imshow(donor, cmap="hot") + plt.title(f"Donor Experiment: {experiment_id:03d} Trial:{trial_id:03d}") + plt.show(block=False) + plt.figure(3) + plt.imshow(oxygenation, cmap="hot") + plt.title(f"Oxygenation Experiment: {experiment_id:03d} Trial:{trial_id:03d}") + plt.show(block=False) + plt.figure(4) + plt.imshow(volume, cmap="hot") + plt.title(f"Volume Experiment: {experiment_id:03d} Trial:{trial_id:03d}") + plt.show(block=True) + + return + + +if __name__ == "__main__": + argh.dispatch_command(main) diff --git a/stage_1_get_ref_image.py b/stage_1_get_ref_image.py new file mode 100644 index 0000000..0e5b6da --- /dev/null +++ b/stage_1_get_ref_image.py @@ -0,0 +1,129 @@ +import os +import torch +import numpy as np +import argh + +from functions.get_experiments import get_experiments +from functions.get_trials import get_trials +from functions.bandpass import bandpass +from functions.create_logger import create_logger +from functions.get_torch_device import get_torch_device +from functions.load_config import load_config +from functions.data_raw_loader import data_raw_loader + + +def main(*, config_filename: str = "config.json") -> None: + mylogger = create_logger( + save_logging_messages=True, + display_logging_messages=True, + log_stage_name="stage_1", + ) + + config = load_config(mylogger=mylogger, filename=config_filename) + + if config["binning_enable"] and (config["binning_at_the_end"] is False): + device: torch.device = torch.device("cpu") + else: + device = get_torch_device(mylogger, config["force_to_cpu"]) + + raw_data_path: str = os.path.join( + config["basic_path"], + config["recoding_data"], + config["mouse_identifier"], + config["raw_path"], + ) + + mylogger.info(f"Using data path: {raw_data_path}") + + first_experiment_id: int = int(get_experiments(raw_data_path).min()) + first_trial_id: int = int(get_trials(raw_data_path, first_experiment_id).min()) + + meta_channels: list[str] + meta_mouse_markings: str + meta_recording_date: str + meta_stimulation_times: dict + meta_experiment_names: dict + meta_trial_recording_duration: float + meta_frame_time: float + meta_mouse: str + data: torch.Tensor + + if config["binning_enable"] and (config["binning_at_the_end"] is False): + force_to_cpu_memory: bool = True + else: + force_to_cpu_memory = False + + mylogger.info("Loading data") + + ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + data, + ) = data_raw_loader( + raw_data_path=raw_data_path, + mylogger=mylogger, + experiment_id=first_experiment_id, + trial_id=first_trial_id, + device=device, + force_to_cpu_memory=force_to_cpu_memory, + config=config, + ) + mylogger.info("-==- Done -==-") + + output_path = config["ref_image_path"] + mylogger.info(f"Create directory {output_path} in the case it does not exist") + os.makedirs(output_path, exist_ok=True) + + mylogger.info("Reference images") + for i in range(0, len(meta_channels)): + temp_path: str = os.path.join(output_path, meta_channels[i] + ".npy") + mylogger.info(f"Extract and save: {temp_path}") + frame_id: int = data.shape[-2] // 2 + mylogger.info(f"Will use frame id: {frame_id}") + ref_image: np.ndarray = ( + data[:, :, frame_id, meta_channels.index(meta_channels[i])] + .clone() + .cpu() + .numpy() + ) + np.save(temp_path, ref_image) + mylogger.info("-==- Done -==-") + + sample_frequency: float = 1.0 / meta_frame_time + mylogger.info( + ( + f"Heartbeat power {config['lower_freqency_bandpass']} Hz" + f" - {config['upper_freqency_bandpass']} Hz," + f" sample-rate: {sample_frequency}," + f" skipping the first {config['skip_frames_in_the_beginning']} frames" + ) + ) + + for i in range(0, len(meta_channels)): + temp_path = os.path.join(output_path, meta_channels[i] + "_var.npy") + mylogger.info(f"Extract and save: {temp_path}") + + heartbeat_ts: torch.Tensor = bandpass( + data=data[..., i], + low_frequency=config["lower_freqency_bandpass"], + high_frequency=config["upper_freqency_bandpass"], + fs=sample_frequency, + filtfilt_chuck_size=10, + ) + + heartbeat_power = heartbeat_ts[ + ..., config["skip_frames_in_the_beginning"] : + ].var(dim=-1) + np.save(temp_path, heartbeat_power) + + mylogger.info("-==- Done -==-") + + +if __name__ == "__main__": + argh.dispatch_command(main) diff --git a/stage_2_make_heartbeat_mask.py b/stage_2_make_heartbeat_mask.py new file mode 100644 index 0000000..dfa8c63 --- /dev/null +++ b/stage_2_make_heartbeat_mask.py @@ -0,0 +1,163 @@ +import matplotlib.pyplot as plt # type:ignore +import matplotlib +import numpy as np +import torch +import os +import argh + +from matplotlib.widgets import Slider, Button # type:ignore +from functools import partial +from functions.gauss_smear_individual import gauss_smear_individual +from functions.create_logger import create_logger +from functions.get_torch_device import get_torch_device +from functions.load_config import load_config + + +def main(*, config_filename: str = "config.json") -> None: + mylogger = create_logger( + save_logging_messages=True, + display_logging_messages=True, + log_stage_name="stage_2", + ) + + config = load_config(mylogger=mylogger, filename=config_filename) + + path: str = config["ref_image_path"] + use_channel: str = "donor" + spatial_width: float = 4.0 + temporal_width: float = 0.1 + + threshold: float = 0.05 + + heartbeat_mask_threshold_file: str = os.path.join( + path, "heartbeat_mask_threshold.npy" + ) + if os.path.isfile(heartbeat_mask_threshold_file): + mylogger.info( + f"loading previous threshold file: {heartbeat_mask_threshold_file}" + ) + threshold = float(np.load(heartbeat_mask_threshold_file)[0]) + + mylogger.info(f"initial threshold is {threshold}") + + image_ref_file: str = os.path.join(path, use_channel + ".npy") + image_var_file: str = os.path.join(path, use_channel + "_var.npy") + heartbeat_mask_file: str = os.path.join(path, "heartbeat_mask.npy") + + device = get_torch_device(mylogger, config["force_to_cpu"]) + + mylogger.info(f"loading image reference file: {image_ref_file}") + image_ref: np.ndarray = np.load(image_ref_file) + image_ref /= image_ref.max() + + mylogger.info(f"loading image heartbeat power: {image_var_file}") + image_var: np.ndarray = np.load(image_var_file) + image_var /= image_var.max() + + mylogger.info("Smear the image heartbeat power patially") + temp, _ = gauss_smear_individual( + input=torch.tensor(image_var[..., np.newaxis], device=device), + spatial_width=spatial_width, + temporal_width=temporal_width, + use_matlab_mask=False, + ) + temp /= temp.max() + + mylogger.info("-==- DONE -==-") + + image_3color = np.concatenate( + ( + np.zeros_like(image_ref[..., np.newaxis]), + image_ref[..., np.newaxis], + temp.cpu().numpy(), + ), + axis=-1, + ) + + mylogger.info("Prepare image") + + display_image = image_3color.copy() + display_image[..., 2] = display_image[..., 0] + mask = np.where(image_3color[..., 2] >= threshold, 1.0, np.nan)[..., np.newaxis] + display_image *= mask + display_image = np.nan_to_num(display_image, nan=1.0) + + value_sort = np.sort(image_var.flatten()) + value_sort_max = value_sort[int(value_sort.shape[0] * 0.95)] * 3 + print(value_sort_max) + mylogger.info("-==- DONE -==-") + + mylogger.info("Create figure") + + fig: matplotlib.figure.Figure = plt.figure() + + image_handle = plt.imshow(display_image, vmin=0, vmax=1, cmap="hot") + + mylogger.info("Add controls") + + def next_frame( + i: float, images: np.ndarray, image_handle: matplotlib.image.AxesImage + ) -> None: + nonlocal threshold + threshold = i + + display_image: np.ndarray = images.copy() + display_image[..., 2] = display_image[..., 0] + mask: np.ndarray = np.where(images[..., 2] >= i, 1.0, np.nan)[..., np.newaxis] + display_image *= mask + display_image = np.nan_to_num(display_image, nan=1.0) + + image_handle.set_data(display_image) + return + + def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None: + nonlocal threshold + nonlocal image_3color + nonlocal path + nonlocal mylogger + nonlocal heartbeat_mask_file + nonlocal heartbeat_mask_threshold_file + + mylogger.info(f"Threshold: {threshold}") + + mask: np.ndarray = image_3color[..., 2] >= threshold + mylogger.info(f"Save mask to: {heartbeat_mask_file}") + np.save(heartbeat_mask_file, mask) + mylogger.info(f"Save threshold to: {heartbeat_mask_threshold_file}") + np.save(heartbeat_mask_threshold_file, np.array([threshold])) + exit() + + def on_clicked_cancel(event: matplotlib.backend_bases.MouseEvent) -> None: + exit() + + axfreq = fig.add_axes(rect=(0.4, 0.9, 0.3, 0.03)) + slice_slider = Slider( + ax=axfreq, + label="Threshold", + valmin=0, + valmax=value_sort_max, + valinit=threshold, + valstep=value_sort_max / 1000.0, + ) + axbutton_accept = fig.add_axes(rect=(0.3, 0.85, 0.2, 0.04)) + button_accept = Button( + ax=axbutton_accept, label="Accept", image=None, color="0.85", hovercolor="0.95" + ) + button_accept.on_clicked(on_clicked_accept) # type: ignore + + axbutton_cancel = fig.add_axes(rect=(0.55, 0.85, 0.2, 0.04)) + button_cancel = Button( + ax=axbutton_cancel, label="Cancel", image=None, color="0.85", hovercolor="0.95" + ) + button_cancel.on_clicked(on_clicked_cancel) # type: ignore + + slice_slider.on_changed( + partial(next_frame, images=image_3color, image_handle=image_handle) + ) + + mylogger.info("Display") + plt.show() + + +if __name__ == "__main__": + argh.dispatch_command(main) diff --git a/stage_3_refine_mask.py b/stage_3_refine_mask.py new file mode 100644 index 0000000..f96b3bd --- /dev/null +++ b/stage_3_refine_mask.py @@ -0,0 +1,169 @@ +import os +import numpy as np + +import matplotlib.pyplot as plt # type:ignore +import matplotlib +from matplotlib.widgets import Button # type:ignore + +# pip install roipoly +from roipoly import RoiPoly # type:ignore + +from functions.create_logger import create_logger +from functions.load_config import load_config + +import argh + + +def compose_image(image_3color: np.ndarray, mask: np.ndarray) -> np.ndarray: + display_image = image_3color.copy() + display_image[..., 2] = display_image[..., 0] + display_image[mask == 0, :] = 1.0 + return display_image + + +def main(*, config_filename: str = "config.json") -> None: + mylogger = create_logger( + save_logging_messages=True, + display_logging_messages=True, + log_stage_name="stage_3", + ) + + config = load_config(mylogger=mylogger, filename=config_filename) + + path: str = config["ref_image_path"] + use_channel: str = "donor" + padding: int = 20 + + image_ref_file: str = os.path.join(path, use_channel + ".npy") + heartbeat_mask_file: str = os.path.join(path, "heartbeat_mask.npy") + refined_mask_file: str = os.path.join(path, "mask_not_rotated.npy") + + mylogger.info(f"loading image reference file: {image_ref_file}") + image_ref: np.ndarray = np.load(image_ref_file) + image_ref /= image_ref.max() + image_ref = np.pad(image_ref, pad_width=padding) + + mylogger.info(f"loading heartbeat mask: {heartbeat_mask_file}") + mask: np.ndarray = np.load(heartbeat_mask_file) + mask = np.pad(mask, pad_width=padding) + + image_3color = np.concatenate( + ( + np.zeros_like(image_ref[..., np.newaxis]), + image_ref[..., np.newaxis], + np.zeros_like(image_ref[..., np.newaxis]), + ), + axis=-1, + ) + + mylogger.info("-==- DONE -==-") + + fig, ax_main = plt.subplots() + + display_image = compose_image(image_3color=image_3color, mask=mask) + image_handle = ax_main.imshow(display_image, vmin=0, vmax=1, cmap="hot") + + mylogger.info("Add controls") + + def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None: + nonlocal mylogger + nonlocal refined_mask_file + nonlocal mask + + mylogger.info(f"Save mask to: {refined_mask_file}") + mask = mask[padding:-padding, padding:-padding] + np.save(refined_mask_file, mask) + + exit() + + def on_clicked_cancel(event: matplotlib.backend_bases.MouseEvent) -> None: + nonlocal mylogger + mylogger.info("Ended without saving the mask") + exit() + + def on_clicked_add(event: matplotlib.backend_bases.MouseEvent) -> None: + nonlocal new_roi # type: ignore + nonlocal mask + nonlocal image_3color + nonlocal display_image + nonlocal mylogger + if len(new_roi.x) > 0: + mylogger.info( + "A ROI with the following coordiantes has been added to the mask" + ) + for i in range(0, len(new_roi.x)): + mylogger.info(f"{round(new_roi.x[i], 1)} x {round(new_roi.y[i], 1)}") + mylogger.info("") + new_mask = new_roi.get_mask(display_image[:, :, 0]) + mask[new_mask] = 0.0 + display_image = compose_image(image_3color=image_3color, mask=mask) + image_handle.set_data(display_image) + for line in ax_main.lines: + line.remove() + plt.draw() + + new_roi = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False) + + def on_clicked_remove(event: matplotlib.backend_bases.MouseEvent) -> None: + nonlocal new_roi # type: ignore + nonlocal mask + nonlocal image_3color + nonlocal display_image + if len(new_roi.x) > 0: + mylogger.info( + "A ROI with the following coordiantes has been removed from the mask" + ) + for i in range(0, len(new_roi.x)): + mylogger.info(f"{round(new_roi.x[i], 1)} x {round(new_roi.y[i], 1)}") + mylogger.info("") + new_mask = new_roi.get_mask(display_image[:, :, 0]) + mask[new_mask] = 1.0 + display_image = compose_image(image_3color=image_3color, mask=mask) + image_handle.set_data(display_image) + for line in ax_main.lines: + line.remove() + plt.draw() + new_roi = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False) + + axbutton_accept = fig.add_axes(rect=(0.3, 0.85, 0.2, 0.04)) + button_accept = Button( + ax=axbutton_accept, label="Accept", image=None, color="0.85", hovercolor="0.95" + ) + button_accept.on_clicked(on_clicked_accept) # type: ignore + + axbutton_cancel = fig.add_axes(rect=(0.5, 0.85, 0.2, 0.04)) + button_cancel = Button( + ax=axbutton_cancel, label="Cancel", image=None, color="0.85", hovercolor="0.95" + ) + button_cancel.on_clicked(on_clicked_cancel) # type: ignore + + axbutton_addmask = fig.add_axes(rect=(0.3, 0.9, 0.2, 0.04)) + button_addmask = Button( + ax=axbutton_addmask, + label="Add mask", + image=None, + color="0.85", + hovercolor="0.95", + ) + button_addmask.on_clicked(on_clicked_add) # type: ignore + + axbutton_removemask = fig.add_axes(rect=(0.5, 0.9, 0.2, 0.04)) + button_removemask = Button( + ax=axbutton_removemask, + label="Remove mask", + image=None, + color="0.85", + hovercolor="0.95", + ) + button_removemask.on_clicked(on_clicked_remove) # type: ignore + + # ax_main.cla() + + mylogger.info("Display") + new_roi: RoiPoly = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False) + + plt.show() + + +if __name__ == "__main__": + argh.dispatch_command(main) diff --git a/stage_4_process.py b/stage_4_process.py new file mode 100644 index 0000000..4a020e2 --- /dev/null +++ b/stage_4_process.py @@ -0,0 +1,1413 @@ +# %% + +import numpy as np +import torch +import torchvision as tv # type: ignore + +import os +import logging +import h5py # type: ignore + +from functions.create_logger import create_logger +from functions.get_torch_device import get_torch_device +from functions.load_config import load_config +from functions.get_experiments import get_experiments +from functions.get_trials import get_trials +from functions.binning import binning +from functions.align_refref import align_refref +from functions.perform_donor_volume_rotation import perform_donor_volume_rotation +from functions.perform_donor_volume_translation import perform_donor_volume_translation +from functions.bandpass import bandpass +from functions.gauss_smear_individual import gauss_smear_individual +from functions.regression import regression +from functions.data_raw_loader import data_raw_loader + +import argh + + +@torch.no_grad() +def process_trial( + config: dict, + mylogger: logging.Logger, + experiment_id: int, + trial_id: int, + device: torch.device, +): + + mylogger.info("") + mylogger.info("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") + mylogger.info("~ TRIAL START ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") + mylogger.info("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") + mylogger.info("") + + if device != torch.device("cpu"): + torch.cuda.empty_cache() + mylogger.info("Empty CUDA cache") + cuda_total_memory: int = torch.cuda.get_device_properties( + device.index + ).total_memory + else: + cuda_total_memory = 0 + + mylogger.info("") + mylogger.info("(A) LOADING DATA, REFERENCE, AND MASK") + mylogger.info("-----------------------------------------------") + mylogger.info("") + + raw_data_path: str = os.path.join( + config["basic_path"], + config["recoding_data"], + config["mouse_identifier"], + config["raw_path"], + ) + + if config["binning_enable"] and (config["binning_at_the_end"] is False): + force_to_cpu_memory: bool = True + else: + force_to_cpu_memory = False + + meta_channels: list[str] + meta_mouse_markings: str + meta_recording_date: str + meta_stimulation_times: dict + meta_experiment_names: dict + meta_trial_recording_duration: float + meta_frame_time: float + meta_mouse: str + data: torch.Tensor + + ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + data, + ) = data_raw_loader( + raw_data_path=raw_data_path, + mylogger=mylogger, + experiment_id=experiment_id, + trial_id=trial_id, + device=device, + force_to_cpu_memory=force_to_cpu_memory, + config=config, + ) + experiment_name: str = f"Exp{experiment_id:03d}_Trial{trial_id:03d}" + + dtype_str = config["dtype"] + dtype_np: np.dtype = getattr(np, dtype_str) + + dtype: torch.dtype = data.dtype + + if device != torch.device("cpu"): + free_mem = cuda_total_memory - max( + [torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)] + ) + mylogger.info(f"CUDA memory: {free_mem // 1024} MByte") + + mylogger.info(f"Data shape: {data.shape}") + mylogger.info("-==- Done -==-") + + mylogger.info("Finding limit values in the RAW data and mark them for masking") + limit: float = (2**16) - 1 + for i in range(0, data.shape[3]): + zero_pixel_mask: torch.Tensor = torch.any(data[..., i] >= limit, dim=-1) + data[zero_pixel_mask, :, i] = -100.0 + mylogger.info( + f"{meta_channels[i]}: " + f"found {int(zero_pixel_mask.type(dtype=dtype).sum())} pixel " + f"with limit values " + ) + mylogger.info("-==- Done -==-") + + mylogger.info("Reference images and mask") + + ref_image_path: str = config["ref_image_path"] + + ref_image_path_acceptor: str = os.path.join(ref_image_path, "acceptor.npy") + if os.path.isfile(ref_image_path_acceptor) is False: + mylogger.info(f"Could not load ref file: {ref_image_path_acceptor}") + assert os.path.isfile(ref_image_path_acceptor) + return + + mylogger.info(f"Loading ref file data: {ref_image_path_acceptor}") + ref_image_acceptor: torch.Tensor = torch.tensor( + np.load(ref_image_path_acceptor).astype(dtype_np), + dtype=dtype, + device=data.device, + ) + + ref_image_path_donor: str = os.path.join(ref_image_path, "donor.npy") + if os.path.isfile(ref_image_path_donor) is False: + mylogger.info(f"Could not load ref file: {ref_image_path_donor}") + assert os.path.isfile(ref_image_path_donor) + return + + mylogger.info(f"Loading ref file data: {ref_image_path_donor}") + ref_image_donor: torch.Tensor = torch.tensor( + np.load(ref_image_path_donor).astype(dtype_np), dtype=dtype, device=data.device + ) + + ref_image_path_oxygenation: str = os.path.join(ref_image_path, "oxygenation.npy") + if os.path.isfile(ref_image_path_oxygenation) is False: + mylogger.info(f"Could not load ref file: {ref_image_path_oxygenation}") + assert os.path.isfile(ref_image_path_oxygenation) + return + + mylogger.info(f"Loading ref file data: {ref_image_path_oxygenation}") + ref_image_oxygenation: torch.Tensor = torch.tensor( + np.load(ref_image_path_oxygenation).astype(dtype_np), + dtype=dtype, + device=data.device, + ) + + ref_image_path_volume: str = os.path.join(ref_image_path, "volume.npy") + if os.path.isfile(ref_image_path_volume) is False: + mylogger.info(f"Could not load ref file: {ref_image_path_volume}") + assert os.path.isfile(ref_image_path_volume) + return + + mylogger.info(f"Loading ref file data: {ref_image_path_volume}") + ref_image_volume: torch.Tensor = torch.tensor( + np.load(ref_image_path_volume).astype(dtype_np), dtype=dtype, device=data.device + ) + + refined_mask_file: str = os.path.join(ref_image_path, "mask_not_rotated.npy") + if os.path.isfile(refined_mask_file) is False: + mylogger.info(f"Could not load mask file: {refined_mask_file}") + assert os.path.isfile(refined_mask_file) + return + + mylogger.info(f"Loading mask file data: {refined_mask_file}") + mask: torch.Tensor = torch.tensor( + np.load(refined_mask_file).astype(dtype_np), dtype=dtype, device=data.device + ) + mylogger.info("-==- Done -==-") + + if config["binning_enable"] and (config["binning_at_the_end"] is False): + + mylogger.info("") + mylogger.info("(B-OPTIONAL) BINNING") + mylogger.info("-----------------------------------------------") + mylogger.info("") + + mylogger.info("Binning of data") + mylogger.info( + ( + f"kernel_size={int(config['binning_kernel_size'])}, " + f"stride={int(config['binning_stride'])}, " + f"divisor_override={int(config['binning_divisor_override'])}" + ) + ) + + data = binning( + data, + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=int(config["binning_divisor_override"]), + ).to(device=data.device) + ref_image_acceptor = ( + binning( + ref_image_acceptor.unsqueeze(-1).unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=int(config["binning_divisor_override"]), + ) + .squeeze(-1) + .squeeze(-1) + ) + ref_image_donor = ( + binning( + ref_image_donor.unsqueeze(-1).unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=int(config["binning_divisor_override"]), + ) + .squeeze(-1) + .squeeze(-1) + ) + ref_image_oxygenation = ( + binning( + ref_image_oxygenation.unsqueeze(-1).unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=int(config["binning_divisor_override"]), + ) + .squeeze(-1) + .squeeze(-1) + ) + ref_image_volume = ( + binning( + ref_image_volume.unsqueeze(-1).unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=int(config["binning_divisor_override"]), + ) + .squeeze(-1) + .squeeze(-1) + ) + mask = ( + binning( + mask.unsqueeze(-1).unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=int(config["binning_divisor_override"]), + ) + .squeeze(-1) + .squeeze(-1) + ) + mylogger.info(f"Data shape: {data.shape}") + mylogger.info("-==- Done -==-") + + mylogger.info("") + mylogger.info("(C) ALIGNMENT OF SECOND TO FIRST CAMERA") + mylogger.info("-----------------------------------------------") + mylogger.info("") + + mylogger.info("Preparing alignment") + mylogger.info("Re-order Raw data") + data = data.moveaxis(-2, 0).moveaxis(-1, 0) + mylogger.info(f"Data shape: {data.shape}") + mylogger.info("-==- Done -==-") + + mylogger.info("Alignment of the ref images and the mask") + mylogger.info("Ref image of donor stays fixed.") + mylogger.info("Ref image of volume and the mask doesn't need to be touched") + mylogger.info("Calculate translation and rotation between the reference images") + angle_refref, tvec_refref, ref_image_acceptor, ref_image_donor = align_refref( + mylogger=mylogger, + ref_image_acceptor=ref_image_acceptor, + ref_image_donor=ref_image_donor, + batch_size=config["alignment_batch_size"], + fill_value=-100.0, + ) + mylogger.info(f"Rotation: {round(float(angle_refref[0]), 2)} degree") + mylogger.info( + f"Translation: {round(float(tvec_refref[0]), 1)} x {round(float(tvec_refref[1]), 1)} pixel" + ) + + if config["save_alignment"]: + temp_path: str = os.path.join( + config["export_path"], experiment_name + "_angle_refref.npy" + ) + mylogger.info(f"Save angle to {temp_path}") + np.save(temp_path, angle_refref.cpu()) + + temp_path = os.path.join( + config["export_path"], experiment_name + "_tvec_refref.npy" + ) + mylogger.info(f"Save translation vector to {temp_path}") + np.save(temp_path, tvec_refref.cpu()) + + mylogger.info("Moving & rotating the oxygenation ref image") + ref_image_oxygenation = tv.transforms.functional.affine( # type: ignore + img=ref_image_oxygenation.unsqueeze(0), + angle=-float(angle_refref), + translate=[0, 0], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=-100.0, + ) + + ref_image_oxygenation = tv.transforms.functional.affine( # type: ignore + img=ref_image_oxygenation, + angle=0, + translate=[tvec_refref[1], tvec_refref[0]], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=-100.0, + ).squeeze(0) + mylogger.info("-==- Done -==-") + + mylogger.info("Rotate and translate the acceptor and oxygenation data accordingly") + acceptor_index: int = config["required_order"].index("acceptor") + donor_index: int = config["required_order"].index("donor") + oxygenation_index: int = config["required_order"].index("oxygenation") + volume_index: int = config["required_order"].index("volume") + + mylogger.info("Rotate acceptor") + data[acceptor_index, ...] = tv.transforms.functional.affine( # type: ignore + img=data[acceptor_index, ...], # type: ignore + angle=-float(angle_refref), + translate=[0, 0], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=-100.0, + ) + + mylogger.info("Translate acceptor") + data[acceptor_index, ...] = tv.transforms.functional.affine( # type: ignore + img=data[acceptor_index, ...], + angle=0, + translate=[tvec_refref[1], tvec_refref[0]], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=-100.0, + ) + + mylogger.info("Rotate oxygenation") + data[oxygenation_index, ...] = tv.transforms.functional.affine( # type: ignore + img=data[oxygenation_index, ...], + angle=-float(angle_refref), + translate=[0, 0], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=-100.0, + ) + + mylogger.info("Translate oxygenation") + data[oxygenation_index, ...] = tv.transforms.functional.affine( # type: ignore + img=data[oxygenation_index, ...], + angle=0, + translate=[tvec_refref[1], tvec_refref[0]], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=-100.0, + ) + mylogger.info("-==- Done -==-") + + mylogger.info("Perform rotation between donor and volume and its ref images") + mylogger.info("for all frames and then rotate all the data accordingly") + + ( + data[acceptor_index, ...], + data[donor_index, ...], + data[oxygenation_index, ...], + data[volume_index, ...], + angle_donor_volume, + ) = perform_donor_volume_rotation( + mylogger=mylogger, + acceptor=data[acceptor_index, ...], + donor=data[donor_index, ...], + oxygenation=data[oxygenation_index, ...], + volume=data[volume_index, ...], + ref_image_donor=ref_image_donor, + ref_image_volume=ref_image_volume, + batch_size=config["alignment_batch_size"], + fill_value=-100.0, + config=config, + ) + + mylogger.info( + f"angles: " + f"min {round(float(angle_donor_volume.min()), 2)} " + f"max {round(float(angle_donor_volume.max()), 2)} " + f"mean {round(float(angle_donor_volume.mean()), 2)} " + ) + + if config["save_alignment"]: + temp_path = os.path.join( + config["export_path"], experiment_name + "_angle_donor_volume.npy" + ) + mylogger.info(f"Save angles to {temp_path}") + np.save(temp_path, angle_donor_volume.cpu()) + mylogger.info("-==- Done -==-") + + mylogger.info("Perform translation between donor and volume and its ref images") + mylogger.info("for all frames and then translate all the data accordingly") + ( + data[acceptor_index, ...], + data[donor_index, ...], + data[oxygenation_index, ...], + data[volume_index, ...], + tvec_donor_volume, + ) = perform_donor_volume_translation( + mylogger=mylogger, + acceptor=data[acceptor_index, ...], + donor=data[donor_index, ...], + oxygenation=data[oxygenation_index, ...], + volume=data[volume_index, ...], + ref_image_donor=ref_image_donor, + ref_image_volume=ref_image_volume, + batch_size=config["alignment_batch_size"], + fill_value=-100.0, + config=config, + ) + + mylogger.info( + f"translation dim 0: " + f"min {round(float(tvec_donor_volume[:, 0].min()), 1)} " + f"max {round(float(tvec_donor_volume[:, 0].max()), 1)} " + f"mean {round(float(tvec_donor_volume[:, 0].mean()), 1)} " + ) + mylogger.info( + f"translation dim 1: " + f"min {round(float(tvec_donor_volume[:, 1].min()), 1)} " + f"max {round(float(tvec_donor_volume[:, 1].max()), 1)} " + f"mean {round(float(tvec_donor_volume[:, 1].mean()), 1)} " + ) + + if config["save_alignment"]: + temp_path = os.path.join( + config["export_path"], experiment_name + "_tvec_donor_volume.npy" + ) + mylogger.info(f"Save translation vector to {temp_path}") + np.save(temp_path, tvec_donor_volume.cpu()) + mylogger.info("-==- Done -==-") + + mylogger.info("Finding zeros values in the RAW data and mark them for masking") + for i in range(0, data.shape[0]): + zero_pixel_mask = torch.any(data[i, ...] == 0, dim=0) + data[i, :, zero_pixel_mask] = -100.0 + mylogger.info( + f"{config['required_order'][i]}: " + f"found {int(zero_pixel_mask.type(dtype=dtype).sum())} pixel " + f"with zeros " + ) + mylogger.info("-==- Done -==-") + + mylogger.info("Update mask with the new regions due to alignment") + + new_mask_area: torch.Tensor = torch.any(torch.any(data < -0.1, dim=0), dim=0).bool() + mask = (mask == 0).bool() + mask = torch.logical_or(mask, new_mask_area) + mask_negative: torch.Tensor = mask.clone() + mask_positve: torch.Tensor = torch.logical_not(mask) + del mask + + mylogger.info("Update the data with the new mask") + data *= mask_positve.unsqueeze(0).unsqueeze(0).type(dtype=dtype) + mylogger.info("-==- Done -==-") + + if config["save_aligned_as_python"]: + + temp_path = os.path.join( + config["export_path"], experiment_name + "_aligned.npz" + ) + mylogger.info(f"Save aligned data and mask to {temp_path}") + np.savez_compressed( + temp_path, + data=data.cpu(), + mask=mask_positve.cpu(), + acceptor_index=acceptor_index, + donor_index=donor_index, + oxygenation_index=oxygenation_index, + volume_index=volume_index, + ) + + if config["save_aligned_as_matlab"]: + temp_path = os.path.join( + config["export_path"], experiment_name + "_aligned.hd5" + ) + mylogger.info(f"Save aligned data and mask to {temp_path}") + file_handle = h5py.File(temp_path, "w") + + _ = file_handle.create_dataset( + "mask", + data=mask_positve.movedim(0, -1).type(torch.uint8).cpu(), + compression="gzip", + compression_opts=9, + ) + + _ = file_handle.create_dataset( + "data", + data=data.movedim(1, -1).movedim(0, -1).cpu(), + compression="gzip", + compression_opts=9, + ) + + _ = file_handle.create_dataset( + "acceptor_index", + data=torch.tensor((acceptor_index,)), + compression="gzip", + compression_opts=9, + ) + + _ = file_handle.create_dataset( + "donor_index", + data=torch.tensor((donor_index,)), + compression="gzip", + compression_opts=9, + ) + + _ = file_handle.create_dataset( + "oxygenation_index", + data=torch.tensor((oxygenation_index,)), + compression="gzip", + compression_opts=9, + ) + + _ = file_handle.create_dataset( + "volume_index", + data=torch.tensor((volume_index,)), + compression="gzip", + compression_opts=9, + ) + + mylogger.info("Reminder: How to read with matlab:") + mylogger.info(f"mask = h5read('{temp_path}','/mask');") + mylogger.info(f"data_acceptor = h5read('{temp_path}','/data');") + file_handle.close() + + mylogger.info("") + mylogger.info("(D) INTER-FRAME INTERPOLATION") + mylogger.info("-----------------------------------------------") + mylogger.info("") + + mylogger.info("Interpolate the 'in-between' frames for oxygenation and volume") + data[oxygenation_index, 1:, ...] = ( + data[oxygenation_index, 1:, ...] + data[oxygenation_index, :-1, ...] + ) / 2.0 + data[volume_index, 1:, ...] = ( + data[volume_index, 1:, ...] + data[volume_index, :-1, ...] + ) / 2.0 + mylogger.info("-==- Done -==-") + + sample_frequency: float = 1.0 / meta_frame_time + + if config["gevi"]: + assert config["heartbeat_remove"] + + if config["heartbeat_remove"]: + + mylogger.info("") + mylogger.info("(E-OPTIONAL) HEARTBEAT REMOVAL") + mylogger.info("-----------------------------------------------") + mylogger.info("") + + mylogger.info("Extract heartbeat from volume signal") + heartbeat_ts: torch.Tensor = bandpass( + data=data[volume_index, ...].movedim(0, -1).clone(), + low_frequency=config["lower_freqency_bandpass"], + high_frequency=config["upper_freqency_bandpass"], + fs=sample_frequency, + filtfilt_chuck_size=config["heartbeat_filtfilt_chuck_size"], + ) + heartbeat_ts = heartbeat_ts.flatten(start_dim=0, end_dim=-2) + mask_flatten: torch.Tensor = mask_positve.flatten(start_dim=0, end_dim=-1) + + heartbeat_ts = heartbeat_ts[mask_flatten, :] + + heartbeat_ts = heartbeat_ts.movedim(0, -1) + heartbeat_ts -= heartbeat_ts.mean(dim=0, keepdim=True) + + try: + volume_heartbeat, _, _ = torch.linalg.svd(heartbeat_ts, full_matrices=False) + except torch.cuda.OutOfMemoryError: + mylogger.info("torch.cuda.OutOfMemoryError: Fallback to cpu") + volume_heartbeat_cpu, _, _ = torch.linalg.svd( + heartbeat_ts.cpu(), full_matrices=False + ) + volume_heartbeat = volume_heartbeat_cpu.to(heartbeat_ts.device, copy=True) + del volume_heartbeat_cpu + + volume_heartbeat = volume_heartbeat[:, 0] + volume_heartbeat -= volume_heartbeat[ + config["skip_frames_in_the_beginning"] : + ].mean() + + del heartbeat_ts + + if device != torch.device("cpu"): + torch.cuda.empty_cache() + mylogger.info("Empty CUDA cache") + free_mem = cuda_total_memory - max( + [ + torch.cuda.memory_reserved(device), + torch.cuda.memory_allocated(device), + ] + ) + mylogger.info(f"CUDA memory: {free_mem // 1024} MByte") + + if config["save_heartbeat"]: + temp_path = os.path.join( + config["export_path"], experiment_name + "_volume_heartbeat.npy" + ) + mylogger.info(f"Save volume heartbeat to {temp_path}") + np.save(temp_path, volume_heartbeat.cpu()) + mylogger.info("-==- Done -==-") + + volume_heartbeat = volume_heartbeat.unsqueeze(0).unsqueeze(0) + norm_volume_heartbeat = ( + volume_heartbeat[..., config["skip_frames_in_the_beginning"] :] ** 2 + ).sum(dim=-1) + + heartbeat_coefficients: torch.Tensor = torch.zeros( + (data.shape[0], data.shape[-2], data.shape[-1]), + dtype=data.dtype, + device=data.device, + ) + for i in range(0, data.shape[0]): + y = bandpass( + data=data[i, ...].movedim(0, -1).clone(), + low_frequency=config["lower_freqency_bandpass"], + high_frequency=config["upper_freqency_bandpass"], + fs=sample_frequency, + filtfilt_chuck_size=config["heartbeat_filtfilt_chuck_size"], + )[..., config["skip_frames_in_the_beginning"] :] + y -= y.mean(dim=-1, keepdim=True) + + heartbeat_coefficients[i, ...] = ( + volume_heartbeat[..., config["skip_frames_in_the_beginning"] :] * y + ).sum(dim=-1) / norm_volume_heartbeat + + heartbeat_coefficients[i, ...] *= mask_positve.type( + dtype=heartbeat_coefficients.dtype + ) + del y + + if config["save_heartbeat"]: + temp_path = os.path.join( + config["export_path"], experiment_name + "_heartbeat_coefficients.npy" + ) + mylogger.info(f"Save heartbeat coefficients to {temp_path}") + np.save(temp_path, heartbeat_coefficients.cpu()) + mylogger.info("-==- Done -==-") + + mylogger.info("Remove heart beat from data") + data -= heartbeat_coefficients.unsqueeze(1) * volume_heartbeat.unsqueeze( + 0 + ).movedim(-1, 1) + # data_herzlos = data.clone() + mylogger.info("-==- Done -==-") + + if config["gevi"]: # UDO scaling performed! + + mylogger.info("") + mylogger.info("(F-OPTIONAL) DONOR/ACCEPTOR SCALING") + mylogger.info("-----------------------------------------------") + mylogger.info("") + + donor_heartbeat_factor = heartbeat_coefficients[donor_index, ...].clone() + acceptor_heartbeat_factor = heartbeat_coefficients[ + acceptor_index, ... + ].clone() + del heartbeat_coefficients + + if device != torch.device("cpu"): + torch.cuda.empty_cache() + mylogger.info("Empty CUDA cache") + free_mem = cuda_total_memory - max( + [ + torch.cuda.memory_reserved(device), + torch.cuda.memory_allocated(device), + ] + ) + mylogger.info(f"CUDA memory: {free_mem // 1024} MByte") + + mylogger.info("Calculate scaling factor for donor and acceptor") + # donor_factor: torch.Tensor = ( + # donor_heartbeat_factor + acceptor_heartbeat_factor + # ) / (2 * donor_heartbeat_factor) + # acceptor_factor: torch.Tensor = ( + # donor_heartbeat_factor + acceptor_heartbeat_factor + # ) / (2 * acceptor_heartbeat_factor) + donor_factor = torch.sqrt( + acceptor_heartbeat_factor / donor_heartbeat_factor + ) + acceptor_factor = 1 / donor_factor + + # import matplotlib.pyplot as plt + # plt.pcolor(donor_factor, vmin=0.5, vmax=2.0) + # plt.colorbar() + # plt.show() + # plt.pcolor(acceptor_factor, vmin=0.5, vmax=2.0) + # plt.colorbar() + # plt.show() + # TODO remove + + del donor_heartbeat_factor + del acceptor_heartbeat_factor + + # import matplotlib.pyplot as plt + # plt.pcolor(torch.std(data[acceptor_index, config["skip_frames_in_the_beginning"] :, ...], axis=0), vmin=0, vmax=500) + # plt.colorbar() + # plt.show() + # plt.pcolor(torch.std(data[donor_index, config["skip_frames_in_the_beginning"] :, ...], axis=0), vmin=0, vmax=500) + # plt.colorbar() + # plt.show() + # TODO remove + + if config["save_factors"]: + temp_path = os.path.join( + config["export_path"], experiment_name + "_donor_factor.npy" + ) + mylogger.info(f"Save donor factor to {temp_path}") + np.save(temp_path, donor_factor.cpu()) + + temp_path = os.path.join( + config["export_path"], experiment_name + "_acceptor_factor.npy" + ) + mylogger.info(f"Save acceptor factor to {temp_path}") + np.save(temp_path, acceptor_factor.cpu()) + mylogger.info("-==- Done -==-") + + # TODO we have to calculate means first! + mylogger.info("Extract means for acceptor and donor first") + mean_values_acceptor = data[ + acceptor_index, config["skip_frames_in_the_beginning"] :, ... + ].nanmean(dim=0, keepdim=True) + mean_values_donor = data[ + donor_index, config["skip_frames_in_the_beginning"] :, ... + ].nanmean(dim=0, keepdim=True) + + mylogger.info("Scale acceptor to heart beat amplitude") + mylogger.info("Remove mean") + data[acceptor_index, ...] -= mean_values_acceptor + mylogger.info("Apply acceptor_factor and mask") + # data[acceptor_index, ...] *= acceptor_factor.unsqueeze( + # 0 + # ) * mask_positve.unsqueeze(0) + acceptor_factor_correction = np.sqrt( + mean_values_acceptor / mean_values_donor + ) + data[acceptor_index, ...] *= acceptor_factor.unsqueeze( + 0 + ) * acceptor_factor_correction * mask_positve.unsqueeze(0) + mylogger.info("Add mean") + data[acceptor_index, ...] += mean_values_acceptor + mylogger.info("-==- Done -==-") + + mylogger.info("Scale donor to heart beat amplitude") + mylogger.info("Remove mean") + data[donor_index, ...] -= mean_values_donor + mylogger.info("Apply donor_factor and mask") + # data[donor_index, ...] *= donor_factor.unsqueeze( + # 0 + # ) * mask_positve.unsqueeze(0) + donor_factor_correction = 1 / acceptor_factor_correction + data[donor_index, ...] *= donor_factor.unsqueeze( + 0 + ) * donor_factor_correction * mask_positve.unsqueeze(0) + mylogger.info("Add mean") + data[donor_index, ...] += mean_values_donor + mylogger.info("-==- Done -==-") + + # import matplotlib.pyplot as plt + # plt.pcolor(mean_values_acceptor[0]) + # plt.colorbar() + # plt.show() + # plt.pcolor(mean_values_donor[0]) + # plt.colorbar() + # plt.show() + # TODO remove + + # TODO SCHNUGGEL + else: + mylogger.info("GECI does not require acceptor/donor scaling, skipping!") + mylogger.info("-==- Done -==-") + + mylogger.info("") + mylogger.info("(G) CONVERSION TO RELATIVE SIGNAL CHANGES (DIV/MEAN)") + mylogger.info("-----------------------------------------------") + mylogger.info("") + + mylogger.info("Divide by mean over time") + data /= data[:, config["skip_frames_in_the_beginning"] :, ...].nanmean( + dim=1, + keepdim=True, + ) + mylogger.info("-==- Done -==-") + + mylogger.info("") + mylogger.info("(H) CLEANING BY REGRESSION") + mylogger.info("-----------------------------------------------") + mylogger.info("") + + data = data.nan_to_num(nan=0.0) + mylogger.info("Preparation for regression -- Gauss smear") + spatial_width = float(config["gauss_smear_spatial_width"]) + + if config["binning_enable"] and (config["binning_at_the_end"] is False): + spatial_width /= float(config["binning_kernel_size"]) + + mylogger.info( + f"Mask -- " + f"spatial width: {spatial_width}, " + f"temporal width: {float(config['gauss_smear_temporal_width'])}, " + f"use matlab mode: {bool(config['gauss_smear_use_matlab_mask'])} " + ) + + input_mask = mask_positve.type(dtype=dtype).clone() + + filtered_mask: torch.Tensor + filtered_mask, _ = gauss_smear_individual( + input=input_mask, + spatial_width=spatial_width, + temporal_width=float(config["gauss_smear_temporal_width"]), + use_matlab_mask=bool(config["gauss_smear_use_matlab_mask"]), + epsilon=float(torch.finfo(input_mask.dtype).eps), + ) + + mylogger.info("creating a copy of the data") + data_filtered = data.clone().movedim(1, -1) + if device != torch.device("cpu"): + torch.cuda.empty_cache() + mylogger.info("Empty CUDA cache") + free_mem = cuda_total_memory - max( + [torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)] + ) + mylogger.info(f"CUDA memory: {free_mem // 1024} MByte") + + overwrite_fft_gauss: None | torch.Tensor = None + for i in range(0, data_filtered.shape[0]): + mylogger.info( + f"{config['required_order'][i]} -- " + f"spatial width: {spatial_width}, " + f"temporal width: {float(config['gauss_smear_temporal_width'])}, " + f"use matlab mode: {bool(config['gauss_smear_use_matlab_mask'])} " + ) + data_filtered[i, ...] *= input_mask.unsqueeze(-1) + data_filtered[i, ...], overwrite_fft_gauss = gauss_smear_individual( + input=data_filtered[i, ...], + spatial_width=spatial_width, + temporal_width=float(config["gauss_smear_temporal_width"]), + overwrite_fft_gauss=overwrite_fft_gauss, + use_matlab_mask=bool(config["gauss_smear_use_matlab_mask"]), + epsilon=float(torch.finfo(input_mask.dtype).eps), + ) + + data_filtered[i, ...] /= filtered_mask + 1e-20 + data_filtered[i, ...] += 1.0 - input_mask.unsqueeze(-1) + + del filtered_mask + del overwrite_fft_gauss + del input_mask + mylogger.info("data_filtered is populated") + + if device != torch.device("cpu"): + torch.cuda.empty_cache() + mylogger.info("Empty CUDA cache") + free_mem = cuda_total_memory - max( + [torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)] + ) + mylogger.info(f"CUDA memory: {free_mem // 1024} MByte") + mylogger.info("-==- Done -==-") + + mylogger.info("Preperation for Regression") + mylogger.info("Move time dimensions of data to the last dimension") + data = data.movedim(1, -1) + + dual_signal_mode: bool = True + if len(config["target_camera_acceptor"]) > 0: + mylogger.info("Regression Acceptor") + mylogger.info(f"Target: {config['target_camera_acceptor']}") + mylogger.info( + f"Regressors: constant, linear and {config['regressor_cameras_acceptor']}" + ) + target_id: int = config["required_order"].index( + config["target_camera_acceptor"] + ) + regressor_id: list[int] = [] + for i in range(0, len(config["regressor_cameras_acceptor"])): + regressor_id.append( + config["required_order"].index(config["regressor_cameras_acceptor"][i]) + ) + + data_acceptor, coefficients_acceptor = regression( + mylogger=mylogger, + target_camera_id=target_id, + regressor_camera_ids=regressor_id, + mask=mask_negative, + data=data, + data_filtered=data_filtered, + first_none_ramp_frame=int(config["skip_frames_in_the_beginning"]), + ) + + if config["save_regression_coefficients"]: + temp_path = os.path.join( + config["export_path"], experiment_name + "_coefficients_acceptor.npy" + ) + mylogger.info(f"Save acceptor coefficients to {temp_path}") + np.save(temp_path, coefficients_acceptor.cpu()) + del coefficients_acceptor + + mylogger.info("-==- Done -==-") + else: + dual_signal_mode = False + target_id = config["required_order"].index("acceptor") + data_acceptor = data[target_id, ...].clone() + data_acceptor[mask_negative, :] = 0.0 + + if len(config["target_camera_donor"]) > 0: + mylogger.info("Regression Donor") + mylogger.info(f"Target: {config['target_camera_donor']}") + mylogger.info( + f"Regressors: constant, linear and {config['regressor_cameras_donor']}" + ) + target_id = config["required_order"].index(config["target_camera_donor"]) + regressor_id = [] + for i in range(0, len(config["regressor_cameras_donor"])): + regressor_id.append( + config["required_order"].index(config["regressor_cameras_donor"][i]) + ) + + data_donor, coefficients_donor = regression( + mylogger=mylogger, + target_camera_id=target_id, + regressor_camera_ids=regressor_id, + mask=mask_negative, + data=data, + data_filtered=data_filtered, + first_none_ramp_frame=int(config["skip_frames_in_the_beginning"]), + ) + + if config["save_regression_coefficients"]: + temp_path = os.path.join( + config["export_path"], experiment_name + "_coefficients_donor.npy" + ) + mylogger.info(f"Save acceptor donor to {temp_path}") + np.save(temp_path, coefficients_donor.cpu()) + del coefficients_donor + mylogger.info("-==- Done -==-") + else: + dual_signal_mode = False + target_id = config["required_order"].index("donor") + data_donor = data[target_id, ...].clone() + data_donor[mask_negative, :] = 0.0 + + # TODO clean up ---> + if config["save_oxyvol_as_python"] or config["save_oxyvol_as_matlab"]: + + mylogger.info("") + mylogger.info("(I-OPTIONAL) SAVE OXY/VOL/MASK") + mylogger.info("-----------------------------------------------") + mylogger.info("") + + # extract oxy and vol + mylogger.info("Save Oxygenation/Volume/Mask") + data_oxygenation = data[oxygenation_index, ...].clone() + data_volume = data[volume_index, ...].clone() + data_mask = mask_positve.clone() + + # bin, if required... + if config["binning_enable"] and config["binning_at_the_end"]: + mylogger.info("Binning of data") + mylogger.info( + ( + f"kernel_size={int(config['binning_kernel_size'])}, " + f"stride={int(config['binning_stride'])}, " + "divisor_override=None" + ) + ) + + data_oxygenation = binning( + data_oxygenation.unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=None, + ).squeeze(-1) + + data_volume = binning( + data_volume.unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=None, + ).squeeze(-1) + + data_mask = ( + binning( + data_mask.unsqueeze(-1).unsqueeze(-1).type(dtype=dtype), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=None, + ) + .squeeze(-1) + .squeeze(-1) + ) + data_mask = (data_mask > 0).type(torch.bool) + + if config["save_oxyvol_as_python"]: + + # export it! + temp_path = os.path.join( + config["export_path"], experiment_name + "_oxygenation_volume.npz" + ) + mylogger.info(f"Save data oxygenation and volume to {temp_path}") + np.savez_compressed( + temp_path, + data_oxygenation=data_oxygenation.cpu(), + data_volume=data_volume.cpu(), + data_mask=data_mask.cpu(), + ) + + if config["save_oxyvol_as_matlab"]: + + temp_path = os.path.join( + config["export_path"], experiment_name + "_oxygenation_volume.hd5" + ) + mylogger.info(f"Save data oxygenation and volume to {temp_path}") + file_handle = h5py.File(temp_path, "w") + + data_mask = data_mask.movedim(0, -1) + data_oxygenation = data_oxygenation.movedim(1, -1).movedim(0, -1) + data_volume = data_volume.movedim(1, -1).movedim(0, -1) + _ = file_handle.create_dataset( + "data_mask", + data=data_mask.type(torch.uint8).cpu(), + compression="gzip", + compression_opts=9, + ) + _ = file_handle.create_dataset( + "data_oxygenation", + data=data_oxygenation.cpu(), + compression="gzip", + compression_opts=9, + ) + _ = file_handle.create_dataset( + "data_volume", + data=data_volume.cpu(), + compression="gzip", + compression_opts=9, + ) + mylogger.info("Reminder: How to read with matlab:") + mylogger.info(f"data_mask = h5read('{temp_path}','/data_mask');") + mylogger.info(f"data_oxygenation = h5read('{temp_path}','/data_oxygenation');") + mylogger.info(f"data_volume = h5read('{temp_path}','/data_volume');") + file_handle.close() + # TODO <------ clean up + + del data + del data_filtered + + if device != torch.device("cpu"): + torch.cuda.empty_cache() + mylogger.info("Empty CUDA cache") + free_mem = cuda_total_memory - max( + [torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)] + ) + mylogger.info(f"CUDA memory: {free_mem // 1024} MByte") + + # ##################### + + if config["gevi"]: + assert dual_signal_mode + else: + assert dual_signal_mode is False + + if dual_signal_mode is False: + + mylogger.info("") + mylogger.info("(J1-OPTIONAL) SAVE ACC/DON/MASK (NO RATIO!+OPT BIN@END)") + mylogger.info("-----------------------------------------------") + mylogger.info("") + + mylogger.info("mono signal model") + + mylogger.info("Remove nan") + data_acceptor = torch.nan_to_num(data_acceptor, nan=0.0) + data_donor = torch.nan_to_num(data_donor, nan=0.0) + mylogger.info("-==- Done -==-") + + if config["binning_enable"] and config["binning_at_the_end"]: + mylogger.info("Binning of data") + mylogger.info( + ( + f"kernel_size={int(config['binning_kernel_size'])}, " + f"stride={int(config['binning_stride'])}, " + "divisor_override=None" + ) + ) + + data_acceptor = binning( + data_acceptor.unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=None, + ).squeeze(-1) + + data_donor = binning( + data_donor.unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=None, + ).squeeze(-1) + + mask_positve = ( + binning( + mask_positve.unsqueeze(-1).unsqueeze(-1).type(dtype=dtype), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=None, + ) + .squeeze(-1) + .squeeze(-1) + ) + mask_positve = (mask_positve > 0).type(torch.bool) + + if config["save_as_python"]: + + temp_path = os.path.join( + config["export_path"], experiment_name + "_acceptor_donor.npz" + ) + mylogger.info(f"Save data donor and acceptor and mask to {temp_path}") + np.savez_compressed( + temp_path, + data_acceptor=data_acceptor.cpu(), + data_donor=data_donor.cpu(), + mask=mask_positve.cpu(), + ) + + if config["save_as_matlab"]: + temp_path = os.path.join( + config["export_path"], experiment_name + "_acceptor_donor.hd5" + ) + mylogger.info(f"Save data donor and acceptor and mask to {temp_path}") + file_handle = h5py.File(temp_path, "w") + + mask_positve = mask_positve.movedim(0, -1) + data_acceptor = data_acceptor.movedim(1, -1).movedim(0, -1) + data_donor = data_donor.movedim(1, -1).movedim(0, -1) + _ = file_handle.create_dataset( + "mask", + data=mask_positve.type(torch.uint8).cpu(), + compression="gzip", + compression_opts=9, + ) + _ = file_handle.create_dataset( + "data_acceptor", + data=data_acceptor.cpu(), + compression="gzip", + compression_opts=9, + ) + _ = file_handle.create_dataset( + "data_donor", + data=data_donor.cpu(), + compression="gzip", + compression_opts=9, + ) + mylogger.info("Reminder: How to read with matlab:") + mylogger.info(f"mask = h5read('{temp_path}','/mask');") + mylogger.info(f"data_acceptor = h5read('{temp_path}','/data_acceptor');") + mylogger.info(f"data_donor = h5read('{temp_path}','/data_donor');") + file_handle.close() + return + # ##################### + + mylogger.info("") + mylogger.info("(J2-OPTIONAL) BUILD AND SAVE RATIO (+OPT BIN@END)") + mylogger.info("-----------------------------------------------") + mylogger.info("") + + mylogger.info("Calculate ratio sequence") + + if config["classical_ratio_mode"]: + mylogger.info("via acceptor / donor") + ratio_sequence: torch.Tensor = data_acceptor / data_donor + mylogger.info("via / mean over time") + ratio_sequence /= ratio_sequence.mean(dim=-1, keepdim=True) + else: + mylogger.info("via 1.0 + acceptor - donor") + ratio_sequence = 1.0 + data_acceptor - data_donor + + mylogger.info("Remove nan") + ratio_sequence = torch.nan_to_num(ratio_sequence, nan=0.0) + mylogger.info("-==- Done -==-") + + if config["binning_enable"] and config["binning_at_the_end"]: + mylogger.info("Binning of data") + mylogger.info( + ( + f"kernel_size={int(config['binning_kernel_size'])}, " + f"stride={int(config['binning_stride'])}, " + "divisor_override=None" + ) + ) + + ratio_sequence = binning( + ratio_sequence.unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=None, + ).squeeze(-1) + + if config["save_gevi_with_donor_acceptor"]: + data_acceptor = binning( + data_acceptor.unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=None, + ).squeeze(-1) + + data_donor = binning( + data_donor.unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=None, + ).squeeze(-1) + + mask_positve = ( + binning( + mask_positve.unsqueeze(-1).unsqueeze(-1).type(dtype=dtype), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=None, + ) + .squeeze(-1) + .squeeze(-1) + ) + mask_positve = (mask_positve > 0).type(torch.bool) + + if config["save_as_python"]: + temp_path = os.path.join( + config["export_path"], experiment_name + "_ratio_sequence.npz" + ) + mylogger.info(f"Save ratio_sequence and mask to {temp_path}") + if config["save_gevi_with_donor_acceptor"]: + np.savez_compressed( + temp_path, ratio_sequence=ratio_sequence.cpu(), mask=mask_positve.cpu(), data_acceptor=data_acceptor.cpu(), data_donor=data_donor.cpu() + ) + else: + np.savez_compressed( + temp_path, ratio_sequence=ratio_sequence.cpu(), mask=mask_positve.cpu() + ) + + if config["save_as_matlab"]: + temp_path = os.path.join( + config["export_path"], experiment_name + "_ratio_sequence.hd5" + ) + mylogger.info(f"Save ratio_sequence and mask to {temp_path}") + file_handle = h5py.File(temp_path, "w") + + mask_positve = mask_positve.movedim(0, -1) + ratio_sequence = ratio_sequence.movedim(1, -1).movedim(0, -1) + _ = file_handle.create_dataset( + "mask", + data=mask_positve.type(torch.uint8).cpu(), + compression="gzip", + compression_opts=9, + ) + _ = file_handle.create_dataset( + "ratio_sequence", + data=ratio_sequence.cpu(), + compression="gzip", + compression_opts=9, + ) + if config["save_gevi_with_donor_acceptor"]: + _ = file_handle.create_dataset( + "data_acceptor", + data=data_acceptor.cpu(), + compression="gzip", + compression_opts=9, + ) + _ = file_handle.create_dataset( + "data_donor", + data=data_donor.cpu(), + compression="gzip", + compression_opts=9, + ) + mylogger.info("Reminder: How to read with matlab:") + mylogger.info(f"mask = h5read('{temp_path}','/mask');") + mylogger.info(f"ratio_sequence = h5read('{temp_path}','/ratio_sequence');") + if config["save_gevi_with_donor_acceptor"]: + mylogger.info(f"data_donor = h5read('{temp_path}','/data_donor');") + mylogger.info(f"data_acceptor = h5read('{temp_path}','/data_acceptor');") + file_handle.close() + + del ratio_sequence + del mask_positve + del mask_negative + + mylogger.info("") + mylogger.info("***********************************************") + mylogger.info("* TRIAL END ***********************************") + mylogger.info("***********************************************") + mylogger.info("") + + return + + +def main( + *, + config_filename: str = "config.json", + experiment_id_overwrite: int = -1, + trial_id_overwrite: int = -1, +) -> None: + mylogger = create_logger( + save_logging_messages=True, + display_logging_messages=True, + log_stage_name="stage_4", + ) + + config = load_config(mylogger=mylogger, filename=config_filename) + + if (config["save_as_python"] is False) and (config["save_as_matlab"] is False): + mylogger.info("No output will be created. ") + mylogger.info("Change save_as_python and/or save_as_matlab in the config file") + mylogger.info("ERROR: STOP!!!") + exit() + + if (len(config["target_camera_donor"]) == 0) and ( + len(config["target_camera_acceptor"]) == 0 + ): + mylogger.info( + "Configure at least target_camera_donor or target_camera_acceptor correctly." + ) + mylogger.info("ERROR: STOP!!!") + exit() + + device = get_torch_device(mylogger, config["force_to_cpu"]) + + mylogger.info( + f"Create directory {config['export_path']} in the case it does not exist" + ) + os.makedirs(config["export_path"], exist_ok=True) + + raw_data_path: str = os.path.join( + config["basic_path"], + config["recoding_data"], + config["mouse_identifier"], + config["raw_path"], + ) + + if os.path.isdir(raw_data_path) is False: + mylogger.info(f"ERROR: could not find raw directory {raw_data_path}!!!!") + exit() + + if experiment_id_overwrite == -1: + experiments = get_experiments(raw_data_path) + else: + assert experiment_id_overwrite >= 0 + experiments = torch.tensor([experiment_id_overwrite]) + + for experiment_counter in range(0, experiments.shape[0]): + experiment_id = int(experiments[experiment_counter]) + + if trial_id_overwrite == -1: + trials = get_trials(raw_data_path, experiment_id) + else: + assert trial_id_overwrite >= 0 + trials = torch.tensor([trial_id_overwrite]) + + for trial_counter in range(0, trials.shape[0]): + trial_id = int(trials[trial_counter]) + + mylogger.info("") + mylogger.info( + f"======= EXPERIMENT ID: {experiment_id} ==== TRIAL ID: {trial_id} =======" + ) + mylogger.info("") + + try: + process_trial( + config=config, + mylogger=mylogger, + experiment_id=experiment_id, + trial_id=trial_id, + device=device, + ) + except torch.cuda.OutOfMemoryError: + mylogger.info("WARNING: RUNNING IN FAILBACK MODE!!!!") + mylogger.info("Not enough GPU memory. Retry on CPU") + process_trial( + config=config, + mylogger=mylogger, + experiment_id=experiment_id, + trial_id=trial_id, + device=torch.device("cpu"), + ) + + +if __name__ == "__main__": + argh.dispatch_command(main) + +# %% diff --git a/stage_5_convert_metadata.py b/stage_5_convert_metadata.py new file mode 100644 index 0000000..ed4ef73 --- /dev/null +++ b/stage_5_convert_metadata.py @@ -0,0 +1,57 @@ +import json +import os +import argh +from jsmin import jsmin # type:ignore +from functions.get_trials import get_trials +from functions.get_experiments import get_experiments + + +def converter(filename: str = "config_M_Sert_Cre_49.json") -> None: + + if os.path.isfile(filename) is False: + print(f"{filename} is missing") + exit() + + with open(filename, "r") as file: + config = json.loads(jsmin(file.read())) + + raw_data_path: str = os.path.join( + config["basic_path"], + config["recoding_data"], + config["mouse_identifier"], + config["raw_path"], + ) + + if os.path.isdir(raw_data_path) is False: + print(f"ERROR: could not find raw directory {raw_data_path}!!!!") + exit() + + experiments = get_experiments(raw_data_path).numpy() + + os.makedirs(config["export_path"], exist_ok=True) + + for experiment in experiments: + + trials = get_trials(raw_data_path, experiment).numpy() + assert trials.shape[0] > 0 + + with open( + os.path.join( + raw_data_path, + f"Exp{experiment:03d}_Trial{trials[0]:03d}_Part001_meta.txt", + ), + "r", + ) as file: + metadata = json.loads(jsmin(file.read())) + + filename_out: str = os.path.join( + config["export_path"], + f"metadata_exp{experiment:03d}.json", + ) + + with open(filename_out, 'w') as file: + json.dump(metadata, file) + + +if __name__ == "__main__": + argh.dispatch_command(converter)