From b37a79f487cc8737eb7b40d95dec2621ee1b682c Mon Sep 17 00:00:00 2001 From: David Rotermund <54365609+davrot@users.noreply.github.com> Date: Wed, 12 Jul 2023 14:02:35 +0200 Subject: [PATCH] Add files via upload --- 0_convert_avi_to_npy.py | 7 + Anime.py | 90 ++++ ImageAlignment.py | 1010 +++++++++++++++++++++++++++++++++++++++ run_svd.py | 57 +++ svd.py | 204 ++++++++ 5 files changed, 1368 insertions(+) create mode 100644 0_convert_avi_to_npy.py create mode 100644 Anime.py create mode 100644 ImageAlignment.py create mode 100644 run_svd.py create mode 100644 svd.py diff --git a/0_convert_avi_to_npy.py b/0_convert_avi_to_npy.py new file mode 100644 index 0000000..a9b32ec --- /dev/null +++ b/0_convert_avi_to_npy.py @@ -0,0 +1,7 @@ +from svd import convert_avi_to_npy + + +if __name__ == "__main__": + # Convert from avi to npy + filename: str = "example_data_crop" + convert_avi_to_npy(filename) diff --git a/Anime.py b/Anime.py new file mode 100644 index 0000000..628624c --- /dev/null +++ b/Anime.py @@ -0,0 +1,90 @@ +import numpy as np +import torch +import matplotlib.pyplot as plt +import matplotlib.animation + + +class Anime: + def __init__(self) -> None: + super().__init__() + + def show( + self, + input: torch.Tensor | np.ndarray, + mask: torch.Tensor | np.ndarray | None = None, + vmin: float | None = None, + vmax: float | None = None, + cmap: str = "hot", + axis_off: bool = True, + show_frame_count: bool = True, + interval: int = 100, + repeat: bool = False, + colorbar: bool = True, + vmin_scale: float | None = None, + vmax_scale: float | None = None, + ) -> None: + assert input.ndim == 3 + + if isinstance(input, torch.Tensor): + input_np: np.ndarray = input.cpu().numpy() + if mask is not None: + mask_np: np.ndarray | None = (mask == 0).cpu().numpy() + else: + mask_np = None + else: + input_np = input + if mask is not None: + mask_np = mask == 0 # type: ignore + else: + mask_np = None + + if vmin is None: + vmin = float(np.where(np.isfinite(input_np), input_np, 0.0).min()) + if vmax is None: + vmax = float(np.where(np.isfinite(input_np), input_np, 0.0).max()) + + if vmin_scale is not None: + vmin *= vmin_scale + + if vmax_scale is not None: + vmax *= vmax_scale + + fig = plt.figure() + image = np.nan_to_num(input_np[0, ...], copy=True, nan=0.0) + if mask_np is not None: + image[mask_np] = float("NaN") + image_handle = plt.imshow( + image, + cmap=cmap, + vmin=vmin, + vmax=vmax, + ) + + if colorbar: + plt.colorbar() + + if axis_off: + plt.axis("off") + + def next_frame(i: int) -> None: + image = np.nan_to_num(input_np[i, ...], copy=True, nan=0.0) + if mask_np is not None: + image[mask_np] = float("NaN") + + image_handle.set_data(image) + if show_frame_count: + bar_length: int = 10 + filled_length = int(round(bar_length * i / input_np.shape[0])) + bar = "\u25A0" * filled_length + "\u25A1" * (bar_length - filled_length) + plt.title(f"{bar} {i} of {int(input_np.shape[0]-1)}", loc="left") + return + + _ = matplotlib.animation.FuncAnimation( + fig, + next_frame, + frames=int(input.shape[0]), + interval=interval, + repeat=repeat, + ) + + plt.show() diff --git a/ImageAlignment.py b/ImageAlignment.py new file mode 100644 index 0000000..ab483b3 --- /dev/null +++ b/ImageAlignment.py @@ -0,0 +1,1010 @@ +import torch +import torchvision as tv + +# The source code is based on: +# https://github.com/matejak/imreg_dft + +# The original LICENSE: +# Copyright (c) 2014, Matěj Týč +# Copyright (c) 2011-2014, Christoph Gohlke +# Copyright (c) 2011-2014, The Regents of the University of California +# Produced at the Laboratory for Fluorescence Dynamics + +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# * Neither the name of the {organization} nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +class ImageAlignment(torch.nn.Module): + device: torch.device + default_dtype: torch.dtype + excess_const: float = 1.1 + exponent: str = "inf" + success: torch.Tensor | None = None + + # The factor that detmines how many + # sub-pixel we will shift + scale_factor: int = 4 + + reference_image: torch.Tensor | None = None + + last_scale: torch.Tensor | None = None + last_angle: torch.Tensor | None = None + last_tvec: torch.Tensor | None = None + + # Cache + image_reference_dft: torch.Tensor | None = None + filt: torch.Tensor + pcorr_shape: torch.Tensor + log_base: torch.Tensor + image_reference_logp: torch.Tensor + + def __init__( + self, + device: torch.device | None = None, + default_dtype: torch.dtype | None = None, + ) -> None: + super().__init__() + + assert device is not None + assert default_dtype is not None + self.device = device + self.default_dtype = default_dtype + + def set_new_reference_image(self, new_reference_image: torch.Tensor | None = None): + assert new_reference_image is not None + assert new_reference_image.ndim == 2 + self.reference_image = ( + new_reference_image.detach() + .clone() + .to(device=self.device) + .type(dtype=self.default_dtype) + ) + self.image_reference_dft = None + + def forward( + self, input: torch.Tensor, new_reference_image: torch.Tensor | None = None + ) -> torch.Tensor: + assert input.ndim == 3 + + if new_reference_image is not None: + self.set_new_reference_image(new_reference_image) + + assert self.reference_image is not None + assert self.reference_image.ndim == 2 + assert input.shape[-2] == self.reference_image.shape[-2] + assert input.shape[-1] == self.reference_image.shape[-1] + + self.last_scale, self.last_angle, self.last_tvec, output = self.similarity( + self.reference_image, + input.to(device=self.device).type(dtype=self.default_dtype), + ) + + return output + + def dry_run( + self, input: torch.Tensor, new_reference_image: torch.Tensor | None = None + ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: + assert input.ndim == 3 + + if new_reference_image is not None: + self.set_new_reference_image(new_reference_image) + + assert self.reference_image is not None + assert self.reference_image.ndim == 2 + assert input.shape[-2] == self.reference_image.shape[-2] + assert input.shape[-1] == self.reference_image.shape[-1] + + images_todo = input.to(device=self.device).type(dtype=self.default_dtype) + image_reference = self.reference_image + + assert image_reference.ndim == 2 + assert images_todo.ndim == 3 + + bgval: torch.Tensor = self.get_borderval(img=images_todo, radius=5) + + self.last_scale, self.last_angle, self.last_tvec = self._similarity( + image_reference, + images_todo, + bgval, + ) + + return self.last_scale, self.last_angle, self.last_tvec + + def dry_run_translation( + self, input: torch.Tensor, new_reference_image: torch.Tensor | None = None + ) -> torch.Tensor: + assert input.ndim == 3 + + if new_reference_image is not None: + self.set_new_reference_image(new_reference_image) + + assert self.reference_image is not None + assert self.reference_image.ndim == 2 + assert input.shape[-2] == self.reference_image.shape[-2] + assert input.shape[-1] == self.reference_image.shape[-1] + + images_todo = input.to(device=self.device).type(dtype=self.default_dtype) + image_reference = self.reference_image + + assert image_reference.ndim == 2 + assert images_todo.ndim == 3 + + tvec, _ = self._translation(image_reference, images_todo) + + return tvec + + # --------------- + + def dry_run_angle( + self, + input: torch.Tensor, + new_reference_image: torch.Tensor | None = None, + ) -> torch.Tensor: + assert input.ndim == 3 + + if new_reference_image is not None: + self.set_new_reference_image(new_reference_image) + + constraints_dynamic_angle_0: torch.Tensor = torch.zeros( + (input.shape[0]), dtype=self.default_dtype, device=self.device + ) + constraints_dynamic_angle_1: torch.Tensor | None = None + constraints_dynamic_scale_0: torch.Tensor = torch.ones( + (input.shape[0]), dtype=self.default_dtype, device=self.device + ) + constraints_dynamic_scale_1: torch.Tensor | None = None + + assert self.reference_image is not None + assert self.reference_image.ndim == 2 + assert input.shape[-2] == self.reference_image.shape[-2] + assert input.shape[-1] == self.reference_image.shape[-1] + + images_todo = input.to(device=self.device).type(dtype=self.default_dtype) + image_reference = self.reference_image + + assert image_reference.ndim == 2 + assert images_todo.ndim == 3 + + _, newangle = self._get_ang_scale( + image_reference, + images_todo, + constraints_dynamic_scale_0, + constraints_dynamic_scale_1, + constraints_dynamic_angle_0, + constraints_dynamic_angle_1, + ) + + return newangle + + # --------------- + + def _get_pcorr_shape(self, shape: torch.Size) -> tuple[int, int]: + ret = (int(max(shape[-2:]) * 1.0),) * 2 + return ret + + def _get_log_base(self, shape: torch.Size, new_r: torch.Tensor) -> torch.Tensor: + old_r = torch.tensor( + (float(shape[-2]) * self.excess_const) / 2.0, + dtype=self.default_dtype, + device=self.device, + ) + log_base = torch.exp(torch.log(old_r) / new_r) + return log_base + + def wrap_angle( + self, angles: torch.Tensor, ceil: float = 2 * torch.pi + ) -> torch.Tensor: + angles += ceil / 2.0 + angles %= ceil + angles -= ceil / 2.0 + return angles + + def get_borderval( + self, img: torch.Tensor, radius: int | None = None + ) -> torch.Tensor: + assert img.ndim == 3 + if radius is None: + mindim = min([int(img.shape[-2]), int(img.shape[-1])]) + radius = max(1, mindim // 20) + mask = torch.zeros( + (int(img.shape[-2]), int(img.shape[-1])), + dtype=torch.bool, + device=self.device, + ) + mask[:, :radius] = True + mask[:, -radius:] = True + mask[:radius, :] = True + mask[-radius:, :] = True + + mean = torch.median(img[:, mask], dim=-1)[0] + return mean + + def get_apofield(self, shape: torch.Size, aporad: int) -> torch.Tensor: + if aporad == 0: + return torch.ones( + shape[-2:], + dtype=self.default_dtype, + device=self.device, + ) + + assert int(shape[-2]) > aporad * 2 + assert int(shape[-1]) > aporad * 2 + + apos = torch.hann_window( + aporad * 2, dtype=self.default_dtype, periodic=False, device=self.device + ) + + toapp_0 = torch.ones( + shape[-2], + dtype=self.default_dtype, + device=self.device, + ) + toapp_0[:aporad] = apos[:aporad] + toapp_0[-aporad:] = apos[-aporad:] + + toapp_1 = torch.ones( + shape[-1], + dtype=self.default_dtype, + device=self.device, + ) + toapp_1[:aporad] = apos[:aporad] + toapp_1[-aporad:] = apos[-aporad:] + + apofield = torch.outer(toapp_0, toapp_1) + + return apofield + + def _get_subarr( + self, array: torch.Tensor, center: torch.Tensor, rad: int + ) -> torch.Tensor: + assert array.ndim == 3 + assert center.ndim == 2 + assert array.shape[0] == center.shape[0] + assert center.shape[1] == 2 + + dim = 1 + 2 * rad + subarr = torch.zeros( + (array.shape[0], dim, dim), dtype=self.default_dtype, device=self.device + ) + + corner = center - rad + idx_p = range(0, corner.shape[0]) + for ii in range(0, dim): + yidx = corner[:, 0] + ii + yidx %= array.shape[-2] + for jj in range(0, dim): + xidx = corner[:, 1] + jj + xidx %= array.shape[-1] + subarr[:, ii, jj] = array[idx_p, yidx, xidx] + + return subarr + + def _argmax_2d(self, array: torch.Tensor) -> torch.Tensor: + assert array.ndim == 3 + + max_pos = array.reshape( + (array.shape[0], array.shape[1] * array.shape[2]) + ).argmax(dim=1) + pos_0 = max_pos // array.shape[2] + max_pos -= pos_0 * array.shape[2] + ret = torch.zeros( + (array.shape[0], 2), dtype=self.default_dtype, device=self.device + ) + ret[:, 0] = pos_0 + ret[:, 1] = max_pos + return ret.type(dtype=torch.int64) + + def _apodize(self, what: torch.Tensor) -> torch.Tensor: + mindim = min([int(what.shape[-2]), int(what.shape[-1])]) + aporad = int(mindim * 0.12) + + apofield = self.get_apofield(what.shape, aporad).unsqueeze(0) + + res = what * apofield + bg = self.get_borderval(what, aporad // 2).unsqueeze(-1).unsqueeze(-1) + res += bg * (1 - apofield) + return res + + def _logpolar_filter(self, shape: torch.Size) -> torch.Tensor: + yy = torch.linspace( + -torch.pi / 2.0, + torch.pi / 2.0, + shape[-2], + dtype=self.default_dtype, + device=self.device, + ).unsqueeze(1) + + xx = torch.linspace( + -torch.pi / 2.0, + torch.pi / 2.0, + shape[-1], + dtype=self.default_dtype, + device=self.device, + ).unsqueeze(0) + + rads = torch.sqrt(yy**2 + xx**2) + filt = 1.0 - torch.cos(rads) ** 2 + + filt[torch.abs(rads) > torch.pi / 2] = 1 + return filt + + def _get_angles(self, shape: torch.Tensor) -> torch.Tensor: + ret = torch.zeros( + (int(shape[-2]), int(shape[-1])), + dtype=self.default_dtype, + device=self.device, + ) + ret -= torch.linspace( + 0, + torch.pi, + int(shape[-2] + 1), + dtype=self.default_dtype, + device=self.device, + )[:-1].unsqueeze(-1) + + return ret + + def _get_lograd(self, shape: torch.Tensor, log_base: torch.Tensor) -> torch.Tensor: + ret = torch.zeros( + (int(shape[-2]), int(shape[-1])), + dtype=self.default_dtype, + device=self.device, + ) + ret += torch.pow( + log_base, + torch.arange( + 0, + int(shape[-1]), + dtype=self.default_dtype, + device=self.device, + ), + ).unsqueeze(0) + return ret + + def _logpolar( + self, image: torch.Tensor, shape: torch.Tensor, log_base: torch.Tensor + ) -> torch.Tensor: + assert image.ndim == 3 + + imshape: torch.Tensor = torch.tensor( + image.shape[-2:], + dtype=self.default_dtype, + device=self.device, + ) + + center: torch.Tensor = imshape.clone() / 2 + + theta: torch.Tensor = self._get_angles(shape) + radius_x: torch.Tensor = self._get_lograd(shape, log_base) + radius_y: torch.Tensor = radius_x.clone() + + ellipse_coef: torch.Tensor = imshape[0] / imshape[1] + radius_x /= ellipse_coef + + y = radius_y * torch.sin(theta) + center[0] + y /= float(image.shape[-2]) + y *= 2 + y -= 1 + + x = radius_x * torch.cos(theta) + center[1] + x /= float(image.shape[-1]) + x *= 2 + x -= 1 + + idx_x = torch.where(torch.abs(x) <= 1.0, 1.0, 0.0) + idx_y = torch.where(torch.abs(y) <= 1.0, 1.0, 0.0) + + normalized_coords = torch.cat( + ( + x.unsqueeze(-1), + y.unsqueeze(-1), + ), + dim=-1, + ).unsqueeze(0) + + output = torch.empty( + (int(image.shape[0]), int(y.shape[0]), int(y.shape[1])), + dtype=self.default_dtype, + device=self.device, + ) + + for id in range(0, int(image.shape[0])): + bgval: torch.Tensor = torch.quantile(image[id, :, :], q=1.0 / 100.0) + + temp = torch.nn.functional.grid_sample( + image[id, :, :].unsqueeze(0).unsqueeze(0), + normalized_coords, + mode="bilinear", + padding_mode="zeros", + align_corners=False, + ) + + output[id, :, :] = torch.where((idx_x * idx_y) == 0.0, bgval, temp) + + return output + + def _argmax_ext(self, array: torch.Tensor, exponent: float | str) -> torch.Tensor: + assert array.ndim == 3 + + if exponent == "inf": + ret = self._argmax_2d(array) + else: + assert isinstance(exponent, float) or isinstance(exponent, int) + + col = ( + torch.arange( + 0, array.shape[-2], dtype=self.default_dtype, device=self.device + ) + .unsqueeze(-1) + .unsqueeze(0) + ) + row = ( + torch.arange( + 0, array.shape[-1], dtype=self.default_dtype, device=self.device + ) + .unsqueeze(0) + .unsqueeze(0) + ) + + arr2 = torch.pow(array, float(exponent)) + arrsum = arr2.sum(dim=-2).sum(dim=-1) + + ret = torch.zeros( + (array.shape[0], 2), dtype=self.default_dtype, device=self.device + ) + + arrprody = (arr2 * col).sum(dim=-1).sum(dim=-1) / arrsum + arrprodx = (arr2 * row).sum(dim=-1).sum(dim=-1) / arrsum + + ret[:, 0] = arrprody.squeeze(-1).squeeze(-1) + ret[:, 1] = arrprodx.squeeze(-1).squeeze(-1) + + idx = torch.where(arrsum == 0.0)[0] + ret[idx, :] = 0.0 + return ret + + def _interpolate( + self, array: torch.Tensor, rough: torch.Tensor, rad: int = 2 + ) -> torch.Tensor: + assert array.ndim == 3 + assert rough.ndim == 2 + + rough = torch.round(rough).type(torch.int64) + + surroundings = self._get_subarr(array, rough, rad) + + com = self._argmax_ext(surroundings, 1.0) + + offset = com - rad + ret = rough + offset + + ret += 0.5 + ret %= ( + torch.tensor(array.shape[-2:], dtype=self.default_dtype, device=self.device) + .type(dtype=torch.int64) + .unsqueeze(0) + ) + ret -= 0.5 + return ret + + def _get_success( + self, array: torch.Tensor, coord: torch.Tensor, radius: int = 2 + ) -> torch.Tensor: + assert array.ndim == 3 + assert coord.ndim == 2 + assert array.shape[0] == coord.shape[0] + assert coord.shape[1] == 2 + + coord = torch.round(coord).type(dtype=torch.int64) + subarr = self._get_subarr( + array, coord, 2 + ) # Not my fault. They want a 2 there. Not radius + + theval = subarr.sum(dim=-1).sum(dim=-1) + + theval2 = array[range(0, coord.shape[0]), coord[:, 0], coord[:, 1]] + + success = torch.sqrt(theval * theval2) + return success + + def _get_constraint_mask( + self, + shape: torch.Size, + log_base: torch.Tensor, + constraints_scale_0: torch.Tensor, + constraints_scale_1: torch.Tensor | None, + constraints_angle_0: torch.Tensor, + constraints_angle_1: torch.Tensor | None, + ) -> torch.Tensor: + assert constraints_scale_0 is not None + assert constraints_angle_0 is not None + assert constraints_scale_0.ndim == 1 + assert constraints_angle_0.ndim == 1 + + assert constraints_scale_0.shape[0] == constraints_angle_0.shape[0] + + mask: torch.Tensor = torch.ones( + (constraints_scale_0.shape[0], int(shape[-2]), int(shape[-1])), + device=self.device, + dtype=self.default_dtype, + ) + + scale: torch.Tensor = constraints_scale_0.clone() + if constraints_scale_1 is not None: + sigma: torch.Tensor | None = constraints_scale_1.clone() + else: + sigma = None + + scales = torch.fft.ifftshift( + self._get_lograd( + torch.tensor(shape[-2:], device=self.device, dtype=self.default_dtype), + log_base, + ) + ) + + scales *= log_base ** (-shape[-1] / 2.0) + scales = scales.unsqueeze(0) - (1.0 / scale).unsqueeze(-1).unsqueeze(-1) + + if sigma is not None: + assert sigma.shape[0] == constraints_scale_0.shape[0] + + for p_id in range(0, sigma.shape[0]): + if sigma[p_id] == 0: + ascales = torch.abs(scales[p_id, ...]) + scale_min = ascales.min() + binary_mask = torch.where(ascales > scale_min, 0.0, 1.0) + mask[p_id, ...] *= binary_mask + else: + mask[p_id, ...] *= torch.exp( + -(torch.pow(scales[p_id, ...], 2)) / torch.pow(sigma[p_id], 2) + ) + + angle: torch.Tensor = constraints_angle_0.clone() + if constraints_angle_1 is not None: + sigma = constraints_angle_1.clone() + else: + sigma = None + + angles = self._get_angles( + torch.tensor(shape[-2:], device=self.device, dtype=self.default_dtype) + ) + + angles = angles.unsqueeze(0) + torch.deg2rad(angle).unsqueeze(-1).unsqueeze(-1) + + angles = torch.rad2deg(angles) + + if sigma is not None: + assert sigma.shape[0] == constraints_scale_0.shape[0] + + for p_id in range(0, sigma.shape[0]): + if sigma[p_id] == 0: + aangles = torch.abs(angles[p_id, ...]) + angle_min = aangles.min() + binary_mask = torch.where(aangles > angle_min, 0.0, 1.0) + mask[p_id, ...] *= binary_mask + else: + mask *= torch.exp( + -(torch.pow(angles[p_id, ...], 2)) / torch.pow(sigma[p_id], 2) + ) + + mask = torch.fft.fftshift(mask, dim=(-2, -1)) + + return mask + + def argmax_angscale( + self, + array: torch.Tensor, + log_base: torch.Tensor, + constraints_scale_0: torch.Tensor, + constraints_scale_1: torch.Tensor | None, + constraints_angle_0: torch.Tensor, + constraints_angle_1: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + assert array.ndim == 3 + assert constraints_scale_0 is not None + assert constraints_angle_0 is not None + assert constraints_scale_0.ndim == 1 + assert constraints_angle_0.ndim == 1 + + mask = self._get_constraint_mask( + array.shape[-2:], + log_base, + constraints_scale_0, + constraints_scale_1, + constraints_angle_0, + constraints_angle_1, + ) + + array_orig = array.clone() + + array *= mask + ret = self._argmax_ext(array, self.exponent) + + ret_final = self._interpolate(array, ret) + + success = self._get_success(array_orig, ret_final, 0) + + return ret_final, success + + def argmax_translation( + self, array: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + assert array.ndim == 3 + + array_orig = array.clone() + + ashape = torch.tensor(array.shape[-2:], device=self.device).type( + dtype=torch.int64 + ) + + aporad = (ashape // 6).min() + mask2 = self.get_apofield(torch.Size(ashape), aporad).unsqueeze(0) + array *= mask2 + + tvec = self._argmax_ext(array, "inf") + tvec = self._interpolate(array_orig, tvec) + + success = self._get_success(array_orig, tvec, 2) + + return tvec, success + + def transform_img( + self, + img: torch.Tensor, + scale: torch.Tensor | None = None, + angle: torch.Tensor | None = None, + tvec: torch.Tensor | None = None, + bgval: torch.Tensor | None = None, + ) -> torch.Tensor: + assert img.ndim == 3 + + if scale is None: + scale = torch.ones( + (img.shape[0],), dtype=self.default_dtype, device=self.device + ) + assert scale.ndim == 1 + assert scale.shape[0] == img.shape[0] + + if angle is None: + angle = torch.zeros( + (img.shape[0],), dtype=self.default_dtype, device=self.device + ) + assert angle.ndim == 1 + assert angle.shape[0] == img.shape[0] + + if tvec is None: + tvec = torch.zeros( + (img.shape[0], 2), dtype=self.default_dtype, device=self.device + ) + assert tvec.ndim == 2 + assert tvec.shape[0] == img.shape[0] + assert tvec.shape[1] == 2 + + if bgval is None: + bgval = self.get_borderval(img) + assert bgval.ndim == 1 + assert bgval.shape[0] == img.shape[0] + + # Otherwise we need to decompose it and put it back together + assert torch.is_complex(img) is False + + output = torch.zeros_like(img) + + for pos in range(0, img.shape[0]): + image_processed = img[pos, :, :].unsqueeze(0).clone() + + temp_shift = [ + int(round(tvec[pos, 1].item() * self.scale_factor)), + int(round(tvec[pos, 0].item() * self.scale_factor)), + ] + + image_processed = torch.nn.functional.interpolate( + image_processed.unsqueeze(0), + scale_factor=self.scale_factor, + mode="bilinear", + ).squeeze(0) + + image_processed = tv.transforms.functional.affine( + img=image_processed, + angle=-float(angle[pos]), + translate=temp_shift, + scale=float(scale[pos]), + shear=[0, 0], + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=float(bgval[pos]), + center=None, + ) + + image_processed = torch.nn.functional.interpolate( + image_processed.unsqueeze(0), + scale_factor=1.0 / self.scale_factor, + mode="bilinear", + ).squeeze(0) + + image_processed = tv.transforms.functional.center_crop( + image_processed, img.shape[-2:] + ) + + output[pos, ...] = image_processed.squeeze(0) + + return output + + def transform_img_dict( + self, + img: torch.Tensor, + scale: torch.Tensor | None = None, + angle: torch.Tensor | None = None, + tvec: torch.Tensor | None = None, + bgval: torch.Tensor | None = None, + invert=False, + ) -> torch.Tensor: + if invert: + if scale is not None: + scale = 1.0 / scale + if angle is not None: + angle *= -1 + if tvec is not None: + tvec *= -1 + + res = self.transform_img(img, scale, angle, tvec, bgval=bgval) + return res + + def _phase_correlation( + self, image_reference: torch.Tensor, images_todo: torch.Tensor, callback, *args + ) -> tuple[torch.Tensor, torch.Tensor]: + assert image_reference.ndim == 3 + assert image_reference.shape[0] == 1 + assert images_todo.ndim == 3 + + assert callback is not None + + image_reference_fft = torch.fft.fft2(image_reference, dim=(-2, -1)) + images_todo_fft = torch.fft.fft2(images_todo, dim=(-2, -1)) + + eps = torch.abs(images_todo_fft).max(dim=-1)[0].max(dim=-1)[0] * 1e-15 + + cps = abs( + torch.fft.ifft2( + (image_reference_fft * images_todo_fft.conj()) + / ( + torch.abs(image_reference_fft) * torch.abs(images_todo_fft) + + eps.unsqueeze(-1).unsqueeze(-1) + ) + ) + ) + + scps = torch.fft.fftshift(cps, dim=(-2, -1)) + + ret, success = callback(scps, *args) + + ret[:, 0] -= image_reference_fft.shape[-2] // 2 + ret[:, 1] -= image_reference_fft.shape[-1] // 2 + + return ret, success + + def _translation( + self, im0: torch.Tensor, im1: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + assert im0.ndim == 2 + ret, succ = self._phase_correlation( + im0.unsqueeze(0), im1, self.argmax_translation + ) + return ret, succ + + def _get_ang_scale( + self, + image_reference: torch.Tensor, + images_todo: torch.Tensor, + constraints_scale_0: torch.Tensor, + constraints_scale_1: torch.Tensor | None, + constraints_angle_0: torch.Tensor, + constraints_angle_1: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + assert image_reference.ndim == 2 + assert images_todo.ndim == 3 + assert image_reference.shape[-1] == images_todo.shape[-1] + assert image_reference.shape[-2] == images_todo.shape[-2] + assert constraints_scale_0.shape[0] == images_todo.shape[0] + assert constraints_angle_0.shape[0] == images_todo.shape[0] + + if constraints_scale_1 is not None: + assert constraints_scale_1.shape[0] == images_todo.shape[0] + + if constraints_angle_1 is not None: + assert constraints_angle_1.shape[0] == images_todo.shape[0] + + if self.image_reference_dft is None: + image_reference_apod = self._apodize(image_reference.unsqueeze(0)) + self.image_reference_dft = torch.fft.fftshift( + torch.fft.fft2(image_reference_apod, dim=(-2, -1)), dim=(-2, -1) + ) + self.filt = self._logpolar_filter(image_reference.shape).unsqueeze(0) + self.image_reference_dft *= self.filt + self.pcorr_shape = torch.tensor( + self._get_pcorr_shape(image_reference.shape[-2:]), + dtype=self.default_dtype, + device=self.device, + ) + self.log_base = self._get_log_base( + image_reference.shape, + self.pcorr_shape[1], + ) + self.image_reference_logp = self._logpolar( + torch.abs(self.image_reference_dft), self.pcorr_shape, self.log_base + ) + + images_todo_apod = self._apodize(images_todo) + images_todo_dft = torch.fft.fftshift( + torch.fft.fft2(images_todo_apod, dim=(-2, -1)), dim=(-2, -1) + ) + + images_todo_dft *= self.filt + + images_todo_lopg = self._logpolar( + torch.abs(images_todo_dft), self.pcorr_shape, self.log_base + ) + + temp, _ = self._phase_correlation( + self.image_reference_logp, + images_todo_lopg, + self.argmax_angscale, + self.log_base, + constraints_scale_0, + constraints_scale_1, + constraints_angle_0, + constraints_angle_1, + ) + + arg_ang = temp[:, 0].clone() + arg_rad = temp[:, 1].clone() + + angle = -torch.pi * arg_ang / float(self.pcorr_shape[0]) + angle = torch.rad2deg(angle) + + angle = self.wrap_angle(angle, 360) + + scale = torch.pow(self.log_base, arg_rad) + + angle = -angle + scale = 1.0 / scale + + assert torch.where(scale < 2)[0].shape[0] == scale.shape[0] + assert torch.where(scale > 0.5)[0].shape[0] == scale.shape[0] + + return scale, angle + + def translation( + self, im0: torch.Tensor, im1: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + angle = torch.zeros( + (im1.shape[0]), dtype=self.default_dtype, device=self.device + ) + assert im1.ndim == 3 + assert im0.shape[-2] == im1.shape[-2] + assert im0.shape[-1] == im1.shape[-1] + + tvec, succ = self._translation(im0, im1) + tvec2, succ2 = self._translation(im0, torch.rot90(im1, k=2, dims=[-2, -1])) + + assert tvec.shape[0] == tvec2.shape[0] + assert tvec.ndim == 2 + assert tvec2.ndim == 2 + assert tvec.shape[1] == 2 + assert tvec2.shape[1] == 2 + assert succ.shape[0] == succ2.shape[0] + assert succ.ndim == 1 + assert succ2.ndim == 1 + assert tvec.shape[0] == succ.shape[0] + assert angle.shape[0] == tvec.shape[0] + assert angle.ndim == 1 + + for pos in range(0, angle.shape[0]): + pick_rotated = False + if succ2[pos] > succ[pos]: + pick_rotated = True + + if pick_rotated: + tvec[pos, :] = tvec2[pos, :] + succ[pos] = succ2[pos] + angle[pos] += 180 + + return tvec, succ, angle + + def _similarity( + self, + image_reference: torch.Tensor, + images_todo: torch.Tensor, + bgval: torch.Tensor, + ): + assert image_reference.ndim == 2 + assert images_todo.ndim == 3 + assert image_reference.shape[-1] == images_todo.shape[-1] + assert image_reference.shape[-2] == images_todo.shape[-2] + + # We are going to iterate and precise scale and angle estimates + scale: torch.Tensor = torch.ones( + (images_todo.shape[0]), dtype=self.default_dtype, device=self.device + ) + angle: torch.Tensor = torch.zeros( + (images_todo.shape[0]), dtype=self.default_dtype, device=self.device + ) + + constraints_dynamic_angle_0: torch.Tensor = torch.zeros( + (images_todo.shape[0]), dtype=self.default_dtype, device=self.device + ) + constraints_dynamic_angle_1: torch.Tensor | None = None + constraints_dynamic_scale_0: torch.Tensor = torch.ones( + (images_todo.shape[0]), dtype=self.default_dtype, device=self.device + ) + constraints_dynamic_scale_1: torch.Tensor | None = None + + newscale, newangle = self._get_ang_scale( + image_reference, + images_todo, + constraints_dynamic_scale_0, + constraints_dynamic_scale_1, + constraints_dynamic_angle_0, + constraints_dynamic_angle_1, + ) + scale *= newscale + angle += newangle + + im2 = self.transform_img(images_todo, scale, angle, bgval=bgval) + + tvec, self.success, res_angle = self.translation(image_reference, im2) + + angle += res_angle + + angle = self.wrap_angle(angle, 360) + + return scale, angle, tvec + + def similarity( + self, + image_reference: torch.Tensor, + images_todo: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + assert image_reference.ndim == 2 + assert images_todo.ndim == 3 + + bgval: torch.Tensor = self.get_borderval(img=images_todo, radius=5) + + scale, angle, tvec = self._similarity( + image_reference, + images_todo, + bgval, + ) + + im2 = self.transform_img_dict( + img=images_todo, + scale=scale, + angle=angle, + tvec=tvec, + bgval=bgval, + ) + + return scale, angle, tvec, im2 diff --git a/run_svd.py b/run_svd.py new file mode 100644 index 0000000..e12ea03 --- /dev/null +++ b/run_svd.py @@ -0,0 +1,57 @@ +import torch +import numpy as np +from svd import calculate_svd, to_remove, temporal_filter, svd_denoise + +if __name__ == "__main__": + filename: str = "example_data_crop" + window_size: int = 2 + kernel_size_pooling: int = 2 + orig_freq: int = 30 + new_freq: int = 3 + filtfilt_chuck_size: int = 10 + bp_low_frequency: float = 0.1 + bp_high_frequency: float = 1.0 + + torch_device: torch.device = torch.device( + "cuda:0" if torch.cuda.is_available() else "cpu" + ) + + print("Load data") + input = np.load(filename + str(".npy")) + data = torch.tensor(input, device=torch_device) + + print("Movement compensation [MISSING!!!!]") + print("(include ImageAlignment.py into processing chain)") + + print("SVD") + whiten_mean, whiten_k, eigenvalues = calculate_svd(data) + + print("Calculate to_remove") + data = torch.tensor(input, device=torch_device) + to_remove_data = to_remove(data, whiten_k, whiten_mean) + + data -= to_remove_data + del to_remove_data + + print("apply temporal filter") + data = temporal_filter( + data, + device=torch_device, + orig_freq=orig_freq, + new_freq=new_freq, + filtfilt_chuck_size=filtfilt_chuck_size, + bp_low_frequency=bp_low_frequency, + bp_high_frequency=bp_high_frequency, + ) + + print("SVD Denosing") + data_out = svd_denoise(data, window_size=window_size) + + print("Pooling") + avage_pooling = torch.nn.AvgPool2d( + kernel_size=(kernel_size_pooling, kernel_size_pooling), + stride=(kernel_size_pooling, kernel_size_pooling), + ) + data_out = avage_pooling(data_out) + + np.save(filename + str("_decorrelated.npy"), data_out.cpu()) diff --git a/svd.py b/svd.py new file mode 100644 index 0000000..5c21120 --- /dev/null +++ b/svd.py @@ -0,0 +1,204 @@ +import torch +import torchaudio as ta +import cv2 +import numpy as np +from tqdm import trange + + +def convert_avi_to_npy(filename: str) -> None: + capture_from_file = cv2.VideoCapture(filename + str(".avi")) + avi_length = int(capture_from_file.get(cv2.CAP_PROP_FRAME_COUNT)) + + # To torch and beyond + data: np.ndarray | None = None + for i in trange(0, avi_length): + read_ok, frame = capture_from_file.read() + + assert read_ok + + if data is None: + data = np.empty( + (avi_length, frame.shape[0], frame.shape[1]), + dtype=np.float32, + ) + assert data is not None + data[i, :, :] = frame.mean(axis=-1).astype(np.float32) + assert data is not None + np.save(filename + str(".npy"), data) + + +@torch.no_grad() +def to_remove( + data: torch.Tensor, whiten_k: torch.Tensor, whiten_mean: torch.Tensor +) -> torch.Tensor: + whiten_mean = whiten_mean + whiten_k = whiten_k[:, :, 0] + + data = (data - whiten_mean.unsqueeze(0)) * whiten_k.unsqueeze(0) + data_svd = data.sum(dim=-1).sum(dim=-1).unsqueeze(-1).unsqueeze(-1) + + factor = (data * data_svd).sum(dim=0, keepdim=True) / (data_svd**2).sum( + dim=0, keepdim=True + ) + to_remove = data_svd * factor + to_remove /= whiten_k.unsqueeze(0) + 1e-20 + to_remove += whiten_mean.unsqueeze(0) + + return to_remove + + +@torch.no_grad() +def calculate_svd( + input: torch.Tensor, lowrank_q: int = 6 +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + selection = torch.flatten( + input.clone().movedim(0, -1), + start_dim=0, + end_dim=1, + ) + + whiten_mean = torch.mean(selection, dim=-1) + selection -= whiten_mean.unsqueeze(-1) + whiten_mean = whiten_mean.reshape((input.shape[1], input.shape[2])) + + svd_u, svd_s, _ = torch.svd_lowrank(selection, q=lowrank_q) + + whiten_k = ( + torch.sign(svd_u[0, :]).unsqueeze(0) * svd_u / (svd_s.unsqueeze(0) + 1e-20) + ) + whiten_k = whiten_k.reshape((input.shape[1], input.shape[2], svd_s.shape[0])) + eigenvalues = svd_s + + return whiten_mean, whiten_k, eigenvalues + + +@torch.no_grad() +def filtfilt( + input: torch.Tensor, + butter_a: torch.Tensor, + butter_b: torch.Tensor, +) -> torch.Tensor: + assert butter_a.ndim == 1 + assert butter_b.ndim == 1 + assert butter_a.shape[0] == butter_b.shape[0] + + process_data: torch.Tensor = input.movedim(0, -1).detach().clone() + + padding_length = 12 * int(butter_a.shape[0]) + left_padding = 2 * process_data[..., 0].unsqueeze(-1) - process_data[ + ..., 1 : padding_length + 1 + ].flip(-1) + right_padding = 2 * process_data[..., -1].unsqueeze(-1) - process_data[ + ..., -(padding_length + 1) : -1 + ].flip(-1) + process_data_padded = torch.cat((left_padding, process_data, right_padding), dim=-1) + + output = ta.functional.filtfilt( + process_data_padded.unsqueeze(0), butter_a, butter_b, clamp=False + ).squeeze(0) + output = output[..., padding_length:-padding_length].movedim(-1, 0) + return output + + +@torch.no_grad() +def butter_bandpass( + device: torch.device, + low_frequency: float = 0.1, + high_frequency: float = 1.0, + fs: float = 30.0, +) -> tuple[torch.Tensor, torch.Tensor]: + import scipy + + butter_b_np, butter_a_np = scipy.signal.butter( + 4, [low_frequency, high_frequency], btype="bandpass", output="ba", fs=fs + ) + butter_a = torch.tensor(butter_a_np, device=device, dtype=torch.float32) + butter_b = torch.tensor(butter_b_np, device=device, dtype=torch.float32) + return butter_a, butter_b + + +@torch.no_grad() +def chunk_iterator(array: torch.Tensor, chunk_size: int): + for i in range(0, array.shape[0], chunk_size): + yield array[i : i + chunk_size] + + +@torch.no_grad() +def lowpass( + data: torch.Tensor, + device: torch.device, + low_frequency: float = 0.1, + high_frequency: float = 1.0, + fs=30.0, + filtfilt_chuck_size: int = 10, +) -> torch.Tensor: + butter_a, butter_b = butter_bandpass( + device=device, + low_frequency=low_frequency, + high_frequency=high_frequency, + fs=fs, + ) + + index_full_dataset: torch.Tensor = torch.arange( + 0, data.shape[1], device=device, dtype=torch.int64 + ) + + for chunk in chunk_iterator(index_full_dataset, filtfilt_chuck_size): + temp_filtfilt = filtfilt( + data[:, chunk, :], + butter_a=butter_a, + butter_b=butter_b, + ) + data[:, chunk, :] = temp_filtfilt + + return data + + +@torch.no_grad() +def temporal_filter( + data: torch.Tensor, + device: torch.device, + orig_freq: int = 30, + new_freq: int = 3, + filtfilt_chuck_size: int = 10, + bp_low_frequency: float = 0.1, + bp_high_frequency: float = 1.0, +) -> torch.Tensor: + data = ta.functional.resample( + data.movedim(0, -1), orig_freq=orig_freq, new_freq=new_freq + ).movedim(-1, 0) + + data = lowpass( + data, + device=device, + low_frequency=bp_low_frequency, + high_frequency=bp_high_frequency, + fs=float(new_freq), + filtfilt_chuck_size=filtfilt_chuck_size, + ) + + return data + + +@torch.no_grad() +def svd_denoise(data: torch.Tensor, window_size: int) -> torch.Tensor: + data_out = torch.zeros_like(data) + + for x in trange(0, data.shape[1]): + for y in range(0, data.shape[2]): + if ( + ((x - window_size) > 0) + and ((y - window_size) > 0) + and ((x + window_size) <= data.shape[1]) + and ((y + window_size) <= data.shape[2]) + ): + data_sel: torch.Tensor = data[ + :, + x - window_size : x + window_size + 1, + y - window_size : y + window_size + 1, + ] + + whiten_mean, whiten_k, eigenvalues = calculate_svd(data_sel.clone()) + to_remove_data = to_remove(data_sel, whiten_k, whiten_mean) + data_out[:, x, y] = to_remove_data[:, window_size, window_size] + return data_out