diff --git a/functions/ImageAlignment.py b/functions/ImageAlignment.py new file mode 100644 index 0000000..6472d02 --- /dev/null +++ b/functions/ImageAlignment.py @@ -0,0 +1,1015 @@ +import torch +import torchvision as tv # type: ignore + +# The source code is based on: +# https://github.com/matejak/imreg_dft + +# The original LICENSE: +# Copyright (c) 2014, Matěj Týč +# Copyright (c) 2011-2014, Christoph Gohlke +# Copyright (c) 2011-2014, The Regents of the University of California +# Produced at the Laboratory for Fluorescence Dynamics + +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# * Neither the name of the {organization} nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +class ImageAlignment(torch.nn.Module): + device: torch.device + default_dtype: torch.dtype + excess_const: float = 1.1 + exponent: str = "inf" + success: torch.Tensor | None = None + + # The factor that detmines how many + # sub-pixel we will shift + scale_factor: int = 4 + + reference_image: torch.Tensor | None = None + + last_scale: torch.Tensor | None = None + last_angle: torch.Tensor | None = None + last_tvec: torch.Tensor | None = None + + # Cache + image_reference_dft: torch.Tensor | None = None + filt: torch.Tensor + pcorr_shape: torch.Tensor + log_base: torch.Tensor + image_reference_logp: torch.Tensor + + def __init__( + self, + device: torch.device | None = None, + default_dtype: torch.dtype | None = None, + ) -> None: + super().__init__() + + assert device is not None + assert default_dtype is not None + self.device = device + self.default_dtype = default_dtype + + def set_new_reference_image(self, new_reference_image: torch.Tensor | None = None): + assert new_reference_image is not None + assert new_reference_image.ndim == 2 + self.reference_image = ( + new_reference_image.detach() + .clone() + .to(device=self.device) + .type(dtype=self.default_dtype) + ) + self.image_reference_dft = None + + def forward( + self, input: torch.Tensor, new_reference_image: torch.Tensor | None = None + ) -> torch.Tensor: + assert input.ndim == 3 + + if new_reference_image is not None: + self.set_new_reference_image(new_reference_image) + + assert self.reference_image is not None + assert self.reference_image.ndim == 2 + assert input.shape[-2] == self.reference_image.shape[-2] + assert input.shape[-1] == self.reference_image.shape[-1] + + self.last_scale, self.last_angle, self.last_tvec, output = self.similarity( + self.reference_image, + input.to(device=self.device).type(dtype=self.default_dtype), + ) + + return output + + def dry_run( + self, input: torch.Tensor, new_reference_image: torch.Tensor | None = None + ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: + assert input.ndim == 3 + + if new_reference_image is not None: + self.set_new_reference_image(new_reference_image) + + assert self.reference_image is not None + assert self.reference_image.ndim == 2 + assert input.shape[-2] == self.reference_image.shape[-2] + assert input.shape[-1] == self.reference_image.shape[-1] + + images_todo = input.to(device=self.device).type(dtype=self.default_dtype) + image_reference = self.reference_image + + assert image_reference.ndim == 2 + assert images_todo.ndim == 3 + + bgval: torch.Tensor = self.get_borderval(img=images_todo, radius=5) + + self.last_scale, self.last_angle, self.last_tvec = self._similarity( + image_reference, + images_todo, + bgval, + ) + + return self.last_scale, self.last_angle, self.last_tvec + + def dry_run_translation( + self, input: torch.Tensor, new_reference_image: torch.Tensor | None = None + ) -> torch.Tensor: + assert input.ndim == 3 + + if new_reference_image is not None: + self.set_new_reference_image(new_reference_image) + + assert self.reference_image is not None + assert self.reference_image.ndim == 2 + assert input.shape[-2] == self.reference_image.shape[-2] + assert input.shape[-1] == self.reference_image.shape[-1] + + images_todo = input.to(device=self.device).type(dtype=self.default_dtype) + image_reference = self.reference_image + + assert image_reference.ndim == 2 + assert images_todo.ndim == 3 + + tvec, _ = self._translation(image_reference, images_todo) + + return tvec + + # --------------- + + def dry_run_angle( + self, + input: torch.Tensor, + new_reference_image: torch.Tensor | None = None, + ) -> torch.Tensor: + assert input.ndim == 3 + + if new_reference_image is not None: + self.set_new_reference_image(new_reference_image) + + constraints_dynamic_angle_0: torch.Tensor = torch.zeros( + (input.shape[0]), dtype=self.default_dtype, device=self.device + ) + constraints_dynamic_angle_1: torch.Tensor | None = None + constraints_dynamic_scale_0: torch.Tensor = torch.ones( + (input.shape[0]), dtype=self.default_dtype, device=self.device + ) + constraints_dynamic_scale_1: torch.Tensor | None = None + + assert self.reference_image is not None + assert self.reference_image.ndim == 2 + assert input.shape[-2] == self.reference_image.shape[-2] + assert input.shape[-1] == self.reference_image.shape[-1] + + images_todo = input.to(device=self.device).type(dtype=self.default_dtype) + image_reference = self.reference_image + + assert image_reference.ndim == 2 + assert images_todo.ndim == 3 + + _, newangle = self._get_ang_scale( + image_reference, + images_todo, + constraints_dynamic_scale_0, + constraints_dynamic_scale_1, + constraints_dynamic_angle_0, + constraints_dynamic_angle_1, + ) + + return newangle + + # --------------- + + def _get_pcorr_shape(self, shape: torch.Size) -> tuple[int, int]: + ret = (int(max(shape[-2:]) * 1.0),) * 2 + return ret + + def _get_log_base(self, shape: torch.Size, new_r: torch.Tensor) -> torch.Tensor: + old_r = torch.tensor( + (float(shape[-2]) * self.excess_const) / 2.0, + dtype=self.default_dtype, + device=self.device, + ) + log_base = torch.exp(torch.log(old_r) / new_r) + return log_base + + def wrap_angle( + self, angles: torch.Tensor, ceil: float = 2 * torch.pi + ) -> torch.Tensor: + angles += ceil / 2.0 + angles %= ceil + angles -= ceil / 2.0 + return angles + + def get_borderval( + self, img: torch.Tensor, radius: int | None = None + ) -> torch.Tensor: + assert img.ndim == 3 + if radius is None: + mindim = min([int(img.shape[-2]), int(img.shape[-1])]) + radius = max(1, mindim // 20) + mask = torch.zeros( + (int(img.shape[-2]), int(img.shape[-1])), + dtype=torch.bool, + device=self.device, + ) + mask[:, :radius] = True + mask[:, -radius:] = True + mask[:radius, :] = True + mask[-radius:, :] = True + + mean = torch.median(img[:, mask], dim=-1)[0] + return mean + + def get_apofield(self, shape: torch.Size, aporad: int) -> torch.Tensor: + if aporad == 0: + return torch.ones( + shape[-2:], + dtype=self.default_dtype, + device=self.device, + ) + + assert int(shape[-2]) > aporad * 2 + assert int(shape[-1]) > aporad * 2 + + apos = torch.hann_window( + aporad * 2, dtype=self.default_dtype, periodic=False, device=self.device + ) + + toapp_0 = torch.ones( + shape[-2], + dtype=self.default_dtype, + device=self.device, + ) + toapp_0[:aporad] = apos[:aporad] + toapp_0[-aporad:] = apos[-aporad:] + + toapp_1 = torch.ones( + shape[-1], + dtype=self.default_dtype, + device=self.device, + ) + toapp_1[:aporad] = apos[:aporad] + toapp_1[-aporad:] = apos[-aporad:] + + apofield = torch.outer(toapp_0, toapp_1) + + return apofield + + def _get_subarr( + self, array: torch.Tensor, center: torch.Tensor, rad: int + ) -> torch.Tensor: + assert array.ndim == 3 + assert center.ndim == 2 + assert array.shape[0] == center.shape[0] + assert center.shape[1] == 2 + + dim = 1 + 2 * rad + subarr = torch.zeros( + (array.shape[0], dim, dim), dtype=self.default_dtype, device=self.device + ) + + corner = center - rad + idx_p = range(0, corner.shape[0]) + for ii in range(0, dim): + yidx = corner[:, 0] + ii + yidx %= array.shape[-2] + for jj in range(0, dim): + xidx = corner[:, 1] + jj + xidx %= array.shape[-1] + subarr[:, ii, jj] = array[idx_p, yidx, xidx] + + return subarr + + def _argmax_2d(self, array: torch.Tensor) -> torch.Tensor: + assert array.ndim == 3 + + max_pos = array.reshape( + (array.shape[0], array.shape[1] * array.shape[2]) + ).argmax(dim=1) + pos_0 = max_pos // array.shape[2] + + max_pos -= pos_0 * array.shape[2] + ret = torch.zeros( + (array.shape[0], 2), dtype=self.default_dtype, device=self.device + ) + ret[:, 0] = pos_0 + ret[:, 1] = max_pos + + return ret.type(dtype=torch.int64) + + def _apodize(self, what: torch.Tensor) -> torch.Tensor: + mindim = min([int(what.shape[-2]), int(what.shape[-1])]) + aporad = int(mindim * 0.12) + + apofield = self.get_apofield(what.shape, aporad).unsqueeze(0) + + res = what * apofield + bg = self.get_borderval(what, aporad // 2).unsqueeze(-1).unsqueeze(-1) + res += bg * (1 - apofield) + return res + + def _logpolar_filter(self, shape: torch.Size) -> torch.Tensor: + yy = torch.linspace( + -torch.pi / 2.0, + torch.pi / 2.0, + shape[-2], + dtype=self.default_dtype, + device=self.device, + ).unsqueeze(1) + + xx = torch.linspace( + -torch.pi / 2.0, + torch.pi / 2.0, + shape[-1], + dtype=self.default_dtype, + device=self.device, + ).unsqueeze(0) + + rads = torch.sqrt(yy**2 + xx**2) + filt = 1.0 - torch.cos(rads) ** 2 + + filt[torch.abs(rads) > torch.pi / 2] = 1 + return filt + + def _get_angles(self, shape: torch.Tensor) -> torch.Tensor: + ret = torch.zeros( + (int(shape[-2]), int(shape[-1])), + dtype=self.default_dtype, + device=self.device, + ) + ret -= torch.linspace( + 0, + torch.pi, + int(shape[-2] + 1), + dtype=self.default_dtype, + device=self.device, + )[:-1].unsqueeze(-1) + + return ret + + def _get_lograd(self, shape: torch.Tensor, log_base: torch.Tensor) -> torch.Tensor: + ret = torch.zeros( + (int(shape[-2]), int(shape[-1])), + dtype=self.default_dtype, + device=self.device, + ) + ret += torch.pow( + log_base, + torch.arange( + 0, + int(shape[-1]), + dtype=self.default_dtype, + device=self.device, + ), + ).unsqueeze(0) + return ret + + def _logpolar( + self, image: torch.Tensor, shape: torch.Tensor, log_base: torch.Tensor + ) -> torch.Tensor: + assert image.ndim == 3 + + imshape: torch.Tensor = torch.tensor( + image.shape[-2:], + dtype=self.default_dtype, + device=self.device, + ) + + center: torch.Tensor = imshape.clone() / 2 + + theta: torch.Tensor = self._get_angles(shape) + radius_x: torch.Tensor = self._get_lograd(shape, log_base) + radius_y: torch.Tensor = radius_x.clone() + + ellipse_coef: torch.Tensor = imshape[0] / imshape[1] + radius_x /= ellipse_coef + + y = radius_y * torch.sin(theta) + center[0] + y /= float(image.shape[-2]) + y *= 2 + y -= 1 + + x = radius_x * torch.cos(theta) + center[1] + x /= float(image.shape[-1]) + x *= 2 + x -= 1 + + idx_x = torch.where(torch.abs(x) <= 1.0, 1.0, 0.0) + idx_y = torch.where(torch.abs(y) <= 1.0, 1.0, 0.0) + + normalized_coords = torch.cat( + ( + x.unsqueeze(-1), + y.unsqueeze(-1), + ), + dim=-1, + ).unsqueeze(0) + + output = torch.empty( + (int(image.shape[0]), int(y.shape[0]), int(y.shape[1])), + dtype=self.default_dtype, + device=self.device, + ) + + for id in range(0, int(image.shape[0])): + bgval: torch.Tensor = torch.quantile(image[id, :, :], q=1.0 / 100.0) + + temp = torch.nn.functional.grid_sample( + image[id, :, :].unsqueeze(0).unsqueeze(0), + normalized_coords, + mode="bilinear", + padding_mode="zeros", + align_corners=False, + ) + + output[id, :, :] = torch.where((idx_x * idx_y) == 0.0, bgval, temp) + + return output + + def _argmax_ext(self, array: torch.Tensor, exponent: float | str) -> torch.Tensor: + assert array.ndim == 3 + + if exponent == "inf": + ret = self._argmax_2d(array) + else: + assert isinstance(exponent, float) or isinstance(exponent, int) + + col = ( + torch.arange( + 0, array.shape[-2], dtype=self.default_dtype, device=self.device + ) + .unsqueeze(-1) + .unsqueeze(0) + ) + row = ( + torch.arange( + 0, array.shape[-1], dtype=self.default_dtype, device=self.device + ) + .unsqueeze(0) + .unsqueeze(0) + ) + + arr2 = torch.pow(array, float(exponent)) + arrsum = arr2.sum(dim=-2).sum(dim=-1) + + ret = torch.zeros( + (array.shape[0], 2), dtype=self.default_dtype, device=self.device + ) + + arrprody = (arr2 * col).sum(dim=-1).sum(dim=-1) / arrsum + arrprodx = (arr2 * row).sum(dim=-1).sum(dim=-1) / arrsum + + ret[:, 0] = arrprody.squeeze(-1).squeeze(-1) + ret[:, 1] = arrprodx.squeeze(-1).squeeze(-1) + + idx = torch.where(arrsum == 0.0)[0] + ret[idx, :] = 0.0 + return ret + + def _interpolate( + self, array: torch.Tensor, rough: torch.Tensor, rad: int = 2 + ) -> torch.Tensor: + assert array.ndim == 3 + assert rough.ndim == 2 + + rough = torch.round(rough).type(torch.int64) + + surroundings = self._get_subarr(array, rough, rad) + + com = self._argmax_ext(surroundings, 1.0) + + offset = com - rad + ret = rough + offset + + ret += 0.5 + ret %= ( + torch.tensor(array.shape[-2:], dtype=self.default_dtype, device=self.device) + .type(dtype=torch.int64) + .unsqueeze(0) + ) + ret -= 0.5 + return ret + + def _get_success( + self, array: torch.Tensor, coord: torch.Tensor, radius: int = 2 + ) -> torch.Tensor: + assert array.ndim == 3 + assert coord.ndim == 2 + assert array.shape[0] == coord.shape[0] + assert coord.shape[1] == 2 + + coord = torch.round(coord).type(dtype=torch.int64) + subarr = self._get_subarr( + array, coord, 2 + ) # Not my fault. They want a 2 there. Not radius + + theval = subarr.sum(dim=-1).sum(dim=-1) + + theval2 = array[range(0, coord.shape[0]), coord[:, 0], coord[:, 1]] + + success = torch.sqrt(theval * theval2) + return success + + def _get_constraint_mask( + self, + shape: torch.Size, + log_base: torch.Tensor, + constraints_scale_0: torch.Tensor, + constraints_scale_1: torch.Tensor | None, + constraints_angle_0: torch.Tensor, + constraints_angle_1: torch.Tensor | None, + ) -> torch.Tensor: + assert constraints_scale_0 is not None + assert constraints_angle_0 is not None + assert constraints_scale_0.ndim == 1 + assert constraints_angle_0.ndim == 1 + + assert constraints_scale_0.shape[0] == constraints_angle_0.shape[0] + + mask: torch.Tensor = torch.ones( + (constraints_scale_0.shape[0], int(shape[-2]), int(shape[-1])), + device=self.device, + dtype=self.default_dtype, + ) + + scale: torch.Tensor = constraints_scale_0.clone() + if constraints_scale_1 is not None: + sigma: torch.Tensor | None = constraints_scale_1.clone() + else: + sigma = None + + scales = torch.fft.ifftshift( + self._get_lograd( + torch.tensor(shape[-2:], device=self.device, dtype=self.default_dtype), + log_base, + ) + ) + + scales *= log_base ** (-shape[-1] / 2.0) + scales = scales.unsqueeze(0) - (1.0 / scale).unsqueeze(-1).unsqueeze(-1) + + if sigma is not None: + assert sigma.shape[0] == constraints_scale_0.shape[0] + + for p_id in range(0, sigma.shape[0]): + if sigma[p_id] == 0: + ascales = torch.abs(scales[p_id, ...]) + scale_min = ascales.min() + binary_mask = torch.where(ascales > scale_min, 0.0, 1.0) + mask[p_id, ...] *= binary_mask + else: + mask[p_id, ...] *= torch.exp( + -(torch.pow(scales[p_id, ...], 2)) / torch.pow(sigma[p_id], 2) + ) + + angle: torch.Tensor = constraints_angle_0.clone() + if constraints_angle_1 is not None: + sigma = constraints_angle_1.clone() + else: + sigma = None + + angles = self._get_angles( + torch.tensor(shape[-2:], device=self.device, dtype=self.default_dtype) + ) + + angles = angles.unsqueeze(0) + torch.deg2rad(angle).unsqueeze(-1).unsqueeze(-1) + + angles = torch.rad2deg(angles) + + if sigma is not None: + assert sigma.shape[0] == constraints_scale_0.shape[0] + + for p_id in range(0, sigma.shape[0]): + if sigma[p_id] == 0: + aangles = torch.abs(angles[p_id, ...]) + angle_min = aangles.min() + binary_mask = torch.where(aangles > angle_min, 0.0, 1.0) + mask[p_id, ...] *= binary_mask + else: + mask *= torch.exp( + -(torch.pow(angles[p_id, ...], 2)) / torch.pow(sigma[p_id], 2) + ) + + mask = torch.fft.fftshift(mask, dim=(-2, -1)) + + return mask + + def argmax_angscale( + self, + array: torch.Tensor, + log_base: torch.Tensor, + constraints_scale_0: torch.Tensor, + constraints_scale_1: torch.Tensor | None, + constraints_angle_0: torch.Tensor, + constraints_angle_1: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + assert array.ndim == 3 + assert constraints_scale_0 is not None + assert constraints_angle_0 is not None + assert constraints_scale_0.ndim == 1 + assert constraints_angle_0.ndim == 1 + + mask = self._get_constraint_mask( + array.shape[-2:], + log_base, + constraints_scale_0, + constraints_scale_1, + constraints_angle_0, + constraints_angle_1, + ) + + array_orig = array.clone() + + array *= mask + ret = self._argmax_ext(array, self.exponent) + + ret_final = self._interpolate(array, ret) + + success = self._get_success(array_orig, ret_final, 0) + + return ret_final, success + + def argmax_translation( + self, array: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + assert array.ndim == 3 + + array_orig = array.clone() + + ashape = torch.tensor(array.shape[-2:], device=self.device).type( + dtype=torch.int64 + ) + + aporad = (ashape // 6).min() + mask2 = self.get_apofield(torch.Size(ashape), aporad).unsqueeze(0) + array *= mask2 + + tvec = self._argmax_ext(array, "inf") + + tvec = self._interpolate(array_orig, tvec) + + success = self._get_success(array_orig, tvec, 2) + + return tvec, success + + def transform_img( + self, + img: torch.Tensor, + scale: torch.Tensor | None = None, + angle: torch.Tensor | None = None, + tvec: torch.Tensor | None = None, + bgval: torch.Tensor | None = None, + ) -> torch.Tensor: + assert img.ndim == 3 + + if scale is None: + scale = torch.ones( + (img.shape[0],), dtype=self.default_dtype, device=self.device + ) + assert scale.ndim == 1 + assert scale.shape[0] == img.shape[0] + + if angle is None: + angle = torch.zeros( + (img.shape[0],), dtype=self.default_dtype, device=self.device + ) + assert angle.ndim == 1 + assert angle.shape[0] == img.shape[0] + + if tvec is None: + tvec = torch.zeros( + (img.shape[0], 2), dtype=self.default_dtype, device=self.device + ) + assert tvec.ndim == 2 + assert tvec.shape[0] == img.shape[0] + assert tvec.shape[1] == 2 + + if bgval is None: + bgval = self.get_borderval(img) + assert bgval.ndim == 1 + assert bgval.shape[0] == img.shape[0] + + # Otherwise we need to decompose it and put it back together + assert torch.is_complex(img) is False + + output = torch.zeros_like(img) + + for pos in range(0, img.shape[0]): + image_processed = img[pos, :, :].unsqueeze(0).clone() + + temp_shift = [ + int(round(tvec[pos, 1].item() * self.scale_factor)), + int(round(tvec[pos, 0].item() * self.scale_factor)), + ] + + image_processed = torch.nn.functional.interpolate( + image_processed.unsqueeze(0), + scale_factor=self.scale_factor, + mode="bilinear", + ).squeeze(0) + + image_processed = tv.transforms.functional.affine( + img=image_processed, + angle=-float(angle[pos]), + translate=temp_shift, + scale=float(scale[pos]), + shear=[0, 0], + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=float(bgval[pos]), + center=None, + ) + + image_processed = torch.nn.functional.interpolate( + image_processed.unsqueeze(0), + scale_factor=1.0 / self.scale_factor, + mode="bilinear", + ).squeeze(0) + + image_processed = tv.transforms.functional.center_crop( + image_processed, img.shape[-2:] + ) + + output[pos, ...] = image_processed.squeeze(0) + + return output + + def transform_img_dict( + self, + img: torch.Tensor, + scale: torch.Tensor | None = None, + angle: torch.Tensor | None = None, + tvec: torch.Tensor | None = None, + bgval: torch.Tensor | None = None, + invert=False, + ) -> torch.Tensor: + if invert: + if scale is not None: + scale = 1.0 / scale + if angle is not None: + angle *= -1 + if tvec is not None: + tvec *= -1 + + res = self.transform_img(img, scale, angle, tvec, bgval=bgval) + return res + + def _phase_correlation( + self, image_reference: torch.Tensor, images_todo: torch.Tensor, callback, *args + ) -> tuple[torch.Tensor, torch.Tensor]: + assert image_reference.ndim == 3 + assert image_reference.shape[0] == 1 + assert images_todo.ndim == 3 + + assert callback is not None + + image_reference_fft = torch.fft.fft2(image_reference, dim=(-2, -1)) + images_todo_fft = torch.fft.fft2(images_todo, dim=(-2, -1)) + + eps = torch.abs(images_todo_fft).max(dim=-1)[0].max(dim=-1)[0] * 1e-15 + + cps = abs( + torch.fft.ifft2( + (image_reference_fft * images_todo_fft.conj()) + / ( + torch.abs(image_reference_fft) * torch.abs(images_todo_fft) + + eps.unsqueeze(-1).unsqueeze(-1) + ), + dim=(-2, -1), + ) + ) + + scps = torch.fft.fftshift(cps, dim=(-2, -1)) + + ret, success = callback(scps, *args) + + ret[:, 0] -= image_reference_fft.shape[-2] // 2 + ret[:, 1] -= image_reference_fft.shape[-1] // 2 + + return ret, success + + def _translation( + self, im0: torch.Tensor, im1: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + assert im0.ndim == 2 + ret, succ = self._phase_correlation( + im0.unsqueeze(0), im1, self.argmax_translation + ) + + return ret, succ + + def _get_ang_scale( + self, + image_reference: torch.Tensor, + images_todo: torch.Tensor, + constraints_scale_0: torch.Tensor, + constraints_scale_1: torch.Tensor | None, + constraints_angle_0: torch.Tensor, + constraints_angle_1: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + assert image_reference.ndim == 2 + assert images_todo.ndim == 3 + assert image_reference.shape[-1] == images_todo.shape[-1] + assert image_reference.shape[-2] == images_todo.shape[-2] + assert constraints_scale_0.shape[0] == images_todo.shape[0] + assert constraints_angle_0.shape[0] == images_todo.shape[0] + + if constraints_scale_1 is not None: + assert constraints_scale_1.shape[0] == images_todo.shape[0] + + if constraints_angle_1 is not None: + assert constraints_angle_1.shape[0] == images_todo.shape[0] + + if self.image_reference_dft is None: + image_reference_apod = self._apodize(image_reference.unsqueeze(0)) + self.image_reference_dft = torch.fft.fftshift( + torch.fft.fft2(image_reference_apod, dim=(-2, -1)), dim=(-2, -1) + ) + self.filt = self._logpolar_filter(image_reference.shape).unsqueeze(0) + self.image_reference_dft *= self.filt + self.pcorr_shape = torch.tensor( + self._get_pcorr_shape(image_reference.shape[-2:]), + dtype=self.default_dtype, + device=self.device, + ) + self.log_base = self._get_log_base( + image_reference.shape, + self.pcorr_shape[1], + ) + self.image_reference_logp = self._logpolar( + torch.abs(self.image_reference_dft), self.pcorr_shape, self.log_base + ) + + images_todo_apod = self._apodize(images_todo) + images_todo_dft = torch.fft.fftshift( + torch.fft.fft2(images_todo_apod, dim=(-2, -1)), dim=(-2, -1) + ) + + images_todo_dft *= self.filt + + images_todo_lopg = self._logpolar( + torch.abs(images_todo_dft), self.pcorr_shape, self.log_base + ) + + temp, _ = self._phase_correlation( + self.image_reference_logp, + images_todo_lopg, + self.argmax_angscale, + self.log_base, + constraints_scale_0, + constraints_scale_1, + constraints_angle_0, + constraints_angle_1, + ) + + arg_ang = temp[:, 0].clone() + arg_rad = temp[:, 1].clone() + + angle = -torch.pi * arg_ang / float(self.pcorr_shape[0]) + angle = torch.rad2deg(angle) + + angle = self.wrap_angle(angle, 360) + + scale = torch.pow(self.log_base, arg_rad) + + angle = -angle + scale = 1.0 / scale + + assert torch.where(scale < 2)[0].shape[0] == scale.shape[0] + assert torch.where(scale > 0.5)[0].shape[0] == scale.shape[0] + + return scale, angle + + def translation( + self, im0: torch.Tensor, im1: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + angle = torch.zeros( + (im1.shape[0]), dtype=self.default_dtype, device=self.device + ) + assert im1.ndim == 3 + assert im0.shape[-2] == im1.shape[-2] + assert im0.shape[-1] == im1.shape[-1] + + tvec, succ = self._translation(im0, im1) + tvec2, succ2 = self._translation(im0, torch.rot90(im1, k=2, dims=[-2, -1])) + + assert tvec.shape[0] == tvec2.shape[0] + assert tvec.ndim == 2 + assert tvec2.ndim == 2 + assert tvec.shape[1] == 2 + assert tvec2.shape[1] == 2 + assert succ.shape[0] == succ2.shape[0] + assert succ.ndim == 1 + assert succ2.ndim == 1 + assert tvec.shape[0] == succ.shape[0] + assert angle.shape[0] == tvec.shape[0] + assert angle.ndim == 1 + + for pos in range(0, angle.shape[0]): + pick_rotated = False + if succ2[pos] > succ[pos]: + pick_rotated = True + + if pick_rotated: + tvec[pos, :] = tvec2[pos, :] + succ[pos] = succ2[pos] + angle[pos] += 180 + + return tvec, succ, angle + + def _similarity( + self, + image_reference: torch.Tensor, + images_todo: torch.Tensor, + bgval: torch.Tensor, + ): + assert image_reference.ndim == 2 + assert images_todo.ndim == 3 + assert image_reference.shape[-1] == images_todo.shape[-1] + assert image_reference.shape[-2] == images_todo.shape[-2] + + # We are going to iterate and precise scale and angle estimates + scale: torch.Tensor = torch.ones( + (images_todo.shape[0]), dtype=self.default_dtype, device=self.device + ) + angle: torch.Tensor = torch.zeros( + (images_todo.shape[0]), dtype=self.default_dtype, device=self.device + ) + + constraints_dynamic_angle_0: torch.Tensor = torch.zeros( + (images_todo.shape[0]), dtype=self.default_dtype, device=self.device + ) + constraints_dynamic_angle_1: torch.Tensor | None = None + constraints_dynamic_scale_0: torch.Tensor = torch.ones( + (images_todo.shape[0]), dtype=self.default_dtype, device=self.device + ) + constraints_dynamic_scale_1: torch.Tensor | None = None + + newscale, newangle = self._get_ang_scale( + image_reference, + images_todo, + constraints_dynamic_scale_0, + constraints_dynamic_scale_1, + constraints_dynamic_angle_0, + constraints_dynamic_angle_1, + ) + scale *= newscale + angle += newangle + + im2 = self.transform_img(images_todo, scale, angle, bgval=bgval) + + tvec, self.success, res_angle = self.translation(image_reference, im2) + + angle += res_angle + + angle = self.wrap_angle(angle, 360) + + return scale, angle, tvec + + def similarity( + self, + image_reference: torch.Tensor, + images_todo: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + assert image_reference.ndim == 2 + assert images_todo.ndim == 3 + + bgval: torch.Tensor = self.get_borderval(img=images_todo, radius=5) + + scale, angle, tvec = self._similarity( + image_reference, + images_todo, + bgval, + ) + + im2 = self.transform_img_dict( + img=images_todo, + scale=scale, + angle=angle, + tvec=tvec, + bgval=bgval, + ) + + return scale, angle, tvec, im2 diff --git a/functions/align_refref.py b/functions/align_refref.py new file mode 100644 index 0000000..249665c --- /dev/null +++ b/functions/align_refref.py @@ -0,0 +1,57 @@ +import torch +import torchvision as tv # type: ignore +import logging +from functions.ImageAlignment import ImageAlignment +from functions.calculate_translation import calculate_translation +from functions.calculate_rotation import calculate_rotation + + +@torch.no_grad() +def align_refref( + mylogger: logging.Logger, + ref_image_acceptor: torch.Tensor, + ref_image_donor: torch.Tensor, + image_alignment: ImageAlignment, + batch_size: int, + fill_value: float = 0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + + mylogger.info("Rotate ref image acceptor onto donor") + angle_refref = calculate_rotation( + image_alignment=image_alignment, + input=ref_image_acceptor.unsqueeze(0), + reference_image=ref_image_donor, + batch_size=batch_size, + ) + + ref_image_acceptor = tv.transforms.functional.affine( + img=ref_image_acceptor.unsqueeze(0), + angle=-float(angle_refref), + translate=[0, 0], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=fill_value, + ) + + mylogger.info("Translate ref image acceptor onto donor") + tvec_refref = calculate_translation( + image_alignment=image_alignment, + input=ref_image_acceptor, + reference_image=ref_image_donor, + batch_size=batch_size, + ) + + tvec_refref = tvec_refref[0, :] + + ref_image_acceptor = tv.transforms.functional.affine( + img=ref_image_acceptor, + angle=0, + translate=[tvec_refref[1], tvec_refref[0]], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=fill_value, + ).squeeze(0) + + return angle_refref, tvec_refref, ref_image_acceptor, ref_image_donor diff --git a/functions/bandpass.py b/functions/bandpass.py new file mode 100644 index 0000000..171baf5 --- /dev/null +++ b/functions/bandpass.py @@ -0,0 +1,85 @@ +import torchaudio as ta # type: ignore +import torch + + +@torch.no_grad() +def filtfilt( + input: torch.Tensor, + butter_a: torch.Tensor, + butter_b: torch.Tensor, +) -> torch.Tensor: + assert butter_a.ndim == 1 + assert butter_b.ndim == 1 + assert butter_a.shape[0] == butter_b.shape[0] + + process_data: torch.Tensor = input.detach().clone() + + padding_length = 12 * int(butter_a.shape[0]) + left_padding = 2 * process_data[..., 0].unsqueeze(-1) - process_data[ + ..., 1 : padding_length + 1 + ].flip(-1) + right_padding = 2 * process_data[..., -1].unsqueeze(-1) - process_data[ + ..., -(padding_length + 1) : -1 + ].flip(-1) + process_data_padded = torch.cat((left_padding, process_data, right_padding), dim=-1) + + output = ta.functional.filtfilt( + process_data_padded.unsqueeze(0), butter_a, butter_b, clamp=False + ).squeeze(0) + + output = output[..., padding_length:-padding_length] + return output + + +@torch.no_grad() +def butter_bandpass( + device: torch.device, + low_frequency: float = 0.1, + high_frequency: float = 1.0, + fs: float = 30.0, +) -> tuple[torch.Tensor, torch.Tensor]: + import scipy # type: ignore + + butter_b_np, butter_a_np = scipy.signal.butter( + 4, [low_frequency, high_frequency], btype="bandpass", output="ba", fs=fs + ) + butter_a = torch.tensor(butter_a_np, device=device, dtype=torch.float32) + butter_b = torch.tensor(butter_b_np, device=device, dtype=torch.float32) + return butter_a, butter_b + + +@torch.no_grad() +def chunk_iterator(array: torch.Tensor, chunk_size: int): + for i in range(0, array.shape[0], chunk_size): + yield array[i : i + chunk_size] + + +@torch.no_grad() +def bandpass( + data: torch.Tensor, + device: torch.device, + low_frequency: float = 0.1, + high_frequency: float = 1.0, + fs=30.0, + filtfilt_chuck_size: int = 10, +) -> torch.Tensor: + butter_a, butter_b = butter_bandpass( + device=device, + low_frequency=low_frequency, + high_frequency=high_frequency, + fs=fs, + ) + + index_full_dataset: torch.Tensor = torch.arange( + 0, data.shape[1], device=device, dtype=torch.int64 + ) + + for chunk in chunk_iterator(index_full_dataset, filtfilt_chuck_size): + temp_filtfilt = filtfilt( + data[:, chunk, :], + butter_a=butter_a, + butter_b=butter_b, + ) + data[:, chunk, :] = temp_filtfilt + + return data diff --git a/functions/binning.py b/functions/binning.py new file mode 100644 index 0000000..ccfe657 --- /dev/null +++ b/functions/binning.py @@ -0,0 +1,21 @@ +import torch + + +def binning( + data: torch.Tensor, + kernel_size: int = 4, + stride: int = 4, + divisor_override: int | None = 1, +) -> torch.Tensor: + + assert data.ndim == 4 + return ( + torch.nn.functional.avg_pool2d( + input=data.movedim(0, -1).movedim(0, -1), + kernel_size=kernel_size, + stride=stride, + divisor_override=divisor_override, + ) + .movedim(-1, 0) + .movedim(-1, 0) + ) diff --git a/functions/calculate_rotation.py b/functions/calculate_rotation.py new file mode 100644 index 0000000..6a53afd --- /dev/null +++ b/functions/calculate_rotation.py @@ -0,0 +1,40 @@ +import torch + +from functions.ImageAlignment import ImageAlignment + + +@torch.no_grad() +def calculate_rotation( + image_alignment: ImageAlignment, + input: torch.Tensor, + reference_image: torch.Tensor, + batch_size: int, +) -> torch.Tensor: + angle = torch.zeros((input.shape[0])) + + data_loader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset(input), + batch_size=batch_size, + shuffle=False, + ) + start_position: int = 0 + for input_batch in data_loader: + assert len(input_batch) == 1 + + end_position = start_position + input_batch[0].shape[0] + + angle_temp = image_alignment.dry_run_angle( + input=input_batch[0], + new_reference_image=reference_image, + ) + + assert angle_temp is not None + + angle[start_position:end_position] = angle_temp + + start_position += input_batch[0].shape[0] + + angle = torch.where(angle >= 180, 360.0 - angle, angle) + angle = torch.where(angle <= -180, 360.0 + angle, angle) + + return angle diff --git a/functions/calculate_translation.py b/functions/calculate_translation.py new file mode 100644 index 0000000..9eadf59 --- /dev/null +++ b/functions/calculate_translation.py @@ -0,0 +1,37 @@ +import torch + +from functions.ImageAlignment import ImageAlignment + + +@torch.no_grad() +def calculate_translation( + image_alignment: ImageAlignment, + input: torch.Tensor, + reference_image: torch.Tensor, + batch_size: int, +) -> torch.Tensor: + tvec = torch.zeros((input.shape[0], 2)) + + data_loader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset(input), + batch_size=batch_size, + shuffle=False, + ) + start_position: int = 0 + for input_batch in data_loader: + assert len(input_batch) == 1 + + end_position = start_position + input_batch[0].shape[0] + + tvec_temp = image_alignment.dry_run_translation( + input=input_batch[0], + new_reference_image=reference_image, + ) + + assert tvec_temp is not None + + tvec[start_position:end_position, :] = tvec_temp + + start_position += input_batch[0].shape[0] + + return tvec diff --git a/functions/create_logger.py b/functions/create_logger.py new file mode 100644 index 0000000..8fcfa8a --- /dev/null +++ b/functions/create_logger.py @@ -0,0 +1,37 @@ +import logging +import datetime +import os + + +def create_logger( + save_logging_messages: bool, display_logging_messages: bool, log_stage_name: str +): + now = datetime.datetime.now() + dt_string_filename = now.strftime("%Y_%m_%d_%H_%M_%S") + + logger = logging.getLogger("MyLittleLogger") + logger.setLevel(logging.DEBUG) + + if save_logging_messages: + time_format = "%b %-d %Y %H:%M:%S" + logformat = "%(asctime)s %(message)s" + file_formatter = logging.Formatter(fmt=logformat, datefmt=time_format) + os.makedirs("logs_" + log_stage_name, exist_ok=True) + file_handler = logging.FileHandler( + os.path.join("logs_" + log_stage_name, f"log_{dt_string_filename}.txt") + ) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(file_formatter) + logger.addHandler(file_handler) + + if display_logging_messages: + time_format = "%H:%M:%S" + logformat = "%(asctime)s %(message)s" + stream_formatter = logging.Formatter(fmt=logformat, datefmt=time_format) + + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.INFO) + stream_handler.setFormatter(stream_formatter) + logger.addHandler(stream_handler) + + return logger diff --git a/functions/data_raw_loader.py b/functions/data_raw_loader.py new file mode 100644 index 0000000..67e55cf --- /dev/null +++ b/functions/data_raw_loader.py @@ -0,0 +1,339 @@ +import numpy as np +import torch +import os +import logging +import copy + +from functions.get_experiments import get_experiments +from functions.get_trials import get_trials +from functions.get_parts import get_parts +from functions.load_meta_data import load_meta_data + + +def data_raw_loader( + raw_data_path: str, + mylogger: logging.Logger, + experiment_id: int, + trial_id: int, + device: torch.device, + force_to_cpu_memory: bool, + config: dict, +) -> tuple[list[str], str, str, dict, dict, float, float, str, torch.Tensor]: + + meta_channels: list[str] = [] + meta_mouse_markings: str = "" + meta_recording_date: str = "" + meta_stimulation_times: dict = {} + meta_experiment_names: dict = {} + meta_trial_recording_duration: float = 0.0 + meta_frame_time: float = 0.0 + meta_mouse: str = "" + data: torch.Tensor = torch.zeros((1)) + + dtype_str = config["dtype"] + mylogger.info(f"Data precision will be {dtype_str}") + dtype: torch.dtype = getattr(torch, dtype_str) + dtype_np: np.dtype = getattr(np, dtype_str) + + if os.path.isdir(raw_data_path) is False: + mylogger.info(f"ERROR: could not find raw directory {raw_data_path}!!!!") + assert os.path.isdir(raw_data_path) + return ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + data, + ) + + if (torch.where(get_experiments(raw_data_path) == experiment_id)[0].shape[0]) != 1: + mylogger.info(f"ERROR: could not find experiment id {experiment_id}!!!!") + assert ( + torch.where(get_experiments(raw_data_path) == experiment_id)[0].shape[0] + ) == 1 + return ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + data, + ) + + if ( + torch.where(get_trials(raw_data_path, experiment_id) == trial_id)[0].shape[0] + ) != 1: + mylogger.info(f"ERROR: could not find trial id {trial_id}!!!!") + assert ( + torch.where(get_trials(raw_data_path, experiment_id) == trial_id)[0].shape[ + 0 + ] + ) == 1 + return ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + data, + ) + + available_parts: torch.Tensor = get_parts(raw_data_path, experiment_id, trial_id) + if available_parts.shape[0] < 1: + mylogger.info("ERROR: could not find any part files") + assert available_parts.shape[0] >= 1 + + experiment_name = f"Exp{experiment_id:03d}_Trial{trial_id:03d}" + mylogger.info(f"Will work on: {experiment_name}") + + mylogger.info(f"We found {int(available_parts.shape[0])} parts.") + + first_run: bool = True + + mylogger.info("Compare meta data of all parts") + for id in range(0, available_parts.shape[0]): + part_id = available_parts[id] + + filename_meta: str = os.path.join( + raw_data_path, + f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}_meta.txt", + ) + + if os.path.isfile(filename_meta) is False: + mylogger.info(f"Could not load meta data... {filename_meta}") + assert os.path.isfile(filename_meta) + return ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + data, + ) + + ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + ) = load_meta_data( + mylogger=mylogger, filename_meta=filename_meta, silent_mode=True + ) + + if first_run: + first_run = False + master_meta_channels: list[str] = copy.deepcopy(meta_channels) + master_meta_mouse_markings: str = meta_mouse_markings + master_meta_recording_date: str = meta_recording_date + master_meta_stimulation_times: dict = copy.deepcopy(meta_stimulation_times) + master_meta_experiment_names: dict = copy.deepcopy(meta_experiment_names) + master_meta_trial_recording_duration: float = meta_trial_recording_duration + master_meta_frame_time: float = meta_frame_time + master_meta_mouse: str = meta_mouse + + meta_channels_check = master_meta_channels == meta_channels + + # Check channel order + if meta_channels_check: + for channel_a, channel_b in zip(master_meta_channels, meta_channels): + if channel_a != channel_b: + meta_channels_check = False + + meta_mouse_markings_check = master_meta_mouse_markings == meta_mouse_markings + meta_recording_date_check = master_meta_recording_date == meta_recording_date + meta_stimulation_times_check = ( + master_meta_stimulation_times == meta_stimulation_times + ) + meta_experiment_names_check = ( + master_meta_experiment_names == meta_experiment_names + ) + meta_trial_recording_duration_check = ( + master_meta_trial_recording_duration == meta_trial_recording_duration + ) + meta_frame_time_check = master_meta_frame_time == meta_frame_time + meta_mouse_check = master_meta_mouse == meta_mouse + + if meta_channels_check is False: + mylogger.info(f"{filename_meta} failed: channels") + assert meta_channels_check + + if meta_mouse_markings_check is False: + mylogger.info(f"{filename_meta} failed: mouse_markings") + assert meta_mouse_markings_check + + if meta_recording_date_check is False: + mylogger.info(f"{filename_meta} failed: recording_date") + assert meta_recording_date_check + + if meta_stimulation_times_check is False: + mylogger.info(f"{filename_meta} failed: stimulation_times") + assert meta_stimulation_times_check + + if meta_experiment_names_check is False: + mylogger.info(f"{filename_meta} failed: experiment_names") + assert meta_experiment_names_check + + if meta_trial_recording_duration_check is False: + mylogger.info(f"{filename_meta} failed: trial_recording_duration") + assert meta_trial_recording_duration_check + + if meta_frame_time_check is False: + mylogger.info(f"{filename_meta} failed: frame_time_check") + assert meta_frame_time_check + + if meta_mouse_check is False: + mylogger.info(f"{filename_meta} failed: mouse") + assert meta_mouse_check + mylogger.info("-==- Done -==-") + + mylogger.info(f"Will use: {filename_meta} for meta data") + ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + ) = load_meta_data(mylogger=mylogger, filename_meta=filename_meta) + + ################# + # Meta data end # + ################# + + first_run = True + mylogger.info("Count the number of frames in the data of all parts") + frame_count: int = 0 + for id in range(0, available_parts.shape[0]): + part_id = available_parts[id] + + filename_data: str = os.path.join( + raw_data_path, + f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}.npy", + ) + + if os.path.isfile(filename_data) is False: + mylogger.info(f"Could not load data... {filename_data}") + assert os.path.isfile(filename_data) + return ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + data, + ) + data_np: np.ndarray = np.load(filename_data, mmap_mode="r") + + if data_np.ndim != 4: + mylogger.info(f"ERROR: Data needs to have 4 dimensions {filename_data}") + assert data_np.ndim == 4 + + if first_run: + first_run = False + dim_0: int = int(data_np.shape[0]) + dim_1: int = int(data_np.shape[1]) + dim_3: int = int(data_np.shape[3]) + + frame_count += int(data_np.shape[2]) + + if int(data_np.shape[0]) != dim_0: + mylogger.info( + f"ERROR: Data dim 0 is broken {int(data_np.shape[0])} vs {dim_0} {filename_data}" + ) + assert int(data_np.shape[0]) == dim_0 + + if int(data_np.shape[1]) != dim_1: + mylogger.info( + f"ERROR: Data dim 1 is broken {int(data_np.shape[1])} vs {dim_1} {filename_data}" + ) + assert int(data_np.shape[1]) == dim_1 + + if int(data_np.shape[3]) != dim_3: + mylogger.info( + f"ERROR: Data dim 3 is broken {int(data_np.shape[3])} vs {dim_3} {filename_data}" + ) + assert int(data_np.shape[3]) == dim_3 + + mylogger.info( + f"{filename_data}: {int(data_np.shape[2])} frames -> {frame_count} frames total" + ) + + if force_to_cpu_memory: + mylogger.info("Using CPU memory for data") + data = torch.empty( + (dim_0, dim_1, frame_count, dim_3), dtype=dtype, device=torch.device("cpu") + ) + else: + mylogger.info("Using GPU memory for data") + data = torch.empty( + (dim_0, dim_1, frame_count, dim_3), dtype=dtype, device=device + ) + + start_position: int = 0 + end_position: int = 0 + for id in range(0, available_parts.shape[0]): + part_id = available_parts[id] + + filename_data = os.path.join( + raw_data_path, + f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}.npy", + ) + + mylogger.info(f"Will work on {filename_data}") + mylogger.info("Loading data file") + data_np = np.load(filename_data).astype(dtype_np) + + end_position = start_position + int(data_np.shape[2]) + + for i in range(0, len(config["required_order"])): + mylogger.info(f"Move raw data channel: {config['required_order'][i]}") + + idx = meta_channels.index(config["required_order"][i]) + data[..., start_position:end_position, i] = torch.tensor( + data_np[..., idx], dtype=dtype, device=data.device + ) + start_position = end_position + + if start_position != int(data.shape[2]): + mylogger.info("ERROR: data was not fulled fully!!!") + assert start_position == int(data.shape[2]) + + mylogger.info("-==- Done -==-") + + ################# + # Raw data end # + ################# + + return ( + meta_channels, + meta_mouse_markings, + meta_recording_date, + meta_stimulation_times, + meta_experiment_names, + meta_trial_recording_duration, + meta_frame_time, + meta_mouse, + data, + ) diff --git a/functions/gauss_smear_individual.py b/functions/gauss_smear_individual.py new file mode 100644 index 0000000..36700e7 --- /dev/null +++ b/functions/gauss_smear_individual.py @@ -0,0 +1,127 @@ +import torch +import math + + +@torch.no_grad() +def gauss_smear_individual( + input: torch.Tensor, + spatial_width: float, + temporal_width: float, + overwrite_fft_gauss: None | torch.Tensor = None, + use_matlab_mask: bool = True, + epsilon: float = float(torch.finfo(torch.float64).eps), +) -> tuple[torch.Tensor, torch.Tensor]: + + dim_x: int = int(2 * math.ceil(2 * spatial_width) + 1) + dim_y: int = int(2 * math.ceil(2 * spatial_width) + 1) + dim_t: int = int(2 * math.ceil(2 * temporal_width) + 1) + dims_xyt: torch.Tensor = torch.tensor( + [dim_x, dim_y, dim_t], dtype=torch.int64, device=input.device + ) + + if input.ndim == 2: + input = input.unsqueeze(-1) + + input_padded = torch.nn.functional.pad( + input.unsqueeze(0), + pad=( + dim_t, + dim_t, + dim_y, + dim_y, + dim_x, + dim_x, + ), + mode="replicate", + ).squeeze(0) + + if overwrite_fft_gauss is None: + center_x: int = int(math.ceil(input_padded.shape[0] / 2)) + center_y: int = int(math.ceil(input_padded.shape[1] / 2)) + center_z: int = int(math.ceil(input_padded.shape[2] / 2)) + grid_x: torch.Tensor = ( + torch.arange(0, input_padded.shape[0], device=input.device) - center_x + 1 + ) + grid_y: torch.Tensor = ( + torch.arange(0, input_padded.shape[1], device=input.device) - center_y + 1 + ) + grid_z: torch.Tensor = ( + torch.arange(0, input_padded.shape[2], device=input.device) - center_z + 1 + ) + + grid_x = grid_x.unsqueeze(-1).unsqueeze(-1) ** 2 + grid_y = grid_y.unsqueeze(0).unsqueeze(-1) ** 2 + grid_z = grid_z.unsqueeze(0).unsqueeze(0) ** 2 + + gauss_kernel: torch.Tensor = ( + (grid_x / (spatial_width**2)) + + (grid_y / (spatial_width**2)) + + (grid_z / (temporal_width**2)) + ) + + if use_matlab_mask: + filter_radius: torch.Tensor = (dims_xyt - 1) // 2 + + border_lower: list[int] = [ + center_x - int(filter_radius[0]) - 1, + center_y - int(filter_radius[1]) - 1, + center_z - int(filter_radius[2]) - 1, + ] + + border_upper: list[int] = [ + center_x + int(filter_radius[0]), + center_y + int(filter_radius[1]), + center_z + int(filter_radius[2]), + ] + + matlab_mask: torch.Tensor = torch.zeros_like(gauss_kernel) + matlab_mask[ + border_lower[0] : border_upper[0], + border_lower[1] : border_upper[1], + border_lower[2] : border_upper[2], + ] = 1.0 + + gauss_kernel = torch.exp(-gauss_kernel / 2.0) + if use_matlab_mask: + gauss_kernel = gauss_kernel * matlab_mask + + gauss_kernel[gauss_kernel < (epsilon * gauss_kernel.max())] = 0 + + sum_gauss_kernel: float = float(gauss_kernel.sum()) + + if sum_gauss_kernel != 0.0: + gauss_kernel = gauss_kernel / sum_gauss_kernel + + # FFT Shift + gauss_kernel = torch.cat( + (gauss_kernel[center_x - 1 :, :, :], gauss_kernel[: center_x - 1, :, :]), + dim=0, + ) + gauss_kernel = torch.cat( + (gauss_kernel[:, center_y - 1 :, :], gauss_kernel[:, : center_y - 1, :]), + dim=1, + ) + gauss_kernel = torch.cat( + (gauss_kernel[:, :, center_z - 1 :], gauss_kernel[:, :, : center_z - 1]), + dim=2, + ) + overwrite_fft_gauss = torch.fft.fftn(gauss_kernel) + input_padded_gauss_filtered: torch.Tensor = torch.real( + torch.fft.ifftn(torch.fft.fftn(input_padded) * overwrite_fft_gauss) + ) + else: + input_padded_gauss_filtered = torch.real( + torch.fft.ifftn(torch.fft.fftn(input_padded) * overwrite_fft_gauss) + ) + + start = dims_xyt + stop = ( + torch.tensor(input_padded.shape, device=dims_xyt.device, dtype=dims_xyt.dtype) + - dims_xyt + ) + + output = input_padded_gauss_filtered[ + start[0] : stop[0], start[1] : stop[1], start[2] : stop[2] + ] + + return (output, overwrite_fft_gauss) diff --git a/functions/get_experiments.py b/functions/get_experiments.py new file mode 100644 index 0000000..d92b936 --- /dev/null +++ b/functions/get_experiments.py @@ -0,0 +1,19 @@ +import torch +import os +import glob + + +@torch.no_grad() +def get_experiments(path: str) -> torch.Tensor: + filename_np: str = os.path.join( + path, + "Exp*_Part001.npy", + ) + + list_str = glob.glob(filename_np) + list_int: list[int] = [] + for i in range(0, len(list_str)): + list_int.append(int(list_str[i].split("Exp")[-1].split("_Trial")[0])) + list_int = sorted(list_int) + + return torch.tensor(list_int).unique() diff --git a/functions/get_parts.py b/functions/get_parts.py new file mode 100644 index 0000000..d68e1ae --- /dev/null +++ b/functions/get_parts.py @@ -0,0 +1,18 @@ +import torch +import os +import glob + + +@torch.no_grad() +def get_parts(path: str, experiment_id: int, trial_id: int) -> torch.Tensor: + filename_np: str = os.path.join( + path, + f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part*.npy", + ) + + list_str = glob.glob(filename_np) + list_int: list[int] = [] + for i in range(0, len(list_str)): + list_int.append(int(list_str[i].split("_Part")[-1].split(".npy")[0])) + list_int = sorted(list_int) + return torch.tensor(list_int).unique() diff --git a/functions/get_torch_device.py b/functions/get_torch_device.py new file mode 100644 index 0000000..9eec5e9 --- /dev/null +++ b/functions/get_torch_device.py @@ -0,0 +1,17 @@ +import torch +import logging + + +def get_torch_device(mylogger: logging.Logger, force_to_cpu: bool) -> torch.device: + + if torch.cuda.is_available(): + device_name: str = "cuda:0" + else: + device_name = "cpu" + + if force_to_cpu: + device_name = "cpu" + + mylogger.info(f"Using device: {device_name}") + device: torch.device = torch.device(device_name) + return device diff --git a/functions/get_trials.py b/functions/get_trials.py new file mode 100644 index 0000000..8c687d9 --- /dev/null +++ b/functions/get_trials.py @@ -0,0 +1,18 @@ +import torch +import os +import glob + + +@torch.no_grad() +def get_trials(path: str, experiment_id: int) -> torch.Tensor: + filename_np: str = os.path.join( + path, + f"Exp{experiment_id:03d}_Trial*_Part001.npy", + ) + + list_str = glob.glob(filename_np) + list_int: list[int] = [] + for i in range(0, len(list_str)): + list_int.append(int(list_str[i].split("_Trial")[-1].split("_Part")[0])) + list_int = sorted(list_int) + return torch.tensor(list_int).unique() diff --git a/functions/load_config.py b/functions/load_config.py new file mode 100644 index 0000000..c17fa40 --- /dev/null +++ b/functions/load_config.py @@ -0,0 +1,16 @@ +import json +import os +import logging + +from jsmin import jsmin # type:ignore + + +def load_config(mylogger: logging.Logger, filename: str = "config.json") -> dict: + mylogger.info("loading config file") + if os.path.isfile(filename) is False: + mylogger.info(f"{filename} is missing") + + with open(filename, "r") as file: + config = json.loads(jsmin(file.read())) + + return config diff --git a/functions/load_meta_data.py b/functions/load_meta_data.py new file mode 100644 index 0000000..2a893e0 --- /dev/null +++ b/functions/load_meta_data.py @@ -0,0 +1,63 @@ +import logging +import json + + +def load_meta_data( + mylogger: logging.Logger, filename_meta: str, silent_mode=False +) -> tuple[list[str], str, str, dict, dict, float, float, str]: + + if silent_mode is False: + mylogger.info("Loading meta data") + with open(filename_meta, "r") as file_handle: + metadata: dict = json.load(file_handle) + + channels: list[str] = metadata["channelKey"] + + if silent_mode is False: + mylogger.info(f"meta data: channel order: {channels}") + + mouse_markings: str = metadata["sessionMetaData"]["mouseMarkings"] + if silent_mode is False: + mylogger.info(f"meta data: mouse markings: {mouse_markings}") + + recording_date: str = metadata["sessionMetaData"]["date"] + if silent_mode is False: + mylogger.info(f"meta data: recording data: {recording_date}") + + stimulation_times: dict = metadata["sessionMetaData"]["stimulationTimes"] + if silent_mode is False: + mylogger.info(f"meta data: stimulation times: {stimulation_times}") + + experiment_names: dict = metadata["sessionMetaData"]["experimentNames"] + if silent_mode is False: + mylogger.info(f"meta data: experiment names: {experiment_names}") + + trial_recording_duration: float = float( + metadata["sessionMetaData"]["trialRecordingDuration"] + ) + if silent_mode is False: + mylogger.info( + f"meta data: trial recording duration: {trial_recording_duration} sec" + ) + + frame_time: float = float(metadata["sessionMetaData"]["frameTime"]) + if silent_mode is False: + mylogger.info( + f"meta data: frame time: {frame_time} sec ; frame rate: {1.0/frame_time}Hz" + ) + + mouse: str = metadata["sessionMetaData"]["mouse"] + if silent_mode is False: + mylogger.info(f"meta data: mouse: {mouse}") + mylogger.info("-==- Done -==-") + + return ( + channels, + mouse_markings, + recording_date, + stimulation_times, + experiment_names, + trial_recording_duration, + frame_time, + mouse, + ) diff --git a/functions/perform_donor_volume_rotation.py b/functions/perform_donor_volume_rotation.py new file mode 100644 index 0000000..590630f --- /dev/null +++ b/functions/perform_donor_volume_rotation.py @@ -0,0 +1,140 @@ +import torch +import torchvision as tv # type: ignore +import logging +from functions.calculate_rotation import calculate_rotation +from functions.ImageAlignment import ImageAlignment + + +@torch.no_grad() +def perform_donor_volume_rotation( + mylogger: logging.Logger, + acceptor: torch.Tensor, + donor: torch.Tensor, + oxygenation: torch.Tensor, + volume: torch.Tensor, + ref_image_donor: torch.Tensor, + ref_image_volume: torch.Tensor, + image_alignment: ImageAlignment, + batch_size: int, + config: dict, + fill_value: float = 0, +) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + + mylogger.info("Calculate rotation between donor data and donor ref image") + + angle_donor = calculate_rotation( + input=donor, + reference_image=ref_image_donor, + image_alignment=image_alignment, + batch_size=batch_size, + ) + + mylogger.info("Calculate rotation between volume data and volume ref image") + angle_volume = calculate_rotation( + input=volume, + reference_image=ref_image_volume, + image_alignment=image_alignment, + batch_size=batch_size, + ) + + mylogger.info("Average over both rotations") + + donor_threshold: torch.Tensor = torch.sort(torch.abs(angle_donor))[0] + donor_threshold = donor_threshold[ + int( + donor_threshold.shape[0] + * float(config["rotation_stabilization_threshold_border"]) + ) + ] * float(config["rotation_stabilization_threshold_factor"]) + + volume_threshold: torch.Tensor = torch.sort(torch.abs(angle_volume))[0] + volume_threshold = volume_threshold[ + int( + volume_threshold.shape[0] + * float(config["rotation_stabilization_threshold_border"]) + ) + ] * float(config["rotation_stabilization_threshold_factor"]) + + donor_idx = torch.where(torch.abs(angle_donor) > donor_threshold)[0] + volume_idx = torch.where(torch.abs(angle_volume) > volume_threshold)[0] + mylogger.info( + f"Border: {config['rotation_stabilization_threshold_border']}, " + f"factor {config['rotation_stabilization_threshold_factor']} " + ) + mylogger.info( + f"Donor threshold: {donor_threshold:.3e}, volume threshold: {volume_threshold:.3e}" + ) + mylogger.info( + f"Found broken rotation values: " + f"donor {int(donor_idx.shape[0])}, " + f"volume {int(volume_idx.shape[0])}" + ) + angle_donor[donor_idx] = angle_volume[donor_idx] + angle_volume[volume_idx] = angle_donor[volume_idx] + + donor_idx = torch.where(torch.abs(angle_donor) > donor_threshold)[0] + volume_idx = torch.where(torch.abs(angle_volume) > volume_threshold)[0] + mylogger.info( + f"After fill in these broken rotation values remain: " + f"donor {int(donor_idx.shape[0])}, " + f"volume {int(volume_idx.shape[0])}" + ) + angle_donor[donor_idx] = 0.0 + angle_volume[volume_idx] = 0.0 + angle_donor_volume = (angle_donor + angle_volume) / 2.0 + + mylogger.info("Rotate acceptor data based on the average rotation") + for frame_id in range(0, angle_donor_volume.shape[0]): + acceptor[frame_id, ...] = tv.transforms.functional.affine( + img=acceptor[frame_id, ...].unsqueeze(0), + angle=-float(angle_donor_volume[frame_id]), + translate=[0, 0], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=fill_value, + ).squeeze(0) + + mylogger.info("Rotate donor data based on the average rotation") + for frame_id in range(0, angle_donor_volume.shape[0]): + donor[frame_id, ...] = tv.transforms.functional.affine( + img=donor[frame_id, ...].unsqueeze(0), + angle=-float(angle_donor_volume[frame_id]), + translate=[0, 0], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=fill_value, + ).squeeze(0) + + mylogger.info("Rotate oxygenation data based on the average rotation") + for frame_id in range(0, angle_donor_volume.shape[0]): + oxygenation[frame_id, ...] = tv.transforms.functional.affine( + img=oxygenation[frame_id, ...].unsqueeze(0), + angle=-float(angle_donor_volume[frame_id]), + translate=[0, 0], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=fill_value, + ).squeeze(0) + + mylogger.info("Rotate volume data based on the average rotation") + for frame_id in range(0, angle_donor_volume.shape[0]): + volume[frame_id, ...] = tv.transforms.functional.affine( + img=volume[frame_id, ...].unsqueeze(0), + angle=-float(angle_donor_volume[frame_id]), + translate=[0, 0], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=fill_value, + ).squeeze(0) + + return (acceptor, donor, oxygenation, volume, angle_donor_volume) diff --git a/functions/perform_donor_volume_translation.py b/functions/perform_donor_volume_translation.py new file mode 100644 index 0000000..7091add --- /dev/null +++ b/functions/perform_donor_volume_translation.py @@ -0,0 +1,143 @@ +import torch +import torchvision as tv # type: ignore +import logging + +from functions.calculate_translation import calculate_translation +from functions.ImageAlignment import ImageAlignment + + +@torch.no_grad() +def perform_donor_volume_translation( + mylogger: logging.Logger, + acceptor: torch.Tensor, + donor: torch.Tensor, + oxygenation: torch.Tensor, + volume: torch.Tensor, + ref_image_donor: torch.Tensor, + ref_image_volume: torch.Tensor, + image_alignment: ImageAlignment, + batch_size: int, + config: dict, + fill_value: float = 0, +) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + + mylogger.info("Calculate translation between donor data and donor ref image") + tvec_donor = calculate_translation( + input=donor, + reference_image=ref_image_donor, + image_alignment=image_alignment, + batch_size=batch_size, + ) + + mylogger.info("Calculate translation between volume data and volume ref image") + tvec_volume = calculate_translation( + input=volume, + reference_image=ref_image_volume, + image_alignment=image_alignment, + batch_size=batch_size, + ) + + mylogger.info("Average over both translations") + + for i in range(0, 2): + mylogger.info(f"Processing dimension {i}") + donor_threshold: torch.Tensor = torch.sort(torch.abs(tvec_donor[:, i]))[0] + donor_threshold = donor_threshold[ + int( + donor_threshold.shape[0] + * float(config["rotation_stabilization_threshold_border"]) + ) + ] * float(config["rotation_stabilization_threshold_factor"]) + + volume_threshold: torch.Tensor = torch.sort(torch.abs(tvec_volume[:, i]))[0] + volume_threshold = volume_threshold[ + int( + volume_threshold.shape[0] + * float(config["rotation_stabilization_threshold_border"]) + ) + ] * float(config["rotation_stabilization_threshold_factor"]) + + donor_idx = torch.where(torch.abs(tvec_donor[:, i]) > donor_threshold)[0] + volume_idx = torch.where(torch.abs(tvec_volume[:, i]) > volume_threshold)[0] + mylogger.info( + f"Border: {config['rotation_stabilization_threshold_border']}, " + f"factor {config['rotation_stabilization_threshold_factor']} " + ) + mylogger.info( + f"Donor threshold: {donor_threshold:.3e}, volume threshold: {volume_threshold:.3e}" + ) + mylogger.info( + f"Found broken rotation values: " + f"donor {int(donor_idx.shape[0])}, " + f"volume {int(volume_idx.shape[0])}" + ) + tvec_donor[donor_idx, i] = tvec_volume[donor_idx, i] + tvec_volume[volume_idx, i] = tvec_donor[volume_idx, i] + + donor_idx = torch.where(torch.abs(tvec_donor[:, i]) > donor_threshold)[0] + volume_idx = torch.where(torch.abs(tvec_volume[:, i]) > volume_threshold)[0] + mylogger.info( + f"After fill in these broken rotation values remain: " + f"donor {int(donor_idx.shape[0])}, " + f"volume {int(volume_idx.shape[0])}" + ) + tvec_donor[donor_idx, i] = 0.0 + tvec_volume[volume_idx, i] = 0.0 + + tvec_donor_volume = (tvec_donor + tvec_volume) / 2.0 + + mylogger.info("Translate acceptor data based on the average translation vector") + for frame_id in range(0, tvec_donor_volume.shape[0]): + acceptor[frame_id, ...] = tv.transforms.functional.affine( + img=acceptor[frame_id, ...].unsqueeze(0), + angle=0, + translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=fill_value, + ).squeeze(0) + + mylogger.info("Translate donor data based on the average translation vector") + for frame_id in range(0, tvec_donor_volume.shape[0]): + donor[frame_id, ...] = tv.transforms.functional.affine( + img=donor[frame_id, ...].unsqueeze(0), + angle=0, + translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=fill_value, + ).squeeze(0) + + mylogger.info("Translate oxygenation data based on the average translation vector") + for frame_id in range(0, tvec_donor_volume.shape[0]): + oxygenation[frame_id, ...] = tv.transforms.functional.affine( + img=oxygenation[frame_id, ...].unsqueeze(0), + angle=0, + translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=fill_value, + ).squeeze(0) + + mylogger.info("Translate volume data based on the average translation vector") + for frame_id in range(0, tvec_donor_volume.shape[0]): + volume[frame_id, ...] = tv.transforms.functional.affine( + img=volume[frame_id, ...].unsqueeze(0), + angle=0, + translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=fill_value, + ).squeeze(0) + + return (acceptor, donor, oxygenation, volume, tvec_donor_volume) diff --git a/functions/regression.py b/functions/regression.py new file mode 100644 index 0000000..d4efac0 --- /dev/null +++ b/functions/regression.py @@ -0,0 +1,117 @@ +import torch +import logging +from functions.regression_internal import regression_internal + + +@torch.no_grad() +def regression( + mylogger: logging.Logger, + target_camera_id: int, + regressor_camera_ids: list[int], + mask: torch.Tensor, + data: torch.Tensor, + data_filtered: torch.Tensor, + first_none_ramp_frame: int, +) -> tuple[torch.Tensor, torch.Tensor]: + + assert len(regressor_camera_ids) > 0 + + mylogger.info("Prepare the target signal - 1.0 (from data_filtered)") + target_signals_train: torch.Tensor = ( + data_filtered[target_camera_id, ..., first_none_ramp_frame:].clone() - 1.0 + ) + target_signals_train[target_signals_train < -1] = 0.0 + + # Check if everything is happy + assert target_signals_train.ndim == 3 + assert target_signals_train.ndim == data[target_camera_id, ...].ndim + assert target_signals_train.shape[0] == data[target_camera_id, ...].shape[0] + assert target_signals_train.shape[1] == data[target_camera_id, ...].shape[1] + assert (target_signals_train.shape[2] + first_none_ramp_frame) == data[ + target_camera_id, ... + ].shape[2] + + mylogger.info("Prepare the regressor signals (linear plus from data_filtered)") + + regressor_signals_train: torch.Tensor = torch.zeros( + ( + data_filtered.shape[1], + data_filtered.shape[2], + data_filtered.shape[3], + len(regressor_camera_ids) + 1, + ), + device=data_filtered.device, + dtype=data_filtered.dtype, + ) + + mylogger.info("Copy the regressor signals - 1.0") + for matrix_id, id in enumerate(regressor_camera_ids): + regressor_signals_train[..., matrix_id] = data_filtered[id, ...] - 1.0 + + regressor_signals_train[regressor_signals_train < -1] = 0.0 + + mylogger.info("Create the linear regressor") + trend = torch.arange( + 0, regressor_signals_train.shape[-2], device=data_filtered.device + ) / float(regressor_signals_train.shape[-2] - 1) + trend -= trend.mean() + trend = trend.unsqueeze(0).unsqueeze(0) + trend = trend.tile( + (regressor_signals_train.shape[0], regressor_signals_train.shape[1], 1) + ) + regressor_signals_train[..., -1] = trend + + regressor_signals_train = regressor_signals_train[:, :, first_none_ramp_frame:, :] + + mylogger.info("Calculating the regression coefficients") + coefficients, intercept = regression_internal( + input_regressor=regressor_signals_train, input_target=target_signals_train + ) + del regressor_signals_train + del target_signals_train + + mylogger.info("Prepare the target signal - 1.0 (from data)") + target_signals_perform: torch.Tensor = data[target_camera_id, ...].clone() - 1.0 + + mylogger.info("Prepare the regressor signals (linear plus from data)") + regressor_signals_perform: torch.Tensor = torch.zeros( + ( + data.shape[1], + data.shape[2], + data.shape[3], + len(regressor_camera_ids) + 1, + ), + device=data.device, + dtype=data.dtype, + ) + + mylogger.info("Copy the regressor signals - 1.0 ") + for matrix_id, id in enumerate(regressor_camera_ids): + regressor_signals_perform[..., matrix_id] = data[id] - 1.0 + + mylogger.info("Create the linear regressor") + trend = torch.arange( + 0, regressor_signals_perform.shape[-2], device=data[0].device + ) / float(regressor_signals_perform.shape[-2] - 1) + trend -= trend.mean() + trend = trend.unsqueeze(0).unsqueeze(0) + trend = trend.tile( + (regressor_signals_perform.shape[0], regressor_signals_perform.shape[1], 1) + ) + regressor_signals_perform[..., -1] = trend + + mylogger.info("Remove regressors") + target_signals_perform -= ( + regressor_signals_perform * coefficients.unsqueeze(-2) + ).sum(dim=-1) + + mylogger.info("Remove offset") + target_signals_perform -= intercept.unsqueeze(-1) + + mylogger.info("Remove masked pixels") + target_signals_perform[mask, :] = 0.0 + + mylogger.info("Add an offset of 1.0") + target_signals_perform += 1.0 + + return target_signals_perform, coefficients diff --git a/functions/regression_internal.py b/functions/regression_internal.py new file mode 100644 index 0000000..352d7ba --- /dev/null +++ b/functions/regression_internal.py @@ -0,0 +1,20 @@ +import torch + + +def regression_internal( + input_regressor: torch.Tensor, input_target: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + + regressor_offset = input_regressor.mean(keepdim=True, dim=-2) + target_offset = input_target.mean(keepdim=True, dim=-1) + + regressor = input_regressor - regressor_offset + target = input_target - target_offset + + coefficients, _, _, _ = torch.linalg.lstsq(regressor, target, rcond=None) # None ? + + intercept = target_offset.squeeze(-1) - ( + coefficients * regressor_offset.squeeze(-2) + ).sum(dim=-1) + + return coefficients, intercept