From a348475f3926aeb33d17e3c19fcac8014236bcd2 Mon Sep 17 00:00:00 2001 From: David Rotermund <54365609+davrot@users.noreply.github.com> Date: Wed, 28 Feb 2024 16:15:38 +0100 Subject: [PATCH] Delete new_pipeline directory --- new_pipeline/config.json | 62 - new_pipeline/functions/ImageAlignment.py | 1015 ----------------- new_pipeline/functions/align_refref.py | 57 - new_pipeline/functions/bandpass.py | 85 -- new_pipeline/functions/binning.py | 21 - new_pipeline/functions/calculate_rotation.py | 40 - .../functions/calculate_translation.py | 37 - new_pipeline/functions/create_logger.py | 37 - new_pipeline/functions/data_raw_loader.py | 339 ------ .../functions/gauss_smear_individual.py | 127 --- new_pipeline/functions/get_experiments.py | 19 - new_pipeline/functions/get_parts.py | 18 - new_pipeline/functions/get_torch_device.py | 17 - new_pipeline/functions/get_trials.py | 18 - new_pipeline/functions/load_config.py | 16 - new_pipeline/functions/load_meta_data.py | 63 - .../perform_donor_volume_rotation.py | 140 --- .../perform_donor_volume_translation.py | 143 --- new_pipeline/functions/regression.py | 117 -- new_pipeline/functions/regression_internal.py | 20 - new_pipeline/stage_1_get_ref_image.py | 126 -- new_pipeline/stage_2_make_heartbeat_mask.py | 153 --- new_pipeline/stage_3_refine_mask.py | 157 --- new_pipeline/stage_4_process.py | 918 --------------- 24 files changed, 3745 deletions(-) delete mode 100644 new_pipeline/config.json delete mode 100644 new_pipeline/functions/ImageAlignment.py delete mode 100644 new_pipeline/functions/align_refref.py delete mode 100644 new_pipeline/functions/bandpass.py delete mode 100644 new_pipeline/functions/binning.py delete mode 100644 new_pipeline/functions/calculate_rotation.py delete mode 100644 new_pipeline/functions/calculate_translation.py delete mode 100644 new_pipeline/functions/create_logger.py delete mode 100644 new_pipeline/functions/data_raw_loader.py delete mode 100644 new_pipeline/functions/gauss_smear_individual.py delete mode 100644 new_pipeline/functions/get_experiments.py delete mode 100644 new_pipeline/functions/get_parts.py delete mode 100644 new_pipeline/functions/get_torch_device.py delete mode 100644 new_pipeline/functions/get_trials.py delete mode 100644 new_pipeline/functions/load_config.py delete mode 100644 new_pipeline/functions/load_meta_data.py delete mode 100644 new_pipeline/functions/perform_donor_volume_rotation.py delete mode 100644 new_pipeline/functions/perform_donor_volume_translation.py delete mode 100644 new_pipeline/functions/regression.py delete mode 100644 new_pipeline/functions/regression_internal.py delete mode 100644 new_pipeline/stage_1_get_ref_image.py delete mode 100644 new_pipeline/stage_2_make_heartbeat_mask.py delete mode 100644 new_pipeline/stage_3_refine_mask.py delete mode 100644 new_pipeline/stage_4_process.py diff --git a/new_pipeline/config.json b/new_pipeline/config.json deleted file mode 100644 index 2746a54..0000000 --- a/new_pipeline/config.json +++ /dev/null @@ -1,62 +0,0 @@ -{ - "basic_path": "/data_1/hendrik", - "recoding_data": "2021-06-17", - "mouse_identifier": "M3859M", - //"basic_path": "/data_1/robert", - //"recoding_data": "2021-10-05", - //"mouse_identifier": "M3879M", - "raw_path": "raw", - "export_path": "output", - "ref_image_path": "ref_images", - // Ratio Sequence - "classical_ratio_mode": true, // true: a/d false: 1+a-d - // Regression - "target_camera_acceptor": "acceptor", - "regressor_cameras_acceptor": [ - "oxygenation", - "volume" - ], - "target_camera_donor": "donor", - "regressor_cameras_donor": [ - "oxygenation", - "volume" - ], - // binning - "binning_enable": true, - "binning_at_the_end": false, - "binning_kernel_size": 4, - "binning_stride": 4, - "binning_divisor_override": 1, - // alignment - "alignment_batch_size": 200, - "rotation_stabilization_threshold_factor": 3.0, // >= 1.0 - "rotation_stabilization_threshold_border": 0.9, // <= 1.0 - // Heart beat detection - "lower_freqency_bandpass": 5.0, // Hz - "upper_freqency_bandpass": 14.0, // Hz - "heartbeat_filtfilt_chuck_size": 10, - // Gauss smear - "gauss_smear_spatial_width": 8, - "gauss_smear_temporal_width": 0.1, - "gauss_smear_use_matlab_mask": false, - // LED Ramp on - "skip_frames_in_the_beginning": 100, // Frames - // PyTorch - "dtype": "float32", - "force_to_cpu": false, - // Save - "save_as_python": true, // produces .npz files (compressed) - "save_as_matlab": false, // produces .hd5 file (compressed) - // Save extra information - "save_alignment": false, - "save_heartbeat": false, - "save_factors": false, - "save_regression_coefficients": false, - // Not important parameter - "required_order": [ - "acceptor", - "donor", - "oxygenation", - "volume" - ] -} \ No newline at end of file diff --git a/new_pipeline/functions/ImageAlignment.py b/new_pipeline/functions/ImageAlignment.py deleted file mode 100644 index 6472d02..0000000 --- a/new_pipeline/functions/ImageAlignment.py +++ /dev/null @@ -1,1015 +0,0 @@ -import torch -import torchvision as tv # type: ignore - -# The source code is based on: -# https://github.com/matejak/imreg_dft - -# The original LICENSE: -# Copyright (c) 2014, Matěj Týč -# Copyright (c) 2011-2014, Christoph Gohlke -# Copyright (c) 2011-2014, The Regents of the University of California -# Produced at the Laboratory for Fluorescence Dynamics - -# All rights reserved. - -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: - -# * Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. - -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. - -# * Neither the name of the {organization} nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. - -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - -class ImageAlignment(torch.nn.Module): - device: torch.device - default_dtype: torch.dtype - excess_const: float = 1.1 - exponent: str = "inf" - success: torch.Tensor | None = None - - # The factor that detmines how many - # sub-pixel we will shift - scale_factor: int = 4 - - reference_image: torch.Tensor | None = None - - last_scale: torch.Tensor | None = None - last_angle: torch.Tensor | None = None - last_tvec: torch.Tensor | None = None - - # Cache - image_reference_dft: torch.Tensor | None = None - filt: torch.Tensor - pcorr_shape: torch.Tensor - log_base: torch.Tensor - image_reference_logp: torch.Tensor - - def __init__( - self, - device: torch.device | None = None, - default_dtype: torch.dtype | None = None, - ) -> None: - super().__init__() - - assert device is not None - assert default_dtype is not None - self.device = device - self.default_dtype = default_dtype - - def set_new_reference_image(self, new_reference_image: torch.Tensor | None = None): - assert new_reference_image is not None - assert new_reference_image.ndim == 2 - self.reference_image = ( - new_reference_image.detach() - .clone() - .to(device=self.device) - .type(dtype=self.default_dtype) - ) - self.image_reference_dft = None - - def forward( - self, input: torch.Tensor, new_reference_image: torch.Tensor | None = None - ) -> torch.Tensor: - assert input.ndim == 3 - - if new_reference_image is not None: - self.set_new_reference_image(new_reference_image) - - assert self.reference_image is not None - assert self.reference_image.ndim == 2 - assert input.shape[-2] == self.reference_image.shape[-2] - assert input.shape[-1] == self.reference_image.shape[-1] - - self.last_scale, self.last_angle, self.last_tvec, output = self.similarity( - self.reference_image, - input.to(device=self.device).type(dtype=self.default_dtype), - ) - - return output - - def dry_run( - self, input: torch.Tensor, new_reference_image: torch.Tensor | None = None - ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: - assert input.ndim == 3 - - if new_reference_image is not None: - self.set_new_reference_image(new_reference_image) - - assert self.reference_image is not None - assert self.reference_image.ndim == 2 - assert input.shape[-2] == self.reference_image.shape[-2] - assert input.shape[-1] == self.reference_image.shape[-1] - - images_todo = input.to(device=self.device).type(dtype=self.default_dtype) - image_reference = self.reference_image - - assert image_reference.ndim == 2 - assert images_todo.ndim == 3 - - bgval: torch.Tensor = self.get_borderval(img=images_todo, radius=5) - - self.last_scale, self.last_angle, self.last_tvec = self._similarity( - image_reference, - images_todo, - bgval, - ) - - return self.last_scale, self.last_angle, self.last_tvec - - def dry_run_translation( - self, input: torch.Tensor, new_reference_image: torch.Tensor | None = None - ) -> torch.Tensor: - assert input.ndim == 3 - - if new_reference_image is not None: - self.set_new_reference_image(new_reference_image) - - assert self.reference_image is not None - assert self.reference_image.ndim == 2 - assert input.shape[-2] == self.reference_image.shape[-2] - assert input.shape[-1] == self.reference_image.shape[-1] - - images_todo = input.to(device=self.device).type(dtype=self.default_dtype) - image_reference = self.reference_image - - assert image_reference.ndim == 2 - assert images_todo.ndim == 3 - - tvec, _ = self._translation(image_reference, images_todo) - - return tvec - - # --------------- - - def dry_run_angle( - self, - input: torch.Tensor, - new_reference_image: torch.Tensor | None = None, - ) -> torch.Tensor: - assert input.ndim == 3 - - if new_reference_image is not None: - self.set_new_reference_image(new_reference_image) - - constraints_dynamic_angle_0: torch.Tensor = torch.zeros( - (input.shape[0]), dtype=self.default_dtype, device=self.device - ) - constraints_dynamic_angle_1: torch.Tensor | None = None - constraints_dynamic_scale_0: torch.Tensor = torch.ones( - (input.shape[0]), dtype=self.default_dtype, device=self.device - ) - constraints_dynamic_scale_1: torch.Tensor | None = None - - assert self.reference_image is not None - assert self.reference_image.ndim == 2 - assert input.shape[-2] == self.reference_image.shape[-2] - assert input.shape[-1] == self.reference_image.shape[-1] - - images_todo = input.to(device=self.device).type(dtype=self.default_dtype) - image_reference = self.reference_image - - assert image_reference.ndim == 2 - assert images_todo.ndim == 3 - - _, newangle = self._get_ang_scale( - image_reference, - images_todo, - constraints_dynamic_scale_0, - constraints_dynamic_scale_1, - constraints_dynamic_angle_0, - constraints_dynamic_angle_1, - ) - - return newangle - - # --------------- - - def _get_pcorr_shape(self, shape: torch.Size) -> tuple[int, int]: - ret = (int(max(shape[-2:]) * 1.0),) * 2 - return ret - - def _get_log_base(self, shape: torch.Size, new_r: torch.Tensor) -> torch.Tensor: - old_r = torch.tensor( - (float(shape[-2]) * self.excess_const) / 2.0, - dtype=self.default_dtype, - device=self.device, - ) - log_base = torch.exp(torch.log(old_r) / new_r) - return log_base - - def wrap_angle( - self, angles: torch.Tensor, ceil: float = 2 * torch.pi - ) -> torch.Tensor: - angles += ceil / 2.0 - angles %= ceil - angles -= ceil / 2.0 - return angles - - def get_borderval( - self, img: torch.Tensor, radius: int | None = None - ) -> torch.Tensor: - assert img.ndim == 3 - if radius is None: - mindim = min([int(img.shape[-2]), int(img.shape[-1])]) - radius = max(1, mindim // 20) - mask = torch.zeros( - (int(img.shape[-2]), int(img.shape[-1])), - dtype=torch.bool, - device=self.device, - ) - mask[:, :radius] = True - mask[:, -radius:] = True - mask[:radius, :] = True - mask[-radius:, :] = True - - mean = torch.median(img[:, mask], dim=-1)[0] - return mean - - def get_apofield(self, shape: torch.Size, aporad: int) -> torch.Tensor: - if aporad == 0: - return torch.ones( - shape[-2:], - dtype=self.default_dtype, - device=self.device, - ) - - assert int(shape[-2]) > aporad * 2 - assert int(shape[-1]) > aporad * 2 - - apos = torch.hann_window( - aporad * 2, dtype=self.default_dtype, periodic=False, device=self.device - ) - - toapp_0 = torch.ones( - shape[-2], - dtype=self.default_dtype, - device=self.device, - ) - toapp_0[:aporad] = apos[:aporad] - toapp_0[-aporad:] = apos[-aporad:] - - toapp_1 = torch.ones( - shape[-1], - dtype=self.default_dtype, - device=self.device, - ) - toapp_1[:aporad] = apos[:aporad] - toapp_1[-aporad:] = apos[-aporad:] - - apofield = torch.outer(toapp_0, toapp_1) - - return apofield - - def _get_subarr( - self, array: torch.Tensor, center: torch.Tensor, rad: int - ) -> torch.Tensor: - assert array.ndim == 3 - assert center.ndim == 2 - assert array.shape[0] == center.shape[0] - assert center.shape[1] == 2 - - dim = 1 + 2 * rad - subarr = torch.zeros( - (array.shape[0], dim, dim), dtype=self.default_dtype, device=self.device - ) - - corner = center - rad - idx_p = range(0, corner.shape[0]) - for ii in range(0, dim): - yidx = corner[:, 0] + ii - yidx %= array.shape[-2] - for jj in range(0, dim): - xidx = corner[:, 1] + jj - xidx %= array.shape[-1] - subarr[:, ii, jj] = array[idx_p, yidx, xidx] - - return subarr - - def _argmax_2d(self, array: torch.Tensor) -> torch.Tensor: - assert array.ndim == 3 - - max_pos = array.reshape( - (array.shape[0], array.shape[1] * array.shape[2]) - ).argmax(dim=1) - pos_0 = max_pos // array.shape[2] - - max_pos -= pos_0 * array.shape[2] - ret = torch.zeros( - (array.shape[0], 2), dtype=self.default_dtype, device=self.device - ) - ret[:, 0] = pos_0 - ret[:, 1] = max_pos - - return ret.type(dtype=torch.int64) - - def _apodize(self, what: torch.Tensor) -> torch.Tensor: - mindim = min([int(what.shape[-2]), int(what.shape[-1])]) - aporad = int(mindim * 0.12) - - apofield = self.get_apofield(what.shape, aporad).unsqueeze(0) - - res = what * apofield - bg = self.get_borderval(what, aporad // 2).unsqueeze(-1).unsqueeze(-1) - res += bg * (1 - apofield) - return res - - def _logpolar_filter(self, shape: torch.Size) -> torch.Tensor: - yy = torch.linspace( - -torch.pi / 2.0, - torch.pi / 2.0, - shape[-2], - dtype=self.default_dtype, - device=self.device, - ).unsqueeze(1) - - xx = torch.linspace( - -torch.pi / 2.0, - torch.pi / 2.0, - shape[-1], - dtype=self.default_dtype, - device=self.device, - ).unsqueeze(0) - - rads = torch.sqrt(yy**2 + xx**2) - filt = 1.0 - torch.cos(rads) ** 2 - - filt[torch.abs(rads) > torch.pi / 2] = 1 - return filt - - def _get_angles(self, shape: torch.Tensor) -> torch.Tensor: - ret = torch.zeros( - (int(shape[-2]), int(shape[-1])), - dtype=self.default_dtype, - device=self.device, - ) - ret -= torch.linspace( - 0, - torch.pi, - int(shape[-2] + 1), - dtype=self.default_dtype, - device=self.device, - )[:-1].unsqueeze(-1) - - return ret - - def _get_lograd(self, shape: torch.Tensor, log_base: torch.Tensor) -> torch.Tensor: - ret = torch.zeros( - (int(shape[-2]), int(shape[-1])), - dtype=self.default_dtype, - device=self.device, - ) - ret += torch.pow( - log_base, - torch.arange( - 0, - int(shape[-1]), - dtype=self.default_dtype, - device=self.device, - ), - ).unsqueeze(0) - return ret - - def _logpolar( - self, image: torch.Tensor, shape: torch.Tensor, log_base: torch.Tensor - ) -> torch.Tensor: - assert image.ndim == 3 - - imshape: torch.Tensor = torch.tensor( - image.shape[-2:], - dtype=self.default_dtype, - device=self.device, - ) - - center: torch.Tensor = imshape.clone() / 2 - - theta: torch.Tensor = self._get_angles(shape) - radius_x: torch.Tensor = self._get_lograd(shape, log_base) - radius_y: torch.Tensor = radius_x.clone() - - ellipse_coef: torch.Tensor = imshape[0] / imshape[1] - radius_x /= ellipse_coef - - y = radius_y * torch.sin(theta) + center[0] - y /= float(image.shape[-2]) - y *= 2 - y -= 1 - - x = radius_x * torch.cos(theta) + center[1] - x /= float(image.shape[-1]) - x *= 2 - x -= 1 - - idx_x = torch.where(torch.abs(x) <= 1.0, 1.0, 0.0) - idx_y = torch.where(torch.abs(y) <= 1.0, 1.0, 0.0) - - normalized_coords = torch.cat( - ( - x.unsqueeze(-1), - y.unsqueeze(-1), - ), - dim=-1, - ).unsqueeze(0) - - output = torch.empty( - (int(image.shape[0]), int(y.shape[0]), int(y.shape[1])), - dtype=self.default_dtype, - device=self.device, - ) - - for id in range(0, int(image.shape[0])): - bgval: torch.Tensor = torch.quantile(image[id, :, :], q=1.0 / 100.0) - - temp = torch.nn.functional.grid_sample( - image[id, :, :].unsqueeze(0).unsqueeze(0), - normalized_coords, - mode="bilinear", - padding_mode="zeros", - align_corners=False, - ) - - output[id, :, :] = torch.where((idx_x * idx_y) == 0.0, bgval, temp) - - return output - - def _argmax_ext(self, array: torch.Tensor, exponent: float | str) -> torch.Tensor: - assert array.ndim == 3 - - if exponent == "inf": - ret = self._argmax_2d(array) - else: - assert isinstance(exponent, float) or isinstance(exponent, int) - - col = ( - torch.arange( - 0, array.shape[-2], dtype=self.default_dtype, device=self.device - ) - .unsqueeze(-1) - .unsqueeze(0) - ) - row = ( - torch.arange( - 0, array.shape[-1], dtype=self.default_dtype, device=self.device - ) - .unsqueeze(0) - .unsqueeze(0) - ) - - arr2 = torch.pow(array, float(exponent)) - arrsum = arr2.sum(dim=-2).sum(dim=-1) - - ret = torch.zeros( - (array.shape[0], 2), dtype=self.default_dtype, device=self.device - ) - - arrprody = (arr2 * col).sum(dim=-1).sum(dim=-1) / arrsum - arrprodx = (arr2 * row).sum(dim=-1).sum(dim=-1) / arrsum - - ret[:, 0] = arrprody.squeeze(-1).squeeze(-1) - ret[:, 1] = arrprodx.squeeze(-1).squeeze(-1) - - idx = torch.where(arrsum == 0.0)[0] - ret[idx, :] = 0.0 - return ret - - def _interpolate( - self, array: torch.Tensor, rough: torch.Tensor, rad: int = 2 - ) -> torch.Tensor: - assert array.ndim == 3 - assert rough.ndim == 2 - - rough = torch.round(rough).type(torch.int64) - - surroundings = self._get_subarr(array, rough, rad) - - com = self._argmax_ext(surroundings, 1.0) - - offset = com - rad - ret = rough + offset - - ret += 0.5 - ret %= ( - torch.tensor(array.shape[-2:], dtype=self.default_dtype, device=self.device) - .type(dtype=torch.int64) - .unsqueeze(0) - ) - ret -= 0.5 - return ret - - def _get_success( - self, array: torch.Tensor, coord: torch.Tensor, radius: int = 2 - ) -> torch.Tensor: - assert array.ndim == 3 - assert coord.ndim == 2 - assert array.shape[0] == coord.shape[0] - assert coord.shape[1] == 2 - - coord = torch.round(coord).type(dtype=torch.int64) - subarr = self._get_subarr( - array, coord, 2 - ) # Not my fault. They want a 2 there. Not radius - - theval = subarr.sum(dim=-1).sum(dim=-1) - - theval2 = array[range(0, coord.shape[0]), coord[:, 0], coord[:, 1]] - - success = torch.sqrt(theval * theval2) - return success - - def _get_constraint_mask( - self, - shape: torch.Size, - log_base: torch.Tensor, - constraints_scale_0: torch.Tensor, - constraints_scale_1: torch.Tensor | None, - constraints_angle_0: torch.Tensor, - constraints_angle_1: torch.Tensor | None, - ) -> torch.Tensor: - assert constraints_scale_0 is not None - assert constraints_angle_0 is not None - assert constraints_scale_0.ndim == 1 - assert constraints_angle_0.ndim == 1 - - assert constraints_scale_0.shape[0] == constraints_angle_0.shape[0] - - mask: torch.Tensor = torch.ones( - (constraints_scale_0.shape[0], int(shape[-2]), int(shape[-1])), - device=self.device, - dtype=self.default_dtype, - ) - - scale: torch.Tensor = constraints_scale_0.clone() - if constraints_scale_1 is not None: - sigma: torch.Tensor | None = constraints_scale_1.clone() - else: - sigma = None - - scales = torch.fft.ifftshift( - self._get_lograd( - torch.tensor(shape[-2:], device=self.device, dtype=self.default_dtype), - log_base, - ) - ) - - scales *= log_base ** (-shape[-1] / 2.0) - scales = scales.unsqueeze(0) - (1.0 / scale).unsqueeze(-1).unsqueeze(-1) - - if sigma is not None: - assert sigma.shape[0] == constraints_scale_0.shape[0] - - for p_id in range(0, sigma.shape[0]): - if sigma[p_id] == 0: - ascales = torch.abs(scales[p_id, ...]) - scale_min = ascales.min() - binary_mask = torch.where(ascales > scale_min, 0.0, 1.0) - mask[p_id, ...] *= binary_mask - else: - mask[p_id, ...] *= torch.exp( - -(torch.pow(scales[p_id, ...], 2)) / torch.pow(sigma[p_id], 2) - ) - - angle: torch.Tensor = constraints_angle_0.clone() - if constraints_angle_1 is not None: - sigma = constraints_angle_1.clone() - else: - sigma = None - - angles = self._get_angles( - torch.tensor(shape[-2:], device=self.device, dtype=self.default_dtype) - ) - - angles = angles.unsqueeze(0) + torch.deg2rad(angle).unsqueeze(-1).unsqueeze(-1) - - angles = torch.rad2deg(angles) - - if sigma is not None: - assert sigma.shape[0] == constraints_scale_0.shape[0] - - for p_id in range(0, sigma.shape[0]): - if sigma[p_id] == 0: - aangles = torch.abs(angles[p_id, ...]) - angle_min = aangles.min() - binary_mask = torch.where(aangles > angle_min, 0.0, 1.0) - mask[p_id, ...] *= binary_mask - else: - mask *= torch.exp( - -(torch.pow(angles[p_id, ...], 2)) / torch.pow(sigma[p_id], 2) - ) - - mask = torch.fft.fftshift(mask, dim=(-2, -1)) - - return mask - - def argmax_angscale( - self, - array: torch.Tensor, - log_base: torch.Tensor, - constraints_scale_0: torch.Tensor, - constraints_scale_1: torch.Tensor | None, - constraints_angle_0: torch.Tensor, - constraints_angle_1: torch.Tensor | None, - ) -> tuple[torch.Tensor, torch.Tensor]: - assert array.ndim == 3 - assert constraints_scale_0 is not None - assert constraints_angle_0 is not None - assert constraints_scale_0.ndim == 1 - assert constraints_angle_0.ndim == 1 - - mask = self._get_constraint_mask( - array.shape[-2:], - log_base, - constraints_scale_0, - constraints_scale_1, - constraints_angle_0, - constraints_angle_1, - ) - - array_orig = array.clone() - - array *= mask - ret = self._argmax_ext(array, self.exponent) - - ret_final = self._interpolate(array, ret) - - success = self._get_success(array_orig, ret_final, 0) - - return ret_final, success - - def argmax_translation( - self, array: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - assert array.ndim == 3 - - array_orig = array.clone() - - ashape = torch.tensor(array.shape[-2:], device=self.device).type( - dtype=torch.int64 - ) - - aporad = (ashape // 6).min() - mask2 = self.get_apofield(torch.Size(ashape), aporad).unsqueeze(0) - array *= mask2 - - tvec = self._argmax_ext(array, "inf") - - tvec = self._interpolate(array_orig, tvec) - - success = self._get_success(array_orig, tvec, 2) - - return tvec, success - - def transform_img( - self, - img: torch.Tensor, - scale: torch.Tensor | None = None, - angle: torch.Tensor | None = None, - tvec: torch.Tensor | None = None, - bgval: torch.Tensor | None = None, - ) -> torch.Tensor: - assert img.ndim == 3 - - if scale is None: - scale = torch.ones( - (img.shape[0],), dtype=self.default_dtype, device=self.device - ) - assert scale.ndim == 1 - assert scale.shape[0] == img.shape[0] - - if angle is None: - angle = torch.zeros( - (img.shape[0],), dtype=self.default_dtype, device=self.device - ) - assert angle.ndim == 1 - assert angle.shape[0] == img.shape[0] - - if tvec is None: - tvec = torch.zeros( - (img.shape[0], 2), dtype=self.default_dtype, device=self.device - ) - assert tvec.ndim == 2 - assert tvec.shape[0] == img.shape[0] - assert tvec.shape[1] == 2 - - if bgval is None: - bgval = self.get_borderval(img) - assert bgval.ndim == 1 - assert bgval.shape[0] == img.shape[0] - - # Otherwise we need to decompose it and put it back together - assert torch.is_complex(img) is False - - output = torch.zeros_like(img) - - for pos in range(0, img.shape[0]): - image_processed = img[pos, :, :].unsqueeze(0).clone() - - temp_shift = [ - int(round(tvec[pos, 1].item() * self.scale_factor)), - int(round(tvec[pos, 0].item() * self.scale_factor)), - ] - - image_processed = torch.nn.functional.interpolate( - image_processed.unsqueeze(0), - scale_factor=self.scale_factor, - mode="bilinear", - ).squeeze(0) - - image_processed = tv.transforms.functional.affine( - img=image_processed, - angle=-float(angle[pos]), - translate=temp_shift, - scale=float(scale[pos]), - shear=[0, 0], - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=float(bgval[pos]), - center=None, - ) - - image_processed = torch.nn.functional.interpolate( - image_processed.unsqueeze(0), - scale_factor=1.0 / self.scale_factor, - mode="bilinear", - ).squeeze(0) - - image_processed = tv.transforms.functional.center_crop( - image_processed, img.shape[-2:] - ) - - output[pos, ...] = image_processed.squeeze(0) - - return output - - def transform_img_dict( - self, - img: torch.Tensor, - scale: torch.Tensor | None = None, - angle: torch.Tensor | None = None, - tvec: torch.Tensor | None = None, - bgval: torch.Tensor | None = None, - invert=False, - ) -> torch.Tensor: - if invert: - if scale is not None: - scale = 1.0 / scale - if angle is not None: - angle *= -1 - if tvec is not None: - tvec *= -1 - - res = self.transform_img(img, scale, angle, tvec, bgval=bgval) - return res - - def _phase_correlation( - self, image_reference: torch.Tensor, images_todo: torch.Tensor, callback, *args - ) -> tuple[torch.Tensor, torch.Tensor]: - assert image_reference.ndim == 3 - assert image_reference.shape[0] == 1 - assert images_todo.ndim == 3 - - assert callback is not None - - image_reference_fft = torch.fft.fft2(image_reference, dim=(-2, -1)) - images_todo_fft = torch.fft.fft2(images_todo, dim=(-2, -1)) - - eps = torch.abs(images_todo_fft).max(dim=-1)[0].max(dim=-1)[0] * 1e-15 - - cps = abs( - torch.fft.ifft2( - (image_reference_fft * images_todo_fft.conj()) - / ( - torch.abs(image_reference_fft) * torch.abs(images_todo_fft) - + eps.unsqueeze(-1).unsqueeze(-1) - ), - dim=(-2, -1), - ) - ) - - scps = torch.fft.fftshift(cps, dim=(-2, -1)) - - ret, success = callback(scps, *args) - - ret[:, 0] -= image_reference_fft.shape[-2] // 2 - ret[:, 1] -= image_reference_fft.shape[-1] // 2 - - return ret, success - - def _translation( - self, im0: torch.Tensor, im1: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - assert im0.ndim == 2 - ret, succ = self._phase_correlation( - im0.unsqueeze(0), im1, self.argmax_translation - ) - - return ret, succ - - def _get_ang_scale( - self, - image_reference: torch.Tensor, - images_todo: torch.Tensor, - constraints_scale_0: torch.Tensor, - constraints_scale_1: torch.Tensor | None, - constraints_angle_0: torch.Tensor, - constraints_angle_1: torch.Tensor | None, - ) -> tuple[torch.Tensor, torch.Tensor]: - assert image_reference.ndim == 2 - assert images_todo.ndim == 3 - assert image_reference.shape[-1] == images_todo.shape[-1] - assert image_reference.shape[-2] == images_todo.shape[-2] - assert constraints_scale_0.shape[0] == images_todo.shape[0] - assert constraints_angle_0.shape[0] == images_todo.shape[0] - - if constraints_scale_1 is not None: - assert constraints_scale_1.shape[0] == images_todo.shape[0] - - if constraints_angle_1 is not None: - assert constraints_angle_1.shape[0] == images_todo.shape[0] - - if self.image_reference_dft is None: - image_reference_apod = self._apodize(image_reference.unsqueeze(0)) - self.image_reference_dft = torch.fft.fftshift( - torch.fft.fft2(image_reference_apod, dim=(-2, -1)), dim=(-2, -1) - ) - self.filt = self._logpolar_filter(image_reference.shape).unsqueeze(0) - self.image_reference_dft *= self.filt - self.pcorr_shape = torch.tensor( - self._get_pcorr_shape(image_reference.shape[-2:]), - dtype=self.default_dtype, - device=self.device, - ) - self.log_base = self._get_log_base( - image_reference.shape, - self.pcorr_shape[1], - ) - self.image_reference_logp = self._logpolar( - torch.abs(self.image_reference_dft), self.pcorr_shape, self.log_base - ) - - images_todo_apod = self._apodize(images_todo) - images_todo_dft = torch.fft.fftshift( - torch.fft.fft2(images_todo_apod, dim=(-2, -1)), dim=(-2, -1) - ) - - images_todo_dft *= self.filt - - images_todo_lopg = self._logpolar( - torch.abs(images_todo_dft), self.pcorr_shape, self.log_base - ) - - temp, _ = self._phase_correlation( - self.image_reference_logp, - images_todo_lopg, - self.argmax_angscale, - self.log_base, - constraints_scale_0, - constraints_scale_1, - constraints_angle_0, - constraints_angle_1, - ) - - arg_ang = temp[:, 0].clone() - arg_rad = temp[:, 1].clone() - - angle = -torch.pi * arg_ang / float(self.pcorr_shape[0]) - angle = torch.rad2deg(angle) - - angle = self.wrap_angle(angle, 360) - - scale = torch.pow(self.log_base, arg_rad) - - angle = -angle - scale = 1.0 / scale - - assert torch.where(scale < 2)[0].shape[0] == scale.shape[0] - assert torch.where(scale > 0.5)[0].shape[0] == scale.shape[0] - - return scale, angle - - def translation( - self, im0: torch.Tensor, im1: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - angle = torch.zeros( - (im1.shape[0]), dtype=self.default_dtype, device=self.device - ) - assert im1.ndim == 3 - assert im0.shape[-2] == im1.shape[-2] - assert im0.shape[-1] == im1.shape[-1] - - tvec, succ = self._translation(im0, im1) - tvec2, succ2 = self._translation(im0, torch.rot90(im1, k=2, dims=[-2, -1])) - - assert tvec.shape[0] == tvec2.shape[0] - assert tvec.ndim == 2 - assert tvec2.ndim == 2 - assert tvec.shape[1] == 2 - assert tvec2.shape[1] == 2 - assert succ.shape[0] == succ2.shape[0] - assert succ.ndim == 1 - assert succ2.ndim == 1 - assert tvec.shape[0] == succ.shape[0] - assert angle.shape[0] == tvec.shape[0] - assert angle.ndim == 1 - - for pos in range(0, angle.shape[0]): - pick_rotated = False - if succ2[pos] > succ[pos]: - pick_rotated = True - - if pick_rotated: - tvec[pos, :] = tvec2[pos, :] - succ[pos] = succ2[pos] - angle[pos] += 180 - - return tvec, succ, angle - - def _similarity( - self, - image_reference: torch.Tensor, - images_todo: torch.Tensor, - bgval: torch.Tensor, - ): - assert image_reference.ndim == 2 - assert images_todo.ndim == 3 - assert image_reference.shape[-1] == images_todo.shape[-1] - assert image_reference.shape[-2] == images_todo.shape[-2] - - # We are going to iterate and precise scale and angle estimates - scale: torch.Tensor = torch.ones( - (images_todo.shape[0]), dtype=self.default_dtype, device=self.device - ) - angle: torch.Tensor = torch.zeros( - (images_todo.shape[0]), dtype=self.default_dtype, device=self.device - ) - - constraints_dynamic_angle_0: torch.Tensor = torch.zeros( - (images_todo.shape[0]), dtype=self.default_dtype, device=self.device - ) - constraints_dynamic_angle_1: torch.Tensor | None = None - constraints_dynamic_scale_0: torch.Tensor = torch.ones( - (images_todo.shape[0]), dtype=self.default_dtype, device=self.device - ) - constraints_dynamic_scale_1: torch.Tensor | None = None - - newscale, newangle = self._get_ang_scale( - image_reference, - images_todo, - constraints_dynamic_scale_0, - constraints_dynamic_scale_1, - constraints_dynamic_angle_0, - constraints_dynamic_angle_1, - ) - scale *= newscale - angle += newangle - - im2 = self.transform_img(images_todo, scale, angle, bgval=bgval) - - tvec, self.success, res_angle = self.translation(image_reference, im2) - - angle += res_angle - - angle = self.wrap_angle(angle, 360) - - return scale, angle, tvec - - def similarity( - self, - image_reference: torch.Tensor, - images_todo: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - assert image_reference.ndim == 2 - assert images_todo.ndim == 3 - - bgval: torch.Tensor = self.get_borderval(img=images_todo, radius=5) - - scale, angle, tvec = self._similarity( - image_reference, - images_todo, - bgval, - ) - - im2 = self.transform_img_dict( - img=images_todo, - scale=scale, - angle=angle, - tvec=tvec, - bgval=bgval, - ) - - return scale, angle, tvec, im2 diff --git a/new_pipeline/functions/align_refref.py b/new_pipeline/functions/align_refref.py deleted file mode 100644 index 249665c..0000000 --- a/new_pipeline/functions/align_refref.py +++ /dev/null @@ -1,57 +0,0 @@ -import torch -import torchvision as tv # type: ignore -import logging -from functions.ImageAlignment import ImageAlignment -from functions.calculate_translation import calculate_translation -from functions.calculate_rotation import calculate_rotation - - -@torch.no_grad() -def align_refref( - mylogger: logging.Logger, - ref_image_acceptor: torch.Tensor, - ref_image_donor: torch.Tensor, - image_alignment: ImageAlignment, - batch_size: int, - fill_value: float = 0, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - - mylogger.info("Rotate ref image acceptor onto donor") - angle_refref = calculate_rotation( - image_alignment=image_alignment, - input=ref_image_acceptor.unsqueeze(0), - reference_image=ref_image_donor, - batch_size=batch_size, - ) - - ref_image_acceptor = tv.transforms.functional.affine( - img=ref_image_acceptor.unsqueeze(0), - angle=-float(angle_refref), - translate=[0, 0], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ) - - mylogger.info("Translate ref image acceptor onto donor") - tvec_refref = calculate_translation( - image_alignment=image_alignment, - input=ref_image_acceptor, - reference_image=ref_image_donor, - batch_size=batch_size, - ) - - tvec_refref = tvec_refref[0, :] - - ref_image_acceptor = tv.transforms.functional.affine( - img=ref_image_acceptor, - angle=0, - translate=[tvec_refref[1], tvec_refref[0]], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ).squeeze(0) - - return angle_refref, tvec_refref, ref_image_acceptor, ref_image_donor diff --git a/new_pipeline/functions/bandpass.py b/new_pipeline/functions/bandpass.py deleted file mode 100644 index 171baf5..0000000 --- a/new_pipeline/functions/bandpass.py +++ /dev/null @@ -1,85 +0,0 @@ -import torchaudio as ta # type: ignore -import torch - - -@torch.no_grad() -def filtfilt( - input: torch.Tensor, - butter_a: torch.Tensor, - butter_b: torch.Tensor, -) -> torch.Tensor: - assert butter_a.ndim == 1 - assert butter_b.ndim == 1 - assert butter_a.shape[0] == butter_b.shape[0] - - process_data: torch.Tensor = input.detach().clone() - - padding_length = 12 * int(butter_a.shape[0]) - left_padding = 2 * process_data[..., 0].unsqueeze(-1) - process_data[ - ..., 1 : padding_length + 1 - ].flip(-1) - right_padding = 2 * process_data[..., -1].unsqueeze(-1) - process_data[ - ..., -(padding_length + 1) : -1 - ].flip(-1) - process_data_padded = torch.cat((left_padding, process_data, right_padding), dim=-1) - - output = ta.functional.filtfilt( - process_data_padded.unsqueeze(0), butter_a, butter_b, clamp=False - ).squeeze(0) - - output = output[..., padding_length:-padding_length] - return output - - -@torch.no_grad() -def butter_bandpass( - device: torch.device, - low_frequency: float = 0.1, - high_frequency: float = 1.0, - fs: float = 30.0, -) -> tuple[torch.Tensor, torch.Tensor]: - import scipy # type: ignore - - butter_b_np, butter_a_np = scipy.signal.butter( - 4, [low_frequency, high_frequency], btype="bandpass", output="ba", fs=fs - ) - butter_a = torch.tensor(butter_a_np, device=device, dtype=torch.float32) - butter_b = torch.tensor(butter_b_np, device=device, dtype=torch.float32) - return butter_a, butter_b - - -@torch.no_grad() -def chunk_iterator(array: torch.Tensor, chunk_size: int): - for i in range(0, array.shape[0], chunk_size): - yield array[i : i + chunk_size] - - -@torch.no_grad() -def bandpass( - data: torch.Tensor, - device: torch.device, - low_frequency: float = 0.1, - high_frequency: float = 1.0, - fs=30.0, - filtfilt_chuck_size: int = 10, -) -> torch.Tensor: - butter_a, butter_b = butter_bandpass( - device=device, - low_frequency=low_frequency, - high_frequency=high_frequency, - fs=fs, - ) - - index_full_dataset: torch.Tensor = torch.arange( - 0, data.shape[1], device=device, dtype=torch.int64 - ) - - for chunk in chunk_iterator(index_full_dataset, filtfilt_chuck_size): - temp_filtfilt = filtfilt( - data[:, chunk, :], - butter_a=butter_a, - butter_b=butter_b, - ) - data[:, chunk, :] = temp_filtfilt - - return data diff --git a/new_pipeline/functions/binning.py b/new_pipeline/functions/binning.py deleted file mode 100644 index ccfe657..0000000 --- a/new_pipeline/functions/binning.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch - - -def binning( - data: torch.Tensor, - kernel_size: int = 4, - stride: int = 4, - divisor_override: int | None = 1, -) -> torch.Tensor: - - assert data.ndim == 4 - return ( - torch.nn.functional.avg_pool2d( - input=data.movedim(0, -1).movedim(0, -1), - kernel_size=kernel_size, - stride=stride, - divisor_override=divisor_override, - ) - .movedim(-1, 0) - .movedim(-1, 0) - ) diff --git a/new_pipeline/functions/calculate_rotation.py b/new_pipeline/functions/calculate_rotation.py deleted file mode 100644 index 6a53afd..0000000 --- a/new_pipeline/functions/calculate_rotation.py +++ /dev/null @@ -1,40 +0,0 @@ -import torch - -from functions.ImageAlignment import ImageAlignment - - -@torch.no_grad() -def calculate_rotation( - image_alignment: ImageAlignment, - input: torch.Tensor, - reference_image: torch.Tensor, - batch_size: int, -) -> torch.Tensor: - angle = torch.zeros((input.shape[0])) - - data_loader = torch.utils.data.DataLoader( - torch.utils.data.TensorDataset(input), - batch_size=batch_size, - shuffle=False, - ) - start_position: int = 0 - for input_batch in data_loader: - assert len(input_batch) == 1 - - end_position = start_position + input_batch[0].shape[0] - - angle_temp = image_alignment.dry_run_angle( - input=input_batch[0], - new_reference_image=reference_image, - ) - - assert angle_temp is not None - - angle[start_position:end_position] = angle_temp - - start_position += input_batch[0].shape[0] - - angle = torch.where(angle >= 180, 360.0 - angle, angle) - angle = torch.where(angle <= -180, 360.0 + angle, angle) - - return angle diff --git a/new_pipeline/functions/calculate_translation.py b/new_pipeline/functions/calculate_translation.py deleted file mode 100644 index 9eadf59..0000000 --- a/new_pipeline/functions/calculate_translation.py +++ /dev/null @@ -1,37 +0,0 @@ -import torch - -from functions.ImageAlignment import ImageAlignment - - -@torch.no_grad() -def calculate_translation( - image_alignment: ImageAlignment, - input: torch.Tensor, - reference_image: torch.Tensor, - batch_size: int, -) -> torch.Tensor: - tvec = torch.zeros((input.shape[0], 2)) - - data_loader = torch.utils.data.DataLoader( - torch.utils.data.TensorDataset(input), - batch_size=batch_size, - shuffle=False, - ) - start_position: int = 0 - for input_batch in data_loader: - assert len(input_batch) == 1 - - end_position = start_position + input_batch[0].shape[0] - - tvec_temp = image_alignment.dry_run_translation( - input=input_batch[0], - new_reference_image=reference_image, - ) - - assert tvec_temp is not None - - tvec[start_position:end_position, :] = tvec_temp - - start_position += input_batch[0].shape[0] - - return tvec diff --git a/new_pipeline/functions/create_logger.py b/new_pipeline/functions/create_logger.py deleted file mode 100644 index 8fcfa8a..0000000 --- a/new_pipeline/functions/create_logger.py +++ /dev/null @@ -1,37 +0,0 @@ -import logging -import datetime -import os - - -def create_logger( - save_logging_messages: bool, display_logging_messages: bool, log_stage_name: str -): - now = datetime.datetime.now() - dt_string_filename = now.strftime("%Y_%m_%d_%H_%M_%S") - - logger = logging.getLogger("MyLittleLogger") - logger.setLevel(logging.DEBUG) - - if save_logging_messages: - time_format = "%b %-d %Y %H:%M:%S" - logformat = "%(asctime)s %(message)s" - file_formatter = logging.Formatter(fmt=logformat, datefmt=time_format) - os.makedirs("logs_" + log_stage_name, exist_ok=True) - file_handler = logging.FileHandler( - os.path.join("logs_" + log_stage_name, f"log_{dt_string_filename}.txt") - ) - file_handler.setLevel(logging.INFO) - file_handler.setFormatter(file_formatter) - logger.addHandler(file_handler) - - if display_logging_messages: - time_format = "%H:%M:%S" - logformat = "%(asctime)s %(message)s" - stream_formatter = logging.Formatter(fmt=logformat, datefmt=time_format) - - stream_handler = logging.StreamHandler() - stream_handler.setLevel(logging.INFO) - stream_handler.setFormatter(stream_formatter) - logger.addHandler(stream_handler) - - return logger diff --git a/new_pipeline/functions/data_raw_loader.py b/new_pipeline/functions/data_raw_loader.py deleted file mode 100644 index 67e55cf..0000000 --- a/new_pipeline/functions/data_raw_loader.py +++ /dev/null @@ -1,339 +0,0 @@ -import numpy as np -import torch -import os -import logging -import copy - -from functions.get_experiments import get_experiments -from functions.get_trials import get_trials -from functions.get_parts import get_parts -from functions.load_meta_data import load_meta_data - - -def data_raw_loader( - raw_data_path: str, - mylogger: logging.Logger, - experiment_id: int, - trial_id: int, - device: torch.device, - force_to_cpu_memory: bool, - config: dict, -) -> tuple[list[str], str, str, dict, dict, float, float, str, torch.Tensor]: - - meta_channels: list[str] = [] - meta_mouse_markings: str = "" - meta_recording_date: str = "" - meta_stimulation_times: dict = {} - meta_experiment_names: dict = {} - meta_trial_recording_duration: float = 0.0 - meta_frame_time: float = 0.0 - meta_mouse: str = "" - data: torch.Tensor = torch.zeros((1)) - - dtype_str = config["dtype"] - mylogger.info(f"Data precision will be {dtype_str}") - dtype: torch.dtype = getattr(torch, dtype_str) - dtype_np: np.dtype = getattr(np, dtype_str) - - if os.path.isdir(raw_data_path) is False: - mylogger.info(f"ERROR: could not find raw directory {raw_data_path}!!!!") - assert os.path.isdir(raw_data_path) - return ( - meta_channels, - meta_mouse_markings, - meta_recording_date, - meta_stimulation_times, - meta_experiment_names, - meta_trial_recording_duration, - meta_frame_time, - meta_mouse, - data, - ) - - if (torch.where(get_experiments(raw_data_path) == experiment_id)[0].shape[0]) != 1: - mylogger.info(f"ERROR: could not find experiment id {experiment_id}!!!!") - assert ( - torch.where(get_experiments(raw_data_path) == experiment_id)[0].shape[0] - ) == 1 - return ( - meta_channels, - meta_mouse_markings, - meta_recording_date, - meta_stimulation_times, - meta_experiment_names, - meta_trial_recording_duration, - meta_frame_time, - meta_mouse, - data, - ) - - if ( - torch.where(get_trials(raw_data_path, experiment_id) == trial_id)[0].shape[0] - ) != 1: - mylogger.info(f"ERROR: could not find trial id {trial_id}!!!!") - assert ( - torch.where(get_trials(raw_data_path, experiment_id) == trial_id)[0].shape[ - 0 - ] - ) == 1 - return ( - meta_channels, - meta_mouse_markings, - meta_recording_date, - meta_stimulation_times, - meta_experiment_names, - meta_trial_recording_duration, - meta_frame_time, - meta_mouse, - data, - ) - - available_parts: torch.Tensor = get_parts(raw_data_path, experiment_id, trial_id) - if available_parts.shape[0] < 1: - mylogger.info("ERROR: could not find any part files") - assert available_parts.shape[0] >= 1 - - experiment_name = f"Exp{experiment_id:03d}_Trial{trial_id:03d}" - mylogger.info(f"Will work on: {experiment_name}") - - mylogger.info(f"We found {int(available_parts.shape[0])} parts.") - - first_run: bool = True - - mylogger.info("Compare meta data of all parts") - for id in range(0, available_parts.shape[0]): - part_id = available_parts[id] - - filename_meta: str = os.path.join( - raw_data_path, - f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}_meta.txt", - ) - - if os.path.isfile(filename_meta) is False: - mylogger.info(f"Could not load meta data... {filename_meta}") - assert os.path.isfile(filename_meta) - return ( - meta_channels, - meta_mouse_markings, - meta_recording_date, - meta_stimulation_times, - meta_experiment_names, - meta_trial_recording_duration, - meta_frame_time, - meta_mouse, - data, - ) - - ( - meta_channels, - meta_mouse_markings, - meta_recording_date, - meta_stimulation_times, - meta_experiment_names, - meta_trial_recording_duration, - meta_frame_time, - meta_mouse, - ) = load_meta_data( - mylogger=mylogger, filename_meta=filename_meta, silent_mode=True - ) - - if first_run: - first_run = False - master_meta_channels: list[str] = copy.deepcopy(meta_channels) - master_meta_mouse_markings: str = meta_mouse_markings - master_meta_recording_date: str = meta_recording_date - master_meta_stimulation_times: dict = copy.deepcopy(meta_stimulation_times) - master_meta_experiment_names: dict = copy.deepcopy(meta_experiment_names) - master_meta_trial_recording_duration: float = meta_trial_recording_duration - master_meta_frame_time: float = meta_frame_time - master_meta_mouse: str = meta_mouse - - meta_channels_check = master_meta_channels == meta_channels - - # Check channel order - if meta_channels_check: - for channel_a, channel_b in zip(master_meta_channels, meta_channels): - if channel_a != channel_b: - meta_channels_check = False - - meta_mouse_markings_check = master_meta_mouse_markings == meta_mouse_markings - meta_recording_date_check = master_meta_recording_date == meta_recording_date - meta_stimulation_times_check = ( - master_meta_stimulation_times == meta_stimulation_times - ) - meta_experiment_names_check = ( - master_meta_experiment_names == meta_experiment_names - ) - meta_trial_recording_duration_check = ( - master_meta_trial_recording_duration == meta_trial_recording_duration - ) - meta_frame_time_check = master_meta_frame_time == meta_frame_time - meta_mouse_check = master_meta_mouse == meta_mouse - - if meta_channels_check is False: - mylogger.info(f"{filename_meta} failed: channels") - assert meta_channels_check - - if meta_mouse_markings_check is False: - mylogger.info(f"{filename_meta} failed: mouse_markings") - assert meta_mouse_markings_check - - if meta_recording_date_check is False: - mylogger.info(f"{filename_meta} failed: recording_date") - assert meta_recording_date_check - - if meta_stimulation_times_check is False: - mylogger.info(f"{filename_meta} failed: stimulation_times") - assert meta_stimulation_times_check - - if meta_experiment_names_check is False: - mylogger.info(f"{filename_meta} failed: experiment_names") - assert meta_experiment_names_check - - if meta_trial_recording_duration_check is False: - mylogger.info(f"{filename_meta} failed: trial_recording_duration") - assert meta_trial_recording_duration_check - - if meta_frame_time_check is False: - mylogger.info(f"{filename_meta} failed: frame_time_check") - assert meta_frame_time_check - - if meta_mouse_check is False: - mylogger.info(f"{filename_meta} failed: mouse") - assert meta_mouse_check - mylogger.info("-==- Done -==-") - - mylogger.info(f"Will use: {filename_meta} for meta data") - ( - meta_channels, - meta_mouse_markings, - meta_recording_date, - meta_stimulation_times, - meta_experiment_names, - meta_trial_recording_duration, - meta_frame_time, - meta_mouse, - ) = load_meta_data(mylogger=mylogger, filename_meta=filename_meta) - - ################# - # Meta data end # - ################# - - first_run = True - mylogger.info("Count the number of frames in the data of all parts") - frame_count: int = 0 - for id in range(0, available_parts.shape[0]): - part_id = available_parts[id] - - filename_data: str = os.path.join( - raw_data_path, - f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}.npy", - ) - - if os.path.isfile(filename_data) is False: - mylogger.info(f"Could not load data... {filename_data}") - assert os.path.isfile(filename_data) - return ( - meta_channels, - meta_mouse_markings, - meta_recording_date, - meta_stimulation_times, - meta_experiment_names, - meta_trial_recording_duration, - meta_frame_time, - meta_mouse, - data, - ) - data_np: np.ndarray = np.load(filename_data, mmap_mode="r") - - if data_np.ndim != 4: - mylogger.info(f"ERROR: Data needs to have 4 dimensions {filename_data}") - assert data_np.ndim == 4 - - if first_run: - first_run = False - dim_0: int = int(data_np.shape[0]) - dim_1: int = int(data_np.shape[1]) - dim_3: int = int(data_np.shape[3]) - - frame_count += int(data_np.shape[2]) - - if int(data_np.shape[0]) != dim_0: - mylogger.info( - f"ERROR: Data dim 0 is broken {int(data_np.shape[0])} vs {dim_0} {filename_data}" - ) - assert int(data_np.shape[0]) == dim_0 - - if int(data_np.shape[1]) != dim_1: - mylogger.info( - f"ERROR: Data dim 1 is broken {int(data_np.shape[1])} vs {dim_1} {filename_data}" - ) - assert int(data_np.shape[1]) == dim_1 - - if int(data_np.shape[3]) != dim_3: - mylogger.info( - f"ERROR: Data dim 3 is broken {int(data_np.shape[3])} vs {dim_3} {filename_data}" - ) - assert int(data_np.shape[3]) == dim_3 - - mylogger.info( - f"{filename_data}: {int(data_np.shape[2])} frames -> {frame_count} frames total" - ) - - if force_to_cpu_memory: - mylogger.info("Using CPU memory for data") - data = torch.empty( - (dim_0, dim_1, frame_count, dim_3), dtype=dtype, device=torch.device("cpu") - ) - else: - mylogger.info("Using GPU memory for data") - data = torch.empty( - (dim_0, dim_1, frame_count, dim_3), dtype=dtype, device=device - ) - - start_position: int = 0 - end_position: int = 0 - for id in range(0, available_parts.shape[0]): - part_id = available_parts[id] - - filename_data = os.path.join( - raw_data_path, - f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}.npy", - ) - - mylogger.info(f"Will work on {filename_data}") - mylogger.info("Loading data file") - data_np = np.load(filename_data).astype(dtype_np) - - end_position = start_position + int(data_np.shape[2]) - - for i in range(0, len(config["required_order"])): - mylogger.info(f"Move raw data channel: {config['required_order'][i]}") - - idx = meta_channels.index(config["required_order"][i]) - data[..., start_position:end_position, i] = torch.tensor( - data_np[..., idx], dtype=dtype, device=data.device - ) - start_position = end_position - - if start_position != int(data.shape[2]): - mylogger.info("ERROR: data was not fulled fully!!!") - assert start_position == int(data.shape[2]) - - mylogger.info("-==- Done -==-") - - ################# - # Raw data end # - ################# - - return ( - meta_channels, - meta_mouse_markings, - meta_recording_date, - meta_stimulation_times, - meta_experiment_names, - meta_trial_recording_duration, - meta_frame_time, - meta_mouse, - data, - ) diff --git a/new_pipeline/functions/gauss_smear_individual.py b/new_pipeline/functions/gauss_smear_individual.py deleted file mode 100644 index 36700e7..0000000 --- a/new_pipeline/functions/gauss_smear_individual.py +++ /dev/null @@ -1,127 +0,0 @@ -import torch -import math - - -@torch.no_grad() -def gauss_smear_individual( - input: torch.Tensor, - spatial_width: float, - temporal_width: float, - overwrite_fft_gauss: None | torch.Tensor = None, - use_matlab_mask: bool = True, - epsilon: float = float(torch.finfo(torch.float64).eps), -) -> tuple[torch.Tensor, torch.Tensor]: - - dim_x: int = int(2 * math.ceil(2 * spatial_width) + 1) - dim_y: int = int(2 * math.ceil(2 * spatial_width) + 1) - dim_t: int = int(2 * math.ceil(2 * temporal_width) + 1) - dims_xyt: torch.Tensor = torch.tensor( - [dim_x, dim_y, dim_t], dtype=torch.int64, device=input.device - ) - - if input.ndim == 2: - input = input.unsqueeze(-1) - - input_padded = torch.nn.functional.pad( - input.unsqueeze(0), - pad=( - dim_t, - dim_t, - dim_y, - dim_y, - dim_x, - dim_x, - ), - mode="replicate", - ).squeeze(0) - - if overwrite_fft_gauss is None: - center_x: int = int(math.ceil(input_padded.shape[0] / 2)) - center_y: int = int(math.ceil(input_padded.shape[1] / 2)) - center_z: int = int(math.ceil(input_padded.shape[2] / 2)) - grid_x: torch.Tensor = ( - torch.arange(0, input_padded.shape[0], device=input.device) - center_x + 1 - ) - grid_y: torch.Tensor = ( - torch.arange(0, input_padded.shape[1], device=input.device) - center_y + 1 - ) - grid_z: torch.Tensor = ( - torch.arange(0, input_padded.shape[2], device=input.device) - center_z + 1 - ) - - grid_x = grid_x.unsqueeze(-1).unsqueeze(-1) ** 2 - grid_y = grid_y.unsqueeze(0).unsqueeze(-1) ** 2 - grid_z = grid_z.unsqueeze(0).unsqueeze(0) ** 2 - - gauss_kernel: torch.Tensor = ( - (grid_x / (spatial_width**2)) - + (grid_y / (spatial_width**2)) - + (grid_z / (temporal_width**2)) - ) - - if use_matlab_mask: - filter_radius: torch.Tensor = (dims_xyt - 1) // 2 - - border_lower: list[int] = [ - center_x - int(filter_radius[0]) - 1, - center_y - int(filter_radius[1]) - 1, - center_z - int(filter_radius[2]) - 1, - ] - - border_upper: list[int] = [ - center_x + int(filter_radius[0]), - center_y + int(filter_radius[1]), - center_z + int(filter_radius[2]), - ] - - matlab_mask: torch.Tensor = torch.zeros_like(gauss_kernel) - matlab_mask[ - border_lower[0] : border_upper[0], - border_lower[1] : border_upper[1], - border_lower[2] : border_upper[2], - ] = 1.0 - - gauss_kernel = torch.exp(-gauss_kernel / 2.0) - if use_matlab_mask: - gauss_kernel = gauss_kernel * matlab_mask - - gauss_kernel[gauss_kernel < (epsilon * gauss_kernel.max())] = 0 - - sum_gauss_kernel: float = float(gauss_kernel.sum()) - - if sum_gauss_kernel != 0.0: - gauss_kernel = gauss_kernel / sum_gauss_kernel - - # FFT Shift - gauss_kernel = torch.cat( - (gauss_kernel[center_x - 1 :, :, :], gauss_kernel[: center_x - 1, :, :]), - dim=0, - ) - gauss_kernel = torch.cat( - (gauss_kernel[:, center_y - 1 :, :], gauss_kernel[:, : center_y - 1, :]), - dim=1, - ) - gauss_kernel = torch.cat( - (gauss_kernel[:, :, center_z - 1 :], gauss_kernel[:, :, : center_z - 1]), - dim=2, - ) - overwrite_fft_gauss = torch.fft.fftn(gauss_kernel) - input_padded_gauss_filtered: torch.Tensor = torch.real( - torch.fft.ifftn(torch.fft.fftn(input_padded) * overwrite_fft_gauss) - ) - else: - input_padded_gauss_filtered = torch.real( - torch.fft.ifftn(torch.fft.fftn(input_padded) * overwrite_fft_gauss) - ) - - start = dims_xyt - stop = ( - torch.tensor(input_padded.shape, device=dims_xyt.device, dtype=dims_xyt.dtype) - - dims_xyt - ) - - output = input_padded_gauss_filtered[ - start[0] : stop[0], start[1] : stop[1], start[2] : stop[2] - ] - - return (output, overwrite_fft_gauss) diff --git a/new_pipeline/functions/get_experiments.py b/new_pipeline/functions/get_experiments.py deleted file mode 100644 index d92b936..0000000 --- a/new_pipeline/functions/get_experiments.py +++ /dev/null @@ -1,19 +0,0 @@ -import torch -import os -import glob - - -@torch.no_grad() -def get_experiments(path: str) -> torch.Tensor: - filename_np: str = os.path.join( - path, - "Exp*_Part001.npy", - ) - - list_str = glob.glob(filename_np) - list_int: list[int] = [] - for i in range(0, len(list_str)): - list_int.append(int(list_str[i].split("Exp")[-1].split("_Trial")[0])) - list_int = sorted(list_int) - - return torch.tensor(list_int).unique() diff --git a/new_pipeline/functions/get_parts.py b/new_pipeline/functions/get_parts.py deleted file mode 100644 index d68e1ae..0000000 --- a/new_pipeline/functions/get_parts.py +++ /dev/null @@ -1,18 +0,0 @@ -import torch -import os -import glob - - -@torch.no_grad() -def get_parts(path: str, experiment_id: int, trial_id: int) -> torch.Tensor: - filename_np: str = os.path.join( - path, - f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part*.npy", - ) - - list_str = glob.glob(filename_np) - list_int: list[int] = [] - for i in range(0, len(list_str)): - list_int.append(int(list_str[i].split("_Part")[-1].split(".npy")[0])) - list_int = sorted(list_int) - return torch.tensor(list_int).unique() diff --git a/new_pipeline/functions/get_torch_device.py b/new_pipeline/functions/get_torch_device.py deleted file mode 100644 index 9eec5e9..0000000 --- a/new_pipeline/functions/get_torch_device.py +++ /dev/null @@ -1,17 +0,0 @@ -import torch -import logging - - -def get_torch_device(mylogger: logging.Logger, force_to_cpu: bool) -> torch.device: - - if torch.cuda.is_available(): - device_name: str = "cuda:0" - else: - device_name = "cpu" - - if force_to_cpu: - device_name = "cpu" - - mylogger.info(f"Using device: {device_name}") - device: torch.device = torch.device(device_name) - return device diff --git a/new_pipeline/functions/get_trials.py b/new_pipeline/functions/get_trials.py deleted file mode 100644 index 8c687d9..0000000 --- a/new_pipeline/functions/get_trials.py +++ /dev/null @@ -1,18 +0,0 @@ -import torch -import os -import glob - - -@torch.no_grad() -def get_trials(path: str, experiment_id: int) -> torch.Tensor: - filename_np: str = os.path.join( - path, - f"Exp{experiment_id:03d}_Trial*_Part001.npy", - ) - - list_str = glob.glob(filename_np) - list_int: list[int] = [] - for i in range(0, len(list_str)): - list_int.append(int(list_str[i].split("_Trial")[-1].split("_Part")[0])) - list_int = sorted(list_int) - return torch.tensor(list_int).unique() diff --git a/new_pipeline/functions/load_config.py b/new_pipeline/functions/load_config.py deleted file mode 100644 index c17fa40..0000000 --- a/new_pipeline/functions/load_config.py +++ /dev/null @@ -1,16 +0,0 @@ -import json -import os -import logging - -from jsmin import jsmin # type:ignore - - -def load_config(mylogger: logging.Logger, filename: str = "config.json") -> dict: - mylogger.info("loading config file") - if os.path.isfile(filename) is False: - mylogger.info(f"{filename} is missing") - - with open(filename, "r") as file: - config = json.loads(jsmin(file.read())) - - return config diff --git a/new_pipeline/functions/load_meta_data.py b/new_pipeline/functions/load_meta_data.py deleted file mode 100644 index 2a893e0..0000000 --- a/new_pipeline/functions/load_meta_data.py +++ /dev/null @@ -1,63 +0,0 @@ -import logging -import json - - -def load_meta_data( - mylogger: logging.Logger, filename_meta: str, silent_mode=False -) -> tuple[list[str], str, str, dict, dict, float, float, str]: - - if silent_mode is False: - mylogger.info("Loading meta data") - with open(filename_meta, "r") as file_handle: - metadata: dict = json.load(file_handle) - - channels: list[str] = metadata["channelKey"] - - if silent_mode is False: - mylogger.info(f"meta data: channel order: {channels}") - - mouse_markings: str = metadata["sessionMetaData"]["mouseMarkings"] - if silent_mode is False: - mylogger.info(f"meta data: mouse markings: {mouse_markings}") - - recording_date: str = metadata["sessionMetaData"]["date"] - if silent_mode is False: - mylogger.info(f"meta data: recording data: {recording_date}") - - stimulation_times: dict = metadata["sessionMetaData"]["stimulationTimes"] - if silent_mode is False: - mylogger.info(f"meta data: stimulation times: {stimulation_times}") - - experiment_names: dict = metadata["sessionMetaData"]["experimentNames"] - if silent_mode is False: - mylogger.info(f"meta data: experiment names: {experiment_names}") - - trial_recording_duration: float = float( - metadata["sessionMetaData"]["trialRecordingDuration"] - ) - if silent_mode is False: - mylogger.info( - f"meta data: trial recording duration: {trial_recording_duration} sec" - ) - - frame_time: float = float(metadata["sessionMetaData"]["frameTime"]) - if silent_mode is False: - mylogger.info( - f"meta data: frame time: {frame_time} sec ; frame rate: {1.0/frame_time}Hz" - ) - - mouse: str = metadata["sessionMetaData"]["mouse"] - if silent_mode is False: - mylogger.info(f"meta data: mouse: {mouse}") - mylogger.info("-==- Done -==-") - - return ( - channels, - mouse_markings, - recording_date, - stimulation_times, - experiment_names, - trial_recording_duration, - frame_time, - mouse, - ) diff --git a/new_pipeline/functions/perform_donor_volume_rotation.py b/new_pipeline/functions/perform_donor_volume_rotation.py deleted file mode 100644 index 590630f..0000000 --- a/new_pipeline/functions/perform_donor_volume_rotation.py +++ /dev/null @@ -1,140 +0,0 @@ -import torch -import torchvision as tv # type: ignore -import logging -from functions.calculate_rotation import calculate_rotation -from functions.ImageAlignment import ImageAlignment - - -@torch.no_grad() -def perform_donor_volume_rotation( - mylogger: logging.Logger, - acceptor: torch.Tensor, - donor: torch.Tensor, - oxygenation: torch.Tensor, - volume: torch.Tensor, - ref_image_donor: torch.Tensor, - ref_image_volume: torch.Tensor, - image_alignment: ImageAlignment, - batch_size: int, - config: dict, - fill_value: float = 0, -) -> tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, -]: - - mylogger.info("Calculate rotation between donor data and donor ref image") - - angle_donor = calculate_rotation( - input=donor, - reference_image=ref_image_donor, - image_alignment=image_alignment, - batch_size=batch_size, - ) - - mylogger.info("Calculate rotation between volume data and volume ref image") - angle_volume = calculate_rotation( - input=volume, - reference_image=ref_image_volume, - image_alignment=image_alignment, - batch_size=batch_size, - ) - - mylogger.info("Average over both rotations") - - donor_threshold: torch.Tensor = torch.sort(torch.abs(angle_donor))[0] - donor_threshold = donor_threshold[ - int( - donor_threshold.shape[0] - * float(config["rotation_stabilization_threshold_border"]) - ) - ] * float(config["rotation_stabilization_threshold_factor"]) - - volume_threshold: torch.Tensor = torch.sort(torch.abs(angle_volume))[0] - volume_threshold = volume_threshold[ - int( - volume_threshold.shape[0] - * float(config["rotation_stabilization_threshold_border"]) - ) - ] * float(config["rotation_stabilization_threshold_factor"]) - - donor_idx = torch.where(torch.abs(angle_donor) > donor_threshold)[0] - volume_idx = torch.where(torch.abs(angle_volume) > volume_threshold)[0] - mylogger.info( - f"Border: {config['rotation_stabilization_threshold_border']}, " - f"factor {config['rotation_stabilization_threshold_factor']} " - ) - mylogger.info( - f"Donor threshold: {donor_threshold:.3e}, volume threshold: {volume_threshold:.3e}" - ) - mylogger.info( - f"Found broken rotation values: " - f"donor {int(donor_idx.shape[0])}, " - f"volume {int(volume_idx.shape[0])}" - ) - angle_donor[donor_idx] = angle_volume[donor_idx] - angle_volume[volume_idx] = angle_donor[volume_idx] - - donor_idx = torch.where(torch.abs(angle_donor) > donor_threshold)[0] - volume_idx = torch.where(torch.abs(angle_volume) > volume_threshold)[0] - mylogger.info( - f"After fill in these broken rotation values remain: " - f"donor {int(donor_idx.shape[0])}, " - f"volume {int(volume_idx.shape[0])}" - ) - angle_donor[donor_idx] = 0.0 - angle_volume[volume_idx] = 0.0 - angle_donor_volume = (angle_donor + angle_volume) / 2.0 - - mylogger.info("Rotate acceptor data based on the average rotation") - for frame_id in range(0, angle_donor_volume.shape[0]): - acceptor[frame_id, ...] = tv.transforms.functional.affine( - img=acceptor[frame_id, ...].unsqueeze(0), - angle=-float(angle_donor_volume[frame_id]), - translate=[0, 0], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ).squeeze(0) - - mylogger.info("Rotate donor data based on the average rotation") - for frame_id in range(0, angle_donor_volume.shape[0]): - donor[frame_id, ...] = tv.transforms.functional.affine( - img=donor[frame_id, ...].unsqueeze(0), - angle=-float(angle_donor_volume[frame_id]), - translate=[0, 0], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ).squeeze(0) - - mylogger.info("Rotate oxygenation data based on the average rotation") - for frame_id in range(0, angle_donor_volume.shape[0]): - oxygenation[frame_id, ...] = tv.transforms.functional.affine( - img=oxygenation[frame_id, ...].unsqueeze(0), - angle=-float(angle_donor_volume[frame_id]), - translate=[0, 0], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ).squeeze(0) - - mylogger.info("Rotate volume data based on the average rotation") - for frame_id in range(0, angle_donor_volume.shape[0]): - volume[frame_id, ...] = tv.transforms.functional.affine( - img=volume[frame_id, ...].unsqueeze(0), - angle=-float(angle_donor_volume[frame_id]), - translate=[0, 0], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ).squeeze(0) - - return (acceptor, donor, oxygenation, volume, angle_donor_volume) diff --git a/new_pipeline/functions/perform_donor_volume_translation.py b/new_pipeline/functions/perform_donor_volume_translation.py deleted file mode 100644 index 7091add..0000000 --- a/new_pipeline/functions/perform_donor_volume_translation.py +++ /dev/null @@ -1,143 +0,0 @@ -import torch -import torchvision as tv # type: ignore -import logging - -from functions.calculate_translation import calculate_translation -from functions.ImageAlignment import ImageAlignment - - -@torch.no_grad() -def perform_donor_volume_translation( - mylogger: logging.Logger, - acceptor: torch.Tensor, - donor: torch.Tensor, - oxygenation: torch.Tensor, - volume: torch.Tensor, - ref_image_donor: torch.Tensor, - ref_image_volume: torch.Tensor, - image_alignment: ImageAlignment, - batch_size: int, - config: dict, - fill_value: float = 0, -) -> tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, -]: - - mylogger.info("Calculate translation between donor data and donor ref image") - tvec_donor = calculate_translation( - input=donor, - reference_image=ref_image_donor, - image_alignment=image_alignment, - batch_size=batch_size, - ) - - mylogger.info("Calculate translation between volume data and volume ref image") - tvec_volume = calculate_translation( - input=volume, - reference_image=ref_image_volume, - image_alignment=image_alignment, - batch_size=batch_size, - ) - - mylogger.info("Average over both translations") - - for i in range(0, 2): - mylogger.info(f"Processing dimension {i}") - donor_threshold: torch.Tensor = torch.sort(torch.abs(tvec_donor[:, i]))[0] - donor_threshold = donor_threshold[ - int( - donor_threshold.shape[0] - * float(config["rotation_stabilization_threshold_border"]) - ) - ] * float(config["rotation_stabilization_threshold_factor"]) - - volume_threshold: torch.Tensor = torch.sort(torch.abs(tvec_volume[:, i]))[0] - volume_threshold = volume_threshold[ - int( - volume_threshold.shape[0] - * float(config["rotation_stabilization_threshold_border"]) - ) - ] * float(config["rotation_stabilization_threshold_factor"]) - - donor_idx = torch.where(torch.abs(tvec_donor[:, i]) > donor_threshold)[0] - volume_idx = torch.where(torch.abs(tvec_volume[:, i]) > volume_threshold)[0] - mylogger.info( - f"Border: {config['rotation_stabilization_threshold_border']}, " - f"factor {config['rotation_stabilization_threshold_factor']} " - ) - mylogger.info( - f"Donor threshold: {donor_threshold:.3e}, volume threshold: {volume_threshold:.3e}" - ) - mylogger.info( - f"Found broken rotation values: " - f"donor {int(donor_idx.shape[0])}, " - f"volume {int(volume_idx.shape[0])}" - ) - tvec_donor[donor_idx, i] = tvec_volume[donor_idx, i] - tvec_volume[volume_idx, i] = tvec_donor[volume_idx, i] - - donor_idx = torch.where(torch.abs(tvec_donor[:, i]) > donor_threshold)[0] - volume_idx = torch.where(torch.abs(tvec_volume[:, i]) > volume_threshold)[0] - mylogger.info( - f"After fill in these broken rotation values remain: " - f"donor {int(donor_idx.shape[0])}, " - f"volume {int(volume_idx.shape[0])}" - ) - tvec_donor[donor_idx, i] = 0.0 - tvec_volume[volume_idx, i] = 0.0 - - tvec_donor_volume = (tvec_donor + tvec_volume) / 2.0 - - mylogger.info("Translate acceptor data based on the average translation vector") - for frame_id in range(0, tvec_donor_volume.shape[0]): - acceptor[frame_id, ...] = tv.transforms.functional.affine( - img=acceptor[frame_id, ...].unsqueeze(0), - angle=0, - translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ).squeeze(0) - - mylogger.info("Translate donor data based on the average translation vector") - for frame_id in range(0, tvec_donor_volume.shape[0]): - donor[frame_id, ...] = tv.transforms.functional.affine( - img=donor[frame_id, ...].unsqueeze(0), - angle=0, - translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ).squeeze(0) - - mylogger.info("Translate oxygenation data based on the average translation vector") - for frame_id in range(0, tvec_donor_volume.shape[0]): - oxygenation[frame_id, ...] = tv.transforms.functional.affine( - img=oxygenation[frame_id, ...].unsqueeze(0), - angle=0, - translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ).squeeze(0) - - mylogger.info("Translate volume data based on the average translation vector") - for frame_id in range(0, tvec_donor_volume.shape[0]): - volume[frame_id, ...] = tv.transforms.functional.affine( - img=volume[frame_id, ...].unsqueeze(0), - angle=0, - translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ).squeeze(0) - - return (acceptor, donor, oxygenation, volume, tvec_donor_volume) diff --git a/new_pipeline/functions/regression.py b/new_pipeline/functions/regression.py deleted file mode 100644 index d4efac0..0000000 --- a/new_pipeline/functions/regression.py +++ /dev/null @@ -1,117 +0,0 @@ -import torch -import logging -from functions.regression_internal import regression_internal - - -@torch.no_grad() -def regression( - mylogger: logging.Logger, - target_camera_id: int, - regressor_camera_ids: list[int], - mask: torch.Tensor, - data: torch.Tensor, - data_filtered: torch.Tensor, - first_none_ramp_frame: int, -) -> tuple[torch.Tensor, torch.Tensor]: - - assert len(regressor_camera_ids) > 0 - - mylogger.info("Prepare the target signal - 1.0 (from data_filtered)") - target_signals_train: torch.Tensor = ( - data_filtered[target_camera_id, ..., first_none_ramp_frame:].clone() - 1.0 - ) - target_signals_train[target_signals_train < -1] = 0.0 - - # Check if everything is happy - assert target_signals_train.ndim == 3 - assert target_signals_train.ndim == data[target_camera_id, ...].ndim - assert target_signals_train.shape[0] == data[target_camera_id, ...].shape[0] - assert target_signals_train.shape[1] == data[target_camera_id, ...].shape[1] - assert (target_signals_train.shape[2] + first_none_ramp_frame) == data[ - target_camera_id, ... - ].shape[2] - - mylogger.info("Prepare the regressor signals (linear plus from data_filtered)") - - regressor_signals_train: torch.Tensor = torch.zeros( - ( - data_filtered.shape[1], - data_filtered.shape[2], - data_filtered.shape[3], - len(regressor_camera_ids) + 1, - ), - device=data_filtered.device, - dtype=data_filtered.dtype, - ) - - mylogger.info("Copy the regressor signals - 1.0") - for matrix_id, id in enumerate(regressor_camera_ids): - regressor_signals_train[..., matrix_id] = data_filtered[id, ...] - 1.0 - - regressor_signals_train[regressor_signals_train < -1] = 0.0 - - mylogger.info("Create the linear regressor") - trend = torch.arange( - 0, regressor_signals_train.shape[-2], device=data_filtered.device - ) / float(regressor_signals_train.shape[-2] - 1) - trend -= trend.mean() - trend = trend.unsqueeze(0).unsqueeze(0) - trend = trend.tile( - (regressor_signals_train.shape[0], regressor_signals_train.shape[1], 1) - ) - regressor_signals_train[..., -1] = trend - - regressor_signals_train = regressor_signals_train[:, :, first_none_ramp_frame:, :] - - mylogger.info("Calculating the regression coefficients") - coefficients, intercept = regression_internal( - input_regressor=regressor_signals_train, input_target=target_signals_train - ) - del regressor_signals_train - del target_signals_train - - mylogger.info("Prepare the target signal - 1.0 (from data)") - target_signals_perform: torch.Tensor = data[target_camera_id, ...].clone() - 1.0 - - mylogger.info("Prepare the regressor signals (linear plus from data)") - regressor_signals_perform: torch.Tensor = torch.zeros( - ( - data.shape[1], - data.shape[2], - data.shape[3], - len(regressor_camera_ids) + 1, - ), - device=data.device, - dtype=data.dtype, - ) - - mylogger.info("Copy the regressor signals - 1.0 ") - for matrix_id, id in enumerate(regressor_camera_ids): - regressor_signals_perform[..., matrix_id] = data[id] - 1.0 - - mylogger.info("Create the linear regressor") - trend = torch.arange( - 0, regressor_signals_perform.shape[-2], device=data[0].device - ) / float(regressor_signals_perform.shape[-2] - 1) - trend -= trend.mean() - trend = trend.unsqueeze(0).unsqueeze(0) - trend = trend.tile( - (regressor_signals_perform.shape[0], regressor_signals_perform.shape[1], 1) - ) - regressor_signals_perform[..., -1] = trend - - mylogger.info("Remove regressors") - target_signals_perform -= ( - regressor_signals_perform * coefficients.unsqueeze(-2) - ).sum(dim=-1) - - mylogger.info("Remove offset") - target_signals_perform -= intercept.unsqueeze(-1) - - mylogger.info("Remove masked pixels") - target_signals_perform[mask, :] = 0.0 - - mylogger.info("Add an offset of 1.0") - target_signals_perform += 1.0 - - return target_signals_perform, coefficients diff --git a/new_pipeline/functions/regression_internal.py b/new_pipeline/functions/regression_internal.py deleted file mode 100644 index 352d7ba..0000000 --- a/new_pipeline/functions/regression_internal.py +++ /dev/null @@ -1,20 +0,0 @@ -import torch - - -def regression_internal( - input_regressor: torch.Tensor, input_target: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor]: - - regressor_offset = input_regressor.mean(keepdim=True, dim=-2) - target_offset = input_target.mean(keepdim=True, dim=-1) - - regressor = input_regressor - regressor_offset - target = input_target - target_offset - - coefficients, _, _, _ = torch.linalg.lstsq(regressor, target, rcond=None) # None ? - - intercept = target_offset.squeeze(-1) - ( - coefficients * regressor_offset.squeeze(-2) - ).sum(dim=-1) - - return coefficients, intercept diff --git a/new_pipeline/stage_1_get_ref_image.py b/new_pipeline/stage_1_get_ref_image.py deleted file mode 100644 index 55435f4..0000000 --- a/new_pipeline/stage_1_get_ref_image.py +++ /dev/null @@ -1,126 +0,0 @@ -import os -import torch -import numpy as np - - -from functions.get_experiments import get_experiments -from functions.get_trials import get_trials -from functions.bandpass import bandpass -from functions.create_logger import create_logger -from functions.get_torch_device import get_torch_device -from functions.load_config import load_config -from functions.data_raw_loader import data_raw_loader - -mylogger = create_logger( - save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_1" -) - -config = load_config(mylogger=mylogger) - -if config["binning_enable"] and (config["binning_at_the_end"] is False): - device: torch.device = torch.device("cpu") -else: - device = get_torch_device(mylogger, config["force_to_cpu"]) - - -dtype_str: str = config["dtype"] -dtype: torch.dtype = getattr(torch, dtype_str) - -raw_data_path: str = os.path.join( - config["basic_path"], - config["recoding_data"], - config["mouse_identifier"], - config["raw_path"], -) - -mylogger.info(f"Using data path: {raw_data_path}") - -first_experiment_id: int = int(get_experiments(raw_data_path).min()) -first_trial_id: int = int(get_trials(raw_data_path, first_experiment_id).min()) - -meta_channels: list[str] -meta_mouse_markings: str -meta_recording_date: str -meta_stimulation_times: dict -meta_experiment_names: dict -meta_trial_recording_duration: float -meta_frame_time: float -meta_mouse: str -data: torch.Tensor - -if config["binning_enable"] and (config["binning_at_the_end"] is False): - force_to_cpu_memory: bool = True -else: - force_to_cpu_memory = False - -mylogger.info("Loading data") - -( - meta_channels, - meta_mouse_markings, - meta_recording_date, - meta_stimulation_times, - meta_experiment_names, - meta_trial_recording_duration, - meta_frame_time, - meta_mouse, - data, -) = data_raw_loader( - raw_data_path=raw_data_path, - mylogger=mylogger, - experiment_id=first_experiment_id, - trial_id=first_trial_id, - device=device, - force_to_cpu_memory=force_to_cpu_memory, - config=config, -) -mylogger.info("-==- Done -==-") - -output_path = config["ref_image_path"] -mylogger.info(f"Create directory {output_path} in the case it does not exist") -os.makedirs(output_path, exist_ok=True) - -mylogger.info("Reference images") -for i in range(0, len(meta_channels)): - temp_path: str = os.path.join(output_path, meta_channels[i] + ".npy") - mylogger.info(f"Extract and save: {temp_path}") - frame_id: int = data.shape[-2] // 2 - mylogger.info(f"Will use frame id: {frame_id}") - ref_image: np.ndarray = ( - data[:, :, frame_id, meta_channels.index(meta_channels[i])] - .clone() - .cpu() - .numpy() - ) - np.save(temp_path, ref_image) -mylogger.info("-==- Done -==-") - -sample_frequency: float = 1.0 / meta_frame_time -mylogger.info( - ( - f"Heartbeat power {config['lower_freqency_bandpass']} Hz" - f" - {config['upper_freqency_bandpass']} Hz," - f" sample-rate: {sample_frequency}," - f" skipping the first {config['skip_frames_in_the_beginning']} frames" - ) -) - -for i in range(0, len(meta_channels)): - temp_path = os.path.join(output_path, meta_channels[i] + "_var.npy") - mylogger.info(f"Extract and save: {temp_path}") - - heartbeat_ts: torch.Tensor = bandpass( - data=data[..., i], - device=data.device, - low_frequency=config["lower_freqency_bandpass"], - high_frequency=config["upper_freqency_bandpass"], - fs=sample_frequency, - filtfilt_chuck_size=10, - ) - - heartbeat_power = heartbeat_ts[..., config["skip_frames_in_the_beginning"] :].var( - dim=-1 - ) - np.save(temp_path, heartbeat_power) - -mylogger.info("-==- Done -==-") diff --git a/new_pipeline/stage_2_make_heartbeat_mask.py b/new_pipeline/stage_2_make_heartbeat_mask.py deleted file mode 100644 index e36516b..0000000 --- a/new_pipeline/stage_2_make_heartbeat_mask.py +++ /dev/null @@ -1,153 +0,0 @@ -import matplotlib.pyplot as plt # type:ignore -import matplotlib -import numpy as np -import torch -import os - -from matplotlib.widgets import Slider, Button # type:ignore -from functools import partial -from functions.gauss_smear_individual import gauss_smear_individual -from functions.create_logger import create_logger -from functions.get_torch_device import get_torch_device -from functions.load_config import load_config - -mylogger = create_logger( - save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_2" -) - -config = load_config(mylogger=mylogger) - -path: str = config["ref_image_path"] -use_channel: str = "donor" -spatial_width: float = 4.0 -temporal_width: float = 0.1 - -threshold: float = 0.05 - -heartbeat_mask_threshold_file: str = os.path.join(path, "heartbeat_mask_threshold.npy") -if os.path.isfile(heartbeat_mask_threshold_file): - mylogger.info(f"loading previous threshold file: {heartbeat_mask_threshold_file}") - threshold = float(np.load(heartbeat_mask_threshold_file)[0]) - -mylogger.info(f"initial threshold is {threshold}") - -image_ref_file: str = os.path.join(path, use_channel + ".npy") -image_var_file: str = os.path.join(path, use_channel + "_var.npy") -heartbeat_mask_file: str = os.path.join(path, "heartbeat_mask.npy") - -device = get_torch_device(mylogger, config["force_to_cpu"]) - - -def next_frame( - i: float, images: np.ndarray, image_handle: matplotlib.image.AxesImage -) -> None: - global threshold - threshold = i - - display_image: np.ndarray = images.copy() - display_image[..., 2] = display_image[..., 0] - mask: np.ndarray = np.where(images[..., 2] >= i, 1.0, np.nan)[..., np.newaxis] - display_image *= mask - display_image = np.nan_to_num(display_image, nan=1.0) - - image_handle.set_data(display_image) - return - - -def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None: - global threshold - global image_3color - global path - global mylogger - global heartbeat_mask_file - global heartbeat_mask_threshold_file - - mylogger.info(f"Threshold: {threshold}") - - mask: np.ndarray = image_3color[..., 2] >= threshold - mylogger.info(f"Save mask to: {heartbeat_mask_file}") - np.save(heartbeat_mask_file, mask) - mylogger.info(f"Save threshold to: {heartbeat_mask_threshold_file}") - np.save(heartbeat_mask_threshold_file, np.array([threshold])) - exit() - - -def on_clicked_cancel(event: matplotlib.backend_bases.MouseEvent) -> None: - exit() - - -mylogger.info(f"loading image reference file: {image_ref_file}") -image_ref: np.ndarray = np.load(image_ref_file) -image_ref /= image_ref.max() - -mylogger.info(f"loading image heartbeat power: {image_var_file}") -image_var: np.ndarray = np.load(image_var_file) -image_var /= image_var.max() - -mylogger.info("Smear the image heartbeat power patially") -temp, _ = gauss_smear_individual( - input=torch.tensor(image_var[..., np.newaxis], device=device), - spatial_width=spatial_width, - temporal_width=temporal_width, - use_matlab_mask=False, -) -temp /= temp.max() - -mylogger.info("-==- DONE -==-") - -image_3color = np.concatenate( - ( - np.zeros_like(image_ref[..., np.newaxis]), - image_ref[..., np.newaxis], - temp.cpu().numpy(), - ), - axis=-1, -) - -mylogger.info("Prepare image") - -display_image = image_3color.copy() -display_image[..., 2] = display_image[..., 0] -mask = np.where(image_3color[..., 2] >= threshold, 1.0, np.nan)[..., np.newaxis] -display_image *= mask -display_image = np.nan_to_num(display_image, nan=1.0) - -value_sort = np.sort(image_var.flatten()) -value_sort_max = value_sort[int(value_sort.shape[0] * 0.95)] -mylogger.info("-==- DONE -==-") - -mylogger.info("Create figure") - -fig: matplotlib.figure.Figure = plt.figure() - -image_handle = plt.imshow(display_image, vmin=0, vmax=1, cmap="hot") - -mylogger.info("Add controls") - -axfreq = fig.add_axes(rect=(0.4, 0.9, 0.3, 0.03)) -slice_slider = Slider( - ax=axfreq, - label="Threshold", - valmin=0, - valmax=value_sort_max, - valinit=threshold, - valstep=value_sort_max / 100.0, -) -axbutton_accept = fig.add_axes(rect=(0.3, 0.85, 0.2, 0.04)) -button_accept = Button( - ax=axbutton_accept, label="Accept", image=None, color="0.85", hovercolor="0.95" -) -button_accept.on_clicked(on_clicked_accept) # type: ignore - -axbutton_cancel = fig.add_axes(rect=(0.55, 0.85, 0.2, 0.04)) -button_cancel = Button( - ax=axbutton_cancel, label="Cancel", image=None, color="0.85", hovercolor="0.95" -) -button_cancel.on_clicked(on_clicked_cancel) # type: ignore - -slice_slider.on_changed( - partial(next_frame, images=image_3color, image_handle=image_handle) -) - -mylogger.info("Display") -plt.show() diff --git a/new_pipeline/stage_3_refine_mask.py b/new_pipeline/stage_3_refine_mask.py deleted file mode 100644 index 83f9ecd..0000000 --- a/new_pipeline/stage_3_refine_mask.py +++ /dev/null @@ -1,157 +0,0 @@ -import os -import numpy as np - -import matplotlib.pyplot as plt # type:ignore -import matplotlib -from matplotlib.widgets import Button # type:ignore - -# pip install roipoly -from roipoly import RoiPoly # type:ignore - -from functions.create_logger import create_logger -from functions.get_torch_device import get_torch_device -from functions.load_config import load_config - - -def compose_image(image_3color: np.ndarray, mask: np.ndarray) -> np.ndarray: - display_image = image_3color.copy() - display_image[..., 2] = display_image[..., 0] - display_image[mask == 0, :] = 1.0 - return display_image - - -def on_clicked_accept(event: matplotlib.backend_bases.MouseEvent) -> None: - global mylogger - global refined_mask_file - global mask - - mylogger.info(f"Save mask to: {refined_mask_file}") - np.save(refined_mask_file, mask) - - exit() - - -def on_clicked_cancel(event: matplotlib.backend_bases.MouseEvent) -> None: - global mylogger - mylogger.info("Ended without saving the mask") - exit() - - -def on_clicked_add(event: matplotlib.backend_bases.MouseEvent) -> None: - global new_roi - global mask - global image_3color - global display_image - global mylogger - if len(new_roi.x) > 0: - mylogger.info("A ROI with the following coordiantes has been added to the mask") - for i in range(0, len(new_roi.x)): - mylogger.info(f"{round(new_roi.x[i],1)} x {round(new_roi.y[i],1)}") - mylogger.info("") - new_mask = new_roi.get_mask(display_image[:, :, 0]) - mask[new_mask] = 0.0 - display_image = compose_image(image_3color=image_3color, mask=mask) - image_handle.set_data(display_image) - for line in ax_main.lines: - line.remove() - plt.draw() - - new_roi = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False) - - -def on_clicked_remove(event: matplotlib.backend_bases.MouseEvent) -> None: - global new_roi - global mask - global image_3color - global display_image - if len(new_roi.x) > 0: - mylogger.info( - "A ROI with the following coordiantes has been removed from the mask" - ) - for i in range(0, len(new_roi.x)): - mylogger.info(f"{round(new_roi.x[i],1)} x {round(new_roi.y[i],1)}") - mylogger.info("") - new_mask = new_roi.get_mask(display_image[:, :, 0]) - mask[new_mask] = 1.0 - display_image = compose_image(image_3color=image_3color, mask=mask) - image_handle.set_data(display_image) - for line in ax_main.lines: - line.remove() - plt.draw() - new_roi = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False) - - -mylogger = create_logger( - save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_3" -) - -config = load_config(mylogger=mylogger) - -device = get_torch_device(mylogger, config["force_to_cpu"]) - -path: str = config["ref_image_path"] -use_channel: str = "donor" - -image_ref_file: str = os.path.join(path, use_channel + ".npy") -heartbeat_mask_file: str = os.path.join(path, "heartbeat_mask.npy") -refined_mask_file: str = os.path.join(path, "mask_not_rotated.npy") - -mylogger.info(f"loading image reference file: {image_ref_file}") -image_ref: np.ndarray = np.load(image_ref_file) -image_ref /= image_ref.max() - -mylogger.info(f"loading heartbeat mask: {heartbeat_mask_file}") -mask: np.ndarray = np.load(heartbeat_mask_file) - -image_3color = np.concatenate( - ( - np.zeros_like(image_ref[..., np.newaxis]), - image_ref[..., np.newaxis], - np.zeros_like(image_ref[..., np.newaxis]), - ), - axis=-1, -) - -mylogger.info("-==- DONE -==-") - -fig, ax_main = plt.subplots() - -display_image = compose_image(image_3color=image_3color, mask=mask) -image_handle = ax_main.imshow(display_image, vmin=0, vmax=1, cmap="hot") - -mylogger.info("Add controls") - -axbutton_accept = fig.add_axes(rect=(0.3, 0.85, 0.2, 0.04)) -button_accept = Button( - ax=axbutton_accept, label="Accept", image=None, color="0.85", hovercolor="0.95" -) -button_accept.on_clicked(on_clicked_accept) # type: ignore - -axbutton_cancel = fig.add_axes(rect=(0.5, 0.85, 0.2, 0.04)) -button_cancel = Button( - ax=axbutton_cancel, label="Cancel", image=None, color="0.85", hovercolor="0.95" -) -button_cancel.on_clicked(on_clicked_cancel) # type: ignore - -axbutton_addmask = fig.add_axes(rect=(0.3, 0.9, 0.2, 0.04)) -button_addmask = Button( - ax=axbutton_addmask, label="Add mask", image=None, color="0.85", hovercolor="0.95" -) -button_addmask.on_clicked(on_clicked_add) # type: ignore - -axbutton_removemask = fig.add_axes(rect=(0.5, 0.9, 0.2, 0.04)) -button_removemask = Button( - ax=axbutton_removemask, - label="Remove mask", - image=None, - color="0.85", - hovercolor="0.95", -) -button_removemask.on_clicked(on_clicked_remove) # type: ignore - -# ax_main.cla() - -mylogger.info("Display") -new_roi: RoiPoly = RoiPoly(ax=ax_main, color="r", close_fig=False, show_fig=False) - -plt.show() diff --git a/new_pipeline/stage_4_process.py b/new_pipeline/stage_4_process.py deleted file mode 100644 index b822856..0000000 --- a/new_pipeline/stage_4_process.py +++ /dev/null @@ -1,918 +0,0 @@ -import numpy as np -import torch -import torchvision as tv # type: ignore - -import os -import logging -import h5py # type: ignore - -from functions.create_logger import create_logger -from functions.get_torch_device import get_torch_device -from functions.load_config import load_config -from functions.get_experiments import get_experiments -from functions.get_trials import get_trials -from functions.binning import binning -from functions.ImageAlignment import ImageAlignment -from functions.align_refref import align_refref -from functions.perform_donor_volume_rotation import perform_donor_volume_rotation -from functions.perform_donor_volume_translation import perform_donor_volume_translation -from functions.bandpass import bandpass -from functions.gauss_smear_individual import gauss_smear_individual -from functions.regression import regression -from functions.data_raw_loader import data_raw_loader - - -@torch.no_grad() -def process_trial( - config: dict, - mylogger: logging.Logger, - experiment_id: int, - trial_id: int, - device: torch.device, -): - - mylogger.info("") - mylogger.info("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") - mylogger.info("~ TRIAL START ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") - mylogger.info("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") - mylogger.info("") - - if device != torch.device("cpu"): - torch.cuda.empty_cache() - mylogger.info("Empty CUDA cache") - cuda_total_memory: int = torch.cuda.get_device_properties( - device.index - ).total_memory - else: - cuda_total_memory = 0 - - raw_data_path: str = os.path.join( - config["basic_path"], - config["recoding_data"], - config["mouse_identifier"], - config["raw_path"], - ) - - if config["binning_enable"] and (config["binning_at_the_end"] is False): - force_to_cpu_memory: bool = True - else: - force_to_cpu_memory = False - - meta_channels: list[str] - meta_mouse_markings: str - meta_recording_date: str - meta_stimulation_times: dict - meta_experiment_names: dict - meta_trial_recording_duration: float - meta_frame_time: float - meta_mouse: str - data: torch.Tensor - - ( - meta_channels, - meta_mouse_markings, - meta_recording_date, - meta_stimulation_times, - meta_experiment_names, - meta_trial_recording_duration, - meta_frame_time, - meta_mouse, - data, - ) = data_raw_loader( - raw_data_path=raw_data_path, - mylogger=mylogger, - experiment_id=experiment_id, - trial_id=trial_id, - device=device, - force_to_cpu_memory=force_to_cpu_memory, - config=config, - ) - experiment_name: str = f"Exp{experiment_id:03d}_Trial{trial_id:03d}" - - dtype_str = config["dtype"] - dtype_np: np.dtype = getattr(np, dtype_str) - - dtype: torch.dtype = data.dtype - - if device != torch.device("cpu"): - free_mem = cuda_total_memory - max( - [torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)] - ) - mylogger.info(f"CUDA memory: {free_mem//1024} MByte") - - mylogger.info(f"Data shape: {data.shape}") - mylogger.info("-==- Done -==-") - - mylogger.info("Finding limit values in the RAW data and mark them for masking") - limit: float = (2**16) - 1 - for i in range(0, data.shape[3]): - zero_pixel_mask: torch.Tensor = torch.any(data[..., i] >= limit, dim=-1) - data[zero_pixel_mask, :, i] = -100.0 - mylogger.info( - f"{meta_channels[i]}: " - f"found {int(zero_pixel_mask.type(dtype=dtype).sum())} pixel " - f"with limit values " - ) - mylogger.info("-==- Done -==-") - - mylogger.info("Reference images and mask") - - ref_image_path: str = config["ref_image_path"] - - ref_image_path_acceptor: str = os.path.join(ref_image_path, "acceptor.npy") - if os.path.isfile(ref_image_path_acceptor) is False: - mylogger.info(f"Could not load ref file: {ref_image_path_acceptor}") - assert os.path.isfile(ref_image_path_acceptor) - return - - mylogger.info(f"Loading ref file data: {ref_image_path_acceptor}") - ref_image_acceptor: torch.Tensor = torch.tensor( - np.load(ref_image_path_acceptor).astype(dtype_np), dtype=dtype, device=device - ) - - ref_image_path_donor: str = os.path.join(ref_image_path, "donor.npy") - if os.path.isfile(ref_image_path_donor) is False: - mylogger.info(f"Could not load ref file: {ref_image_path_donor}") - assert os.path.isfile(ref_image_path_donor) - return - - mylogger.info(f"Loading ref file data: {ref_image_path_donor}") - ref_image_donor: torch.Tensor = torch.tensor( - np.load(ref_image_path_donor).astype(dtype_np), dtype=dtype, device=device - ) - - ref_image_path_oxygenation: str = os.path.join(ref_image_path, "oxygenation.npy") - if os.path.isfile(ref_image_path_oxygenation) is False: - mylogger.info(f"Could not load ref file: {ref_image_path_oxygenation}") - assert os.path.isfile(ref_image_path_oxygenation) - return - - mylogger.info(f"Loading ref file data: {ref_image_path_oxygenation}") - ref_image_oxygenation: torch.Tensor = torch.tensor( - np.load(ref_image_path_oxygenation).astype(dtype_np), dtype=dtype, device=device - ) - - ref_image_path_volume: str = os.path.join(ref_image_path, "volume.npy") - if os.path.isfile(ref_image_path_volume) is False: - mylogger.info(f"Could not load ref file: {ref_image_path_volume}") - assert os.path.isfile(ref_image_path_volume) - return - - mylogger.info(f"Loading ref file data: {ref_image_path_volume}") - ref_image_volume: torch.Tensor = torch.tensor( - np.load(ref_image_path_volume).astype(dtype_np), dtype=dtype, device=device - ) - - refined_mask_file: str = os.path.join(ref_image_path, "mask_not_rotated.npy") - if os.path.isfile(refined_mask_file) is False: - mylogger.info(f"Could not load mask file: {refined_mask_file}") - assert os.path.isfile(refined_mask_file) - return - - mylogger.info(f"Loading mask file data: {refined_mask_file}") - mask: torch.Tensor = torch.tensor( - np.load(refined_mask_file).astype(dtype_np), dtype=dtype, device=device - ) - mylogger.info("-==- Done -==-") - - if config["binning_enable"] and (config["binning_at_the_end"] is False): - mylogger.info("Binning of data") - mylogger.info( - ( - f"kernel_size={int(config['binning_kernel_size'])}, " - f"stride={int(config['binning_stride'])}, " - f"divisor_override={int(config['binning_divisor_override'])}" - ) - ) - - data = binning( - data, - kernel_size=int(config["binning_kernel_size"]), - stride=int(config["binning_stride"]), - divisor_override=int(config["binning_divisor_override"]), - ).to(device=device) - ref_image_acceptor = ( - binning( - ref_image_acceptor.unsqueeze(-1).unsqueeze(-1), - kernel_size=int(config["binning_kernel_size"]), - stride=int(config["binning_stride"]), - divisor_override=int(config["binning_divisor_override"]), - ) - .squeeze(-1) - .squeeze(-1) - ) - ref_image_donor = ( - binning( - ref_image_donor.unsqueeze(-1).unsqueeze(-1), - kernel_size=int(config["binning_kernel_size"]), - stride=int(config["binning_stride"]), - divisor_override=int(config["binning_divisor_override"]), - ) - .squeeze(-1) - .squeeze(-1) - ) - ref_image_oxygenation = ( - binning( - ref_image_oxygenation.unsqueeze(-1).unsqueeze(-1), - kernel_size=int(config["binning_kernel_size"]), - stride=int(config["binning_stride"]), - divisor_override=int(config["binning_divisor_override"]), - ) - .squeeze(-1) - .squeeze(-1) - ) - ref_image_volume = ( - binning( - ref_image_volume.unsqueeze(-1).unsqueeze(-1), - kernel_size=int(config["binning_kernel_size"]), - stride=int(config["binning_stride"]), - divisor_override=int(config["binning_divisor_override"]), - ) - .squeeze(-1) - .squeeze(-1) - ) - mask = ( - binning( - mask.unsqueeze(-1).unsqueeze(-1), - kernel_size=int(config["binning_kernel_size"]), - stride=int(config["binning_stride"]), - divisor_override=int(config["binning_divisor_override"]), - ) - .squeeze(-1) - .squeeze(-1) - ) - mylogger.info(f"Data shape: {data.shape}") - mylogger.info("-==- Done -==-") - - mylogger.info("Preparing alignment") - image_alignment = ImageAlignment(default_dtype=dtype, device=device) - - mylogger.info("Re-order Raw data") - data = data.moveaxis(-2, 0).moveaxis(-1, 0) - mylogger.info(f"Data shape: {data.shape}") - mylogger.info("-==- Done -==-") - - mylogger.info("Alignment of the ref images and the mask") - mylogger.info("Ref image of donor stays fixed.") - mylogger.info("Ref image of volume and the mask doesn't need to be touched") - mylogger.info("Calculate translation and rotation between the reference images") - angle_refref, tvec_refref, ref_image_acceptor, ref_image_donor = align_refref( - mylogger=mylogger, - ref_image_acceptor=ref_image_acceptor, - ref_image_donor=ref_image_donor, - image_alignment=image_alignment, - batch_size=config["alignment_batch_size"], - fill_value=-100.0, - ) - mylogger.info(f"Rotation: {round(float(angle_refref[0]),2)} degree") - mylogger.info( - f"Translation: {round(float(tvec_refref[0]),1)} x {round(float(tvec_refref[1]),1)} pixel" - ) - - if config["save_alignment"]: - temp_path: str = os.path.join( - config["export_path"], experiment_name + "_angle_refref.npy" - ) - mylogger.info(f"Save angle to {temp_path}") - np.save(temp_path, angle_refref.cpu()) - - temp_path = os.path.join( - config["export_path"], experiment_name + "_tvec_refref.npy" - ) - mylogger.info(f"Save translation vector to {temp_path}") - np.save(temp_path, tvec_refref.cpu()) - - mylogger.info("Moving & rotating the oxygenation ref image") - ref_image_oxygenation = tv.transforms.functional.affine( - img=ref_image_oxygenation.unsqueeze(0), - angle=-float(angle_refref), - translate=[0, 0], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=-100.0, - ) - - ref_image_oxygenation = tv.transforms.functional.affine( - img=ref_image_oxygenation, - angle=0, - translate=[tvec_refref[1], tvec_refref[0]], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=-100.0, - ).squeeze(0) - mylogger.info("-==- Done -==-") - - mylogger.info("Rotate and translate the acceptor and oxygenation data accordingly") - acceptor_index: int = config["required_order"].index("acceptor") - donor_index: int = config["required_order"].index("donor") - oxygenation_index: int = config["required_order"].index("oxygenation") - volume_index: int = config["required_order"].index("volume") - - mylogger.info("Rotate acceptor") - data[acceptor_index, ...] = tv.transforms.functional.affine( - img=data[acceptor_index, ...], - angle=-float(angle_refref), - translate=[0, 0], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=-100.0, - ) - - mylogger.info("Translate acceptor") - data[acceptor_index, ...] = tv.transforms.functional.affine( - img=data[acceptor_index, ...], - angle=0, - translate=[tvec_refref[1], tvec_refref[0]], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=-100.0, - ) - - mylogger.info("Rotate oxygenation") - data[oxygenation_index, ...] = tv.transforms.functional.affine( - img=data[oxygenation_index, ...], - angle=-float(angle_refref), - translate=[0, 0], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=-100.0, - ) - - mylogger.info("Translate oxygenation") - data[oxygenation_index, ...] = tv.transforms.functional.affine( - img=data[oxygenation_index, ...], - angle=0, - translate=[tvec_refref[1], tvec_refref[0]], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=-100.0, - ) - mylogger.info("-==- Done -==-") - - mylogger.info("Perform rotation between donor and volume and its ref images") - mylogger.info("for all frames and then rotate all the data accordingly") - perform_donor_volume_rotation - ( - data[acceptor_index, ...], - data[donor_index, ...], - data[oxygenation_index, ...], - data[volume_index, ...], - angle_donor_volume, - ) = perform_donor_volume_rotation( - mylogger=mylogger, - acceptor=data[acceptor_index, ...], - donor=data[donor_index, ...], - oxygenation=data[oxygenation_index, ...], - volume=data[volume_index, ...], - ref_image_donor=ref_image_donor, - ref_image_volume=ref_image_volume, - image_alignment=image_alignment, - batch_size=config["alignment_batch_size"], - fill_value=-100.0, - config=config, - ) - - mylogger.info( - f"angles: " - f"min {round(float(angle_donor_volume.min()),2)} " - f"max {round(float(angle_donor_volume.max()),2)} " - f"mean {round(float(angle_donor_volume.mean()),2)} " - ) - - if config["save_alignment"]: - temp_path = os.path.join( - config["export_path"], experiment_name + "_angle_donor_volume.npy" - ) - mylogger.info(f"Save angles to {temp_path}") - np.save(temp_path, angle_donor_volume.cpu()) - mylogger.info("-==- Done -==-") - - mylogger.info("Perform translation between donor and volume and its ref images") - mylogger.info("for all frames and then translate all the data accordingly") - ( - data[acceptor_index, ...], - data[donor_index, ...], - data[oxygenation_index, ...], - data[volume_index, ...], - tvec_donor_volume, - ) = perform_donor_volume_translation( - mylogger=mylogger, - acceptor=data[acceptor_index, ...], - donor=data[donor_index, ...], - oxygenation=data[oxygenation_index, ...], - volume=data[volume_index, ...], - ref_image_donor=ref_image_donor, - ref_image_volume=ref_image_volume, - image_alignment=image_alignment, - batch_size=config["alignment_batch_size"], - fill_value=-100.0, - config=config, - ) - - mylogger.info( - f"translation dim 0: " - f"min {round(float(tvec_donor_volume[:,0].min()),1)} " - f"max {round(float(tvec_donor_volume[:,0].max()),1)} " - f"mean {round(float(tvec_donor_volume[:,0].mean()),1)} " - ) - mylogger.info( - f"translation dim 1: " - f"min {round(float(tvec_donor_volume[:,1].min()),1)} " - f"max {round(float(tvec_donor_volume[:,1].max()),1)} " - f"mean {round(float(tvec_donor_volume[:,1].mean()),1)} " - ) - - if config["save_alignment"]: - temp_path = os.path.join( - config["export_path"], experiment_name + "_tvec_donor_volume.npy" - ) - mylogger.info(f"Save translation vector to {temp_path}") - np.save(temp_path, tvec_donor_volume.cpu()) - mylogger.info("-==- Done -==-") - - mylogger.info("Finding zeros values in the RAW data and mark them for masking") - for i in range(0, data.shape[0]): - zero_pixel_mask = torch.any(data[i, ...] == 0, dim=0) - data[i, :, zero_pixel_mask] = -100.0 - mylogger.info( - f"{config['required_order'][i]}: " - f"found {int(zero_pixel_mask.type(dtype=dtype).sum())} pixel " - f"with zeros " - ) - mylogger.info("-==- Done -==-") - - mylogger.info("Update mask with the new regions due to alignment") - - new_mask_area: torch.Tensor = torch.any(torch.any(data < -0.1, dim=0), dim=0).bool() - mask = (mask == 0).bool() - mask = torch.logical_or(mask, new_mask_area) - mask_negative: torch.Tensor = mask.clone() - mask_positve: torch.Tensor = torch.logical_not(mask) - del mask - - mylogger.info("Update the data with the new mask") - data *= mask_positve.unsqueeze(0).unsqueeze(0).type(dtype=dtype) - mylogger.info("-==- Done -==-") - - mylogger.info("Interpolate the 'in-between' frames for oxygenation and volume") - data[oxygenation_index, 1:, ...] = ( - data[oxygenation_index, 1:, ...] + data[oxygenation_index, :-1, ...] - ) / 2.0 - data[volume_index, 1:, ...] = ( - data[volume_index, 1:, ...] + data[volume_index, :-1, ...] - ) / 2.0 - mylogger.info("-==- Done -==-") - - sample_frequency: float = 1.0 / meta_frame_time - - mylogger.info("Extract heartbeat from volume signal") - heartbeat_ts: torch.Tensor = bandpass( - data=data[volume_index, ...].movedim(0, -1).clone(), - device=data.device, - low_frequency=config["lower_freqency_bandpass"], - high_frequency=config["upper_freqency_bandpass"], - fs=sample_frequency, - filtfilt_chuck_size=config["heartbeat_filtfilt_chuck_size"], - ) - heartbeat_ts = heartbeat_ts.flatten(start_dim=0, end_dim=-2) - mask_flatten: torch.Tensor = mask_positve.flatten(start_dim=0, end_dim=-1) - - heartbeat_ts = heartbeat_ts[mask_flatten, :] - - heartbeat_ts = heartbeat_ts.movedim(0, -1) - heartbeat_ts -= heartbeat_ts.mean(dim=0, keepdim=True) - - volume_heartbeat, _, _ = torch.linalg.svd(heartbeat_ts, full_matrices=False) - volume_heartbeat = volume_heartbeat[:, 0] - volume_heartbeat -= volume_heartbeat[ - config["skip_frames_in_the_beginning"] : - ].mean() - - del heartbeat_ts - - if device != torch.device("cpu"): - torch.cuda.empty_cache() - mylogger.info("Empty CUDA cache") - free_mem = cuda_total_memory - max( - [torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)] - ) - mylogger.info(f"CUDA memory: {free_mem//1024} MByte") - - if config["save_heartbeat"]: - temp_path = os.path.join( - config["export_path"], experiment_name + "_volume_heartbeat.npy" - ) - mylogger.info(f"Save volume heartbeat to {temp_path}") - np.save(temp_path, volume_heartbeat.cpu()) - mylogger.info("-==- Done -==-") - - volume_heartbeat = volume_heartbeat.unsqueeze(0).unsqueeze(0) - norm_volume_heartbeat = ( - volume_heartbeat[..., config["skip_frames_in_the_beginning"] :] ** 2 - ).sum(dim=-1) - - heartbeat_coefficients: torch.Tensor = torch.zeros( - (data.shape[0], data.shape[-2], data.shape[-1]), - dtype=data.dtype, - device=data.device, - ) - for i in range(0, data.shape[0]): - y = bandpass( - data=data[i, ...].movedim(0, -1).clone(), - device=data.device, - low_frequency=config["lower_freqency_bandpass"], - high_frequency=config["upper_freqency_bandpass"], - fs=sample_frequency, - filtfilt_chuck_size=config["heartbeat_filtfilt_chuck_size"], - )[..., config["skip_frames_in_the_beginning"] :] - y -= y.mean(dim=-1, keepdim=True) - - heartbeat_coefficients[i, ...] = ( - volume_heartbeat[..., config["skip_frames_in_the_beginning"] :] * y - ).sum(dim=-1) / norm_volume_heartbeat - - heartbeat_coefficients[i, ...] *= mask_positve.type( - dtype=heartbeat_coefficients.dtype - ) - del y - - if config["save_heartbeat"]: - temp_path = os.path.join( - config["export_path"], experiment_name + "_heartbeat_coefficients.npy" - ) - mylogger.info(f"Save heartbeat coefficients to {temp_path}") - np.save(temp_path, heartbeat_coefficients.cpu()) - mylogger.info("-==- Done -==-") - - mylogger.info("Remove heart beat from data") - data -= heartbeat_coefficients.unsqueeze(1) * volume_heartbeat.unsqueeze(0).movedim( - -1, 1 - ) - mylogger.info("-==- Done -==-") - - donor_heartbeat_factor = heartbeat_coefficients[donor_index, ...].clone() - acceptor_heartbeat_factor = heartbeat_coefficients[acceptor_index, ...].clone() - del heartbeat_coefficients - - if device != torch.device("cpu"): - torch.cuda.empty_cache() - mylogger.info("Empty CUDA cache") - free_mem = cuda_total_memory - max( - [torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)] - ) - mylogger.info(f"CUDA memory: {free_mem//1024} MByte") - - mylogger.info("Calculate scaling factor for donor and acceptor") - donor_factor: torch.Tensor = ( - donor_heartbeat_factor + acceptor_heartbeat_factor - ) / (2 * donor_heartbeat_factor) - acceptor_factor: torch.Tensor = ( - donor_heartbeat_factor + acceptor_heartbeat_factor - ) / (2 * acceptor_heartbeat_factor) - - del donor_heartbeat_factor - del acceptor_heartbeat_factor - - if config["save_factors"]: - temp_path = os.path.join( - config["export_path"], experiment_name + "_donor_factor.npy" - ) - mylogger.info(f"Save donor factor to {temp_path}") - np.save(temp_path, donor_factor.cpu()) - - temp_path = os.path.join( - config["export_path"], experiment_name + "_acceptor_factor.npy" - ) - mylogger.info(f"Save acceptor factor to {temp_path}") - np.save(temp_path, acceptor_factor.cpu()) - mylogger.info("-==- Done -==-") - - mylogger.info("Scale acceptor to heart beat amplitude") - mylogger.info("Calculate mean") - mean_values_acceptor = data[ - acceptor_index, config["skip_frames_in_the_beginning"] :, ... - ].nanmean(dim=0, keepdim=True) - - mylogger.info("Remove mean") - data[acceptor_index, ...] -= mean_values_acceptor - mylogger.info("Apply acceptor_factor and mask") - data[acceptor_index, ...] *= acceptor_factor.unsqueeze(0) * mask_positve.unsqueeze( - 0 - ) - mylogger.info("Add mean") - data[acceptor_index, ...] += mean_values_acceptor - mylogger.info("-==- Done -==-") - - mylogger.info("Scale donor to heart beat amplitude") - mylogger.info("Calculate mean") - mean_values_donor = data[ - donor_index, config["skip_frames_in_the_beginning"] :, ... - ].nanmean(dim=0, keepdim=True) - mylogger.info("Remove mean") - data[donor_index, ...] -= mean_values_donor - mylogger.info("Apply donor_factor and mask") - data[donor_index, ...] *= donor_factor.unsqueeze(0) * mask_positve.unsqueeze(0) - mylogger.info("Add mean") - data[donor_index, ...] += mean_values_donor - mylogger.info("-==- Done -==-") - - mylogger.info("Divide by mean over time") - data /= data[:, config["skip_frames_in_the_beginning"] :, ...].nanmean( - dim=1, - keepdim=True, - ) - data = data.nan_to_num(nan=0.0) - mylogger.info("-==- Done -==-") - - mylogger.info("Preparation for regression -- Gauss smear") - spatial_width = float(config["gauss_smear_spatial_width"]) - - if config["binning_enable"] and (config["binning_at_the_end"] is False): - spatial_width /= float(config["binning_kernel_size"]) - - mylogger.info( - f"Mask -- " - f"spatial width: {spatial_width}, " - f"temporal width: {float(config['gauss_smear_temporal_width'])}, " - f"use matlab mode: {bool(config['gauss_smear_use_matlab_mask'])} " - ) - - input_mask = mask_positve.type(dtype=dtype).clone() - - filtered_mask: torch.Tensor - filtered_mask, _ = gauss_smear_individual( - input=input_mask, - spatial_width=spatial_width, - temporal_width=float(config["gauss_smear_temporal_width"]), - use_matlab_mask=bool(config["gauss_smear_use_matlab_mask"]), - epsilon=float(torch.finfo(input_mask.dtype).eps), - ) - - mylogger.info("creating a copy of the data") - data_filtered = data.clone().movedim(1, -1) - if device != torch.device("cpu"): - torch.cuda.empty_cache() - mylogger.info("Empty CUDA cache") - free_mem = cuda_total_memory - max( - [torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)] - ) - mylogger.info(f"CUDA memory: {free_mem//1024} MByte") - - overwrite_fft_gauss: None | torch.Tensor = None - for i in range(0, data_filtered.shape[0]): - mylogger.info( - f"{config['required_order'][i]} -- " - f"spatial width: {spatial_width}, " - f"temporal width: {float(config['gauss_smear_temporal_width'])}, " - f"use matlab mode: {bool(config['gauss_smear_use_matlab_mask'])} " - ) - data_filtered[i, ...] *= input_mask.unsqueeze(-1) - data_filtered[i, ...], overwrite_fft_gauss = gauss_smear_individual( - input=data_filtered[i, ...], - spatial_width=spatial_width, - temporal_width=float(config["gauss_smear_temporal_width"]), - overwrite_fft_gauss=overwrite_fft_gauss, - use_matlab_mask=bool(config["gauss_smear_use_matlab_mask"]), - epsilon=float(torch.finfo(input_mask.dtype).eps), - ) - - data_filtered[i, ...] /= filtered_mask + 1e-20 - data_filtered[i, ...] += 1.0 - input_mask.unsqueeze(-1) - - del filtered_mask - del overwrite_fft_gauss - del input_mask - mylogger.info("data_filtered is populated") - - if device != torch.device("cpu"): - torch.cuda.empty_cache() - mylogger.info("Empty CUDA cache") - free_mem = cuda_total_memory - max( - [torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)] - ) - mylogger.info(f"CUDA memory: {free_mem//1024} MByte") - mylogger.info("-==- Done -==-") - - mylogger.info("Preperation for Regression") - mylogger.info("Move time dimensions of data to the last dimension") - data = data.movedim(1, -1) - - mylogger.info("Regression Acceptor") - mylogger.info(f"Target: {config['target_camera_acceptor']}") - mylogger.info( - f"Regressors: constant, linear and {config['regressor_cameras_acceptor']}" - ) - target_id: int = config["required_order"].index(config["target_camera_acceptor"]) - regressor_id: list[int] = [] - for i in range(0, len(config["regressor_cameras_acceptor"])): - regressor_id.append( - config["required_order"].index(config["regressor_cameras_acceptor"][i]) - ) - - data_acceptor, coefficients_acceptor = regression( - mylogger=mylogger, - target_camera_id=target_id, - regressor_camera_ids=regressor_id, - mask=mask_negative, - data=data, - data_filtered=data_filtered, - first_none_ramp_frame=int(config["skip_frames_in_the_beginning"]), - ) - - if config["save_regression_coefficients"]: - temp_path = os.path.join( - config["export_path"], experiment_name + "_coefficients_acceptor.npy" - ) - mylogger.info(f"Save acceptor coefficients to {temp_path}") - np.save(temp_path, coefficients_acceptor.cpu()) - del coefficients_acceptor - - mylogger.info("-==- Done -==-") - - mylogger.info("Regression Donor") - mylogger.info(f"Target: {config['target_camera_donor']}") - mylogger.info( - f"Regressors: constant, linear and {config['regressor_cameras_donor']}" - ) - target_id = config["required_order"].index(config["target_camera_donor"]) - regressor_id = [] - for i in range(0, len(config["regressor_cameras_donor"])): - regressor_id.append( - config["required_order"].index(config["regressor_cameras_donor"][i]) - ) - - data_donor, coefficients_donor = regression( - mylogger=mylogger, - target_camera_id=target_id, - regressor_camera_ids=regressor_id, - mask=mask_negative, - data=data, - data_filtered=data_filtered, - first_none_ramp_frame=int(config["skip_frames_in_the_beginning"]), - ) - - if config["save_regression_coefficients"]: - temp_path = os.path.join( - config["export_path"], experiment_name + "_coefficients_donor.npy" - ) - mylogger.info(f"Save acceptor donor to {temp_path}") - np.save(temp_path, coefficients_donor.cpu()) - del coefficients_donor - mylogger.info("-==- Done -==-") - - del data - del data_filtered - - if device != torch.device("cpu"): - torch.cuda.empty_cache() - mylogger.info("Empty CUDA cache") - free_mem = cuda_total_memory - max( - [torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)] - ) - mylogger.info(f"CUDA memory: {free_mem//1024} MByte") - - mylogger.info("Calculate ratio sequence") - if config["classical_ratio_mode"]: - mylogger.info("via acceptor / donor") - ratio_sequence: torch.Tensor = data_acceptor / data_donor - mylogger.info("via / mean over time") - ratio_sequence /= ratio_sequence.mean(dim=-1, keepdim=True) - else: - mylogger.info("via 1.0 + acceptor - donor") - ratio_sequence = 1.0 + data_acceptor - data_donor - - mylogger.info("Remove nan") - ratio_sequence = torch.nan_to_num(ratio_sequence, nan=0.0) - mylogger.info("-==- Done -==-") - - if config["binning_enable"] and config["binning_at_the_end"]: - mylogger.info("Binning of data") - mylogger.info( - ( - f"kernel_size={int(config['binning_kernel_size'])}, " - f"stride={int(config['binning_stride'])}, " - "divisor_override=None" - ) - ) - - ratio_sequence = binning( - ratio_sequence.unsqueeze(-1), - kernel_size=int(config["binning_kernel_size"]), - stride=int(config["binning_stride"]), - divisor_override=None, - ).squeeze(-1) - - mask_positve = ( - binning( - mask_positve.unsqueeze(-1).unsqueeze(-1).type(dtype=dtype), - kernel_size=int(config["binning_kernel_size"]), - stride=int(config["binning_stride"]), - divisor_override=None, - ) - .squeeze(-1) - .squeeze(-1) - ) - mask_positve = (mask_positve > 0).type(torch.bool) - - if config["save_as_python"]: - temp_path = os.path.join( - config["export_path"], experiment_name + "_ratio_sequence.npz" - ) - mylogger.info(f"Save ratio_sequence and mask to {temp_path}") - np.savez_compressed( - temp_path, ratio_sequence=ratio_sequence.cpu(), mask=mask_positve.cpu() - ) - - if config["save_as_matlab"]: - temp_path = os.path.join( - config["export_path"], experiment_name + "_ratio_sequence.hd5" - ) - mylogger.info(f"Save ratio_sequence and mask to {temp_path}") - file_handle = h5py.File(temp_path, "w") - - mask_positve = mask_positve.movedim(0, -1) - ratio_sequence = ratio_sequence.movedim(1, -1).movedim(0, -1) - _ = file_handle.create_dataset( - "mask", - data=mask_positve.type(torch.uint8).cpu(), - compression="gzip", - compression_opts=9, - ) - _ = file_handle.create_dataset( - "ratio_sequence", - data=ratio_sequence.cpu(), - compression="gzip", - compression_opts=9, - ) - mylogger.info("Reminder: How to read with matlab:") - mylogger.info(f"mask = h5read('{temp_path}','/mask');") - mylogger.info(f"ratio_sequence = h5read('{temp_path}','/ratio_sequence');") - file_handle.close() - - del ratio_sequence - del mask_positve - del mask_negative - - mylogger.info("") - mylogger.info("***********************************************") - mylogger.info("* TRIAL END ***********************************") - mylogger.info("***********************************************") - mylogger.info("") - - return - - -mylogger = create_logger( - save_logging_messages=True, display_logging_messages=True, log_stage_name="stage_4" -) -config = load_config(mylogger=mylogger) - -if (config["save_as_python"] is False) and (config["save_as_matlab"] is False): - mylogger.info("No output will be created. ") - mylogger.info("Change save_as_python and/or save_as_matlab in the config file") - mylogger.info("ERROR: STOP!!!") - exit() - -device = get_torch_device(mylogger, config["force_to_cpu"]) - -mylogger.info(f"Create directory {config['export_path']} in the case it does not exist") -os.makedirs(config["export_path"], exist_ok=True) - -raw_data_path: str = os.path.join( - config["basic_path"], - config["recoding_data"], - config["mouse_identifier"], - config["raw_path"], -) - -if os.path.isdir(raw_data_path) is False: - mylogger.info(f"ERROR: could not find raw directory {raw_data_path}!!!!") - exit() - -experiments = get_experiments(raw_data_path) - -for experiment_counter in range(0, experiments.shape[0]): - experiment_id = int(experiments[experiment_counter]) - trials = get_trials(raw_data_path, experiment_id) - for trial_counter in range(0, trials.shape[0]): - trial_id = int(trials[trial_counter]) - - mylogger.info("") - mylogger.info( - f"======= EXPERIMENT ID: {experiment_id} ==== TRIAL ID: {trial_id} =======" - ) - mylogger.info("") - - process_trial( - config=config, - mylogger=mylogger, - experiment_id=experiment_id, - trial_id=trial_id, - device=device, - )