1011 lines
32 KiB
Python
1011 lines
32 KiB
Python
|
import torch
|
||
|
import torchvision as tv
|
||
|
|
||
|
# The source code is based on:
|
||
|
# https://github.com/matejak/imreg_dft
|
||
|
|
||
|
# The original LICENSE:
|
||
|
# Copyright (c) 2014, Matěj Týč
|
||
|
# Copyright (c) 2011-2014, Christoph Gohlke
|
||
|
# Copyright (c) 2011-2014, The Regents of the University of California
|
||
|
# Produced at the Laboratory for Fluorescence Dynamics
|
||
|
|
||
|
# All rights reserved.
|
||
|
|
||
|
# Redistribution and use in source and binary forms, with or without
|
||
|
# modification, are permitted provided that the following conditions are met:
|
||
|
|
||
|
# * Redistributions of source code must retain the above copyright notice, this
|
||
|
# list of conditions and the following disclaimer.
|
||
|
|
||
|
# * Redistributions in binary form must reproduce the above copyright notice,
|
||
|
# this list of conditions and the following disclaimer in the documentation
|
||
|
# and/or other materials provided with the distribution.
|
||
|
|
||
|
# * Neither the name of the {organization} nor the names of its
|
||
|
# contributors may be used to endorse or promote products derived from
|
||
|
# this software without specific prior written permission.
|
||
|
|
||
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||
|
|
||
|
|
||
|
class ImageAlignment(torch.nn.Module):
|
||
|
device: torch.device
|
||
|
default_dtype: torch.dtype
|
||
|
excess_const: float = 1.1
|
||
|
exponent: str = "inf"
|
||
|
success: torch.Tensor | None = None
|
||
|
|
||
|
# The factor that detmines how many
|
||
|
# sub-pixel we will shift
|
||
|
scale_factor: int = 4
|
||
|
|
||
|
reference_image: torch.Tensor | None = None
|
||
|
|
||
|
last_scale: torch.Tensor | None = None
|
||
|
last_angle: torch.Tensor | None = None
|
||
|
last_tvec: torch.Tensor | None = None
|
||
|
|
||
|
# Cache
|
||
|
image_reference_dft: torch.Tensor | None = None
|
||
|
filt: torch.Tensor
|
||
|
pcorr_shape: torch.Tensor
|
||
|
log_base: torch.Tensor
|
||
|
image_reference_logp: torch.Tensor
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
device: torch.device | None = None,
|
||
|
default_dtype: torch.dtype | None = None,
|
||
|
) -> None:
|
||
|
super().__init__()
|
||
|
|
||
|
assert device is not None
|
||
|
assert default_dtype is not None
|
||
|
self.device = device
|
||
|
self.default_dtype = default_dtype
|
||
|
|
||
|
def set_new_reference_image(self, new_reference_image: torch.Tensor | None = None):
|
||
|
assert new_reference_image is not None
|
||
|
assert new_reference_image.ndim == 2
|
||
|
self.reference_image = (
|
||
|
new_reference_image.detach()
|
||
|
.clone()
|
||
|
.to(device=self.device)
|
||
|
.type(dtype=self.default_dtype)
|
||
|
)
|
||
|
self.image_reference_dft = None
|
||
|
|
||
|
def forward(
|
||
|
self, input: torch.Tensor, new_reference_image: torch.Tensor | None = None
|
||
|
) -> torch.Tensor:
|
||
|
assert input.ndim == 3
|
||
|
|
||
|
if new_reference_image is not None:
|
||
|
self.set_new_reference_image(new_reference_image)
|
||
|
|
||
|
assert self.reference_image is not None
|
||
|
assert self.reference_image.ndim == 2
|
||
|
assert input.shape[-2] == self.reference_image.shape[-2]
|
||
|
assert input.shape[-1] == self.reference_image.shape[-1]
|
||
|
|
||
|
self.last_scale, self.last_angle, self.last_tvec, output = self.similarity(
|
||
|
self.reference_image,
|
||
|
input.to(device=self.device).type(dtype=self.default_dtype),
|
||
|
)
|
||
|
|
||
|
return output
|
||
|
|
||
|
def dry_run(
|
||
|
self, input: torch.Tensor, new_reference_image: torch.Tensor | None = None
|
||
|
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
|
||
|
assert input.ndim == 3
|
||
|
|
||
|
if new_reference_image is not None:
|
||
|
self.set_new_reference_image(new_reference_image)
|
||
|
|
||
|
assert self.reference_image is not None
|
||
|
assert self.reference_image.ndim == 2
|
||
|
assert input.shape[-2] == self.reference_image.shape[-2]
|
||
|
assert input.shape[-1] == self.reference_image.shape[-1]
|
||
|
|
||
|
images_todo = input.to(device=self.device).type(dtype=self.default_dtype)
|
||
|
image_reference = self.reference_image
|
||
|
|
||
|
assert image_reference.ndim == 2
|
||
|
assert images_todo.ndim == 3
|
||
|
|
||
|
bgval: torch.Tensor = self.get_borderval(img=images_todo, radius=5)
|
||
|
|
||
|
self.last_scale, self.last_angle, self.last_tvec = self._similarity(
|
||
|
image_reference,
|
||
|
images_todo,
|
||
|
bgval,
|
||
|
)
|
||
|
|
||
|
return self.last_scale, self.last_angle, self.last_tvec
|
||
|
|
||
|
def dry_run_translation(
|
||
|
self, input: torch.Tensor, new_reference_image: torch.Tensor | None = None
|
||
|
) -> torch.Tensor:
|
||
|
assert input.ndim == 3
|
||
|
|
||
|
if new_reference_image is not None:
|
||
|
self.set_new_reference_image(new_reference_image)
|
||
|
|
||
|
assert self.reference_image is not None
|
||
|
assert self.reference_image.ndim == 2
|
||
|
assert input.shape[-2] == self.reference_image.shape[-2]
|
||
|
assert input.shape[-1] == self.reference_image.shape[-1]
|
||
|
|
||
|
images_todo = input.to(device=self.device).type(dtype=self.default_dtype)
|
||
|
image_reference = self.reference_image
|
||
|
|
||
|
assert image_reference.ndim == 2
|
||
|
assert images_todo.ndim == 3
|
||
|
|
||
|
tvec, _ = self._translation(image_reference, images_todo)
|
||
|
|
||
|
return tvec
|
||
|
|
||
|
# ---------------
|
||
|
|
||
|
def dry_run_angle(
|
||
|
self,
|
||
|
input: torch.Tensor,
|
||
|
new_reference_image: torch.Tensor | None = None,
|
||
|
) -> torch.Tensor:
|
||
|
assert input.ndim == 3
|
||
|
|
||
|
if new_reference_image is not None:
|
||
|
self.set_new_reference_image(new_reference_image)
|
||
|
|
||
|
constraints_dynamic_angle_0: torch.Tensor = torch.zeros(
|
||
|
(input.shape[0]), dtype=self.default_dtype, device=self.device
|
||
|
)
|
||
|
constraints_dynamic_angle_1: torch.Tensor | None = None
|
||
|
constraints_dynamic_scale_0: torch.Tensor = torch.ones(
|
||
|
(input.shape[0]), dtype=self.default_dtype, device=self.device
|
||
|
)
|
||
|
constraints_dynamic_scale_1: torch.Tensor | None = None
|
||
|
|
||
|
assert self.reference_image is not None
|
||
|
assert self.reference_image.ndim == 2
|
||
|
assert input.shape[-2] == self.reference_image.shape[-2]
|
||
|
assert input.shape[-1] == self.reference_image.shape[-1]
|
||
|
|
||
|
images_todo = input.to(device=self.device).type(dtype=self.default_dtype)
|
||
|
image_reference = self.reference_image
|
||
|
|
||
|
assert image_reference.ndim == 2
|
||
|
assert images_todo.ndim == 3
|
||
|
|
||
|
_, newangle = self._get_ang_scale(
|
||
|
image_reference,
|
||
|
images_todo,
|
||
|
constraints_dynamic_scale_0,
|
||
|
constraints_dynamic_scale_1,
|
||
|
constraints_dynamic_angle_0,
|
||
|
constraints_dynamic_angle_1,
|
||
|
)
|
||
|
|
||
|
return newangle
|
||
|
|
||
|
# ---------------
|
||
|
|
||
|
def _get_pcorr_shape(self, shape: torch.Size) -> tuple[int, int]:
|
||
|
ret = (int(max(shape[-2:]) * 1.0),) * 2
|
||
|
return ret
|
||
|
|
||
|
def _get_log_base(self, shape: torch.Size, new_r: torch.Tensor) -> torch.Tensor:
|
||
|
old_r = torch.tensor(
|
||
|
(float(shape[-2]) * self.excess_const) / 2.0,
|
||
|
dtype=self.default_dtype,
|
||
|
device=self.device,
|
||
|
)
|
||
|
log_base = torch.exp(torch.log(old_r) / new_r)
|
||
|
return log_base
|
||
|
|
||
|
def wrap_angle(
|
||
|
self, angles: torch.Tensor, ceil: float = 2 * torch.pi
|
||
|
) -> torch.Tensor:
|
||
|
angles += ceil / 2.0
|
||
|
angles %= ceil
|
||
|
angles -= ceil / 2.0
|
||
|
return angles
|
||
|
|
||
|
def get_borderval(
|
||
|
self, img: torch.Tensor, radius: int | None = None
|
||
|
) -> torch.Tensor:
|
||
|
assert img.ndim == 3
|
||
|
if radius is None:
|
||
|
mindim = min([int(img.shape[-2]), int(img.shape[-1])])
|
||
|
radius = max(1, mindim // 20)
|
||
|
mask = torch.zeros(
|
||
|
(int(img.shape[-2]), int(img.shape[-1])),
|
||
|
dtype=torch.bool,
|
||
|
device=self.device,
|
||
|
)
|
||
|
mask[:, :radius] = True
|
||
|
mask[:, -radius:] = True
|
||
|
mask[:radius, :] = True
|
||
|
mask[-radius:, :] = True
|
||
|
|
||
|
mean = torch.median(img[:, mask], dim=-1)[0]
|
||
|
return mean
|
||
|
|
||
|
def get_apofield(self, shape: torch.Size, aporad: int) -> torch.Tensor:
|
||
|
if aporad == 0:
|
||
|
return torch.ones(
|
||
|
shape[-2:],
|
||
|
dtype=self.default_dtype,
|
||
|
device=self.device,
|
||
|
)
|
||
|
|
||
|
assert int(shape[-2]) > aporad * 2
|
||
|
assert int(shape[-1]) > aporad * 2
|
||
|
|
||
|
apos = torch.hann_window(
|
||
|
aporad * 2, dtype=self.default_dtype, periodic=False, device=self.device
|
||
|
)
|
||
|
|
||
|
toapp_0 = torch.ones(
|
||
|
shape[-2],
|
||
|
dtype=self.default_dtype,
|
||
|
device=self.device,
|
||
|
)
|
||
|
toapp_0[:aporad] = apos[:aporad]
|
||
|
toapp_0[-aporad:] = apos[-aporad:]
|
||
|
|
||
|
toapp_1 = torch.ones(
|
||
|
shape[-1],
|
||
|
dtype=self.default_dtype,
|
||
|
device=self.device,
|
||
|
)
|
||
|
toapp_1[:aporad] = apos[:aporad]
|
||
|
toapp_1[-aporad:] = apos[-aporad:]
|
||
|
|
||
|
apofield = torch.outer(toapp_0, toapp_1)
|
||
|
|
||
|
return apofield
|
||
|
|
||
|
def _get_subarr(
|
||
|
self, array: torch.Tensor, center: torch.Tensor, rad: int
|
||
|
) -> torch.Tensor:
|
||
|
assert array.ndim == 3
|
||
|
assert center.ndim == 2
|
||
|
assert array.shape[0] == center.shape[0]
|
||
|
assert center.shape[1] == 2
|
||
|
|
||
|
dim = 1 + 2 * rad
|
||
|
subarr = torch.zeros(
|
||
|
(array.shape[0], dim, dim), dtype=self.default_dtype, device=self.device
|
||
|
)
|
||
|
|
||
|
corner = center - rad
|
||
|
idx_p = range(0, corner.shape[0])
|
||
|
for ii in range(0, dim):
|
||
|
yidx = corner[:, 0] + ii
|
||
|
yidx %= array.shape[-2]
|
||
|
for jj in range(0, dim):
|
||
|
xidx = corner[:, 1] + jj
|
||
|
xidx %= array.shape[-1]
|
||
|
subarr[:, ii, jj] = array[idx_p, yidx, xidx]
|
||
|
|
||
|
return subarr
|
||
|
|
||
|
def _argmax_2d(self, array: torch.Tensor) -> torch.Tensor:
|
||
|
assert array.ndim == 3
|
||
|
|
||
|
max_pos = array.reshape(
|
||
|
(array.shape[0], array.shape[1] * array.shape[2])
|
||
|
).argmax(dim=1)
|
||
|
pos_0 = max_pos // array.shape[2]
|
||
|
max_pos -= pos_0 * array.shape[2]
|
||
|
ret = torch.zeros(
|
||
|
(array.shape[0], 2), dtype=self.default_dtype, device=self.device
|
||
|
)
|
||
|
ret[:, 0] = pos_0
|
||
|
ret[:, 1] = max_pos
|
||
|
return ret.type(dtype=torch.int64)
|
||
|
|
||
|
def _apodize(self, what: torch.Tensor) -> torch.Tensor:
|
||
|
mindim = min([int(what.shape[-2]), int(what.shape[-1])])
|
||
|
aporad = int(mindim * 0.12)
|
||
|
|
||
|
apofield = self.get_apofield(what.shape, aporad).unsqueeze(0)
|
||
|
|
||
|
res = what * apofield
|
||
|
bg = self.get_borderval(what, aporad // 2).unsqueeze(-1).unsqueeze(-1)
|
||
|
res += bg * (1 - apofield)
|
||
|
return res
|
||
|
|
||
|
def _logpolar_filter(self, shape: torch.Size) -> torch.Tensor:
|
||
|
yy = torch.linspace(
|
||
|
-torch.pi / 2.0,
|
||
|
torch.pi / 2.0,
|
||
|
shape[-2],
|
||
|
dtype=self.default_dtype,
|
||
|
device=self.device,
|
||
|
).unsqueeze(1)
|
||
|
|
||
|
xx = torch.linspace(
|
||
|
-torch.pi / 2.0,
|
||
|
torch.pi / 2.0,
|
||
|
shape[-1],
|
||
|
dtype=self.default_dtype,
|
||
|
device=self.device,
|
||
|
).unsqueeze(0)
|
||
|
|
||
|
rads = torch.sqrt(yy**2 + xx**2)
|
||
|
filt = 1.0 - torch.cos(rads) ** 2
|
||
|
|
||
|
filt[torch.abs(rads) > torch.pi / 2] = 1
|
||
|
return filt
|
||
|
|
||
|
def _get_angles(self, shape: torch.Tensor) -> torch.Tensor:
|
||
|
ret = torch.zeros(
|
||
|
(int(shape[-2]), int(shape[-1])),
|
||
|
dtype=self.default_dtype,
|
||
|
device=self.device,
|
||
|
)
|
||
|
ret -= torch.linspace(
|
||
|
0,
|
||
|
torch.pi,
|
||
|
int(shape[-2] + 1),
|
||
|
dtype=self.default_dtype,
|
||
|
device=self.device,
|
||
|
)[:-1].unsqueeze(-1)
|
||
|
|
||
|
return ret
|
||
|
|
||
|
def _get_lograd(self, shape: torch.Tensor, log_base: torch.Tensor) -> torch.Tensor:
|
||
|
ret = torch.zeros(
|
||
|
(int(shape[-2]), int(shape[-1])),
|
||
|
dtype=self.default_dtype,
|
||
|
device=self.device,
|
||
|
)
|
||
|
ret += torch.pow(
|
||
|
log_base,
|
||
|
torch.arange(
|
||
|
0,
|
||
|
int(shape[-1]),
|
||
|
dtype=self.default_dtype,
|
||
|
device=self.device,
|
||
|
),
|
||
|
).unsqueeze(0)
|
||
|
return ret
|
||
|
|
||
|
def _logpolar(
|
||
|
self, image: torch.Tensor, shape: torch.Tensor, log_base: torch.Tensor
|
||
|
) -> torch.Tensor:
|
||
|
assert image.ndim == 3
|
||
|
|
||
|
imshape: torch.Tensor = torch.tensor(
|
||
|
image.shape[-2:],
|
||
|
dtype=self.default_dtype,
|
||
|
device=self.device,
|
||
|
)
|
||
|
|
||
|
center: torch.Tensor = imshape.clone() / 2
|
||
|
|
||
|
theta: torch.Tensor = self._get_angles(shape)
|
||
|
radius_x: torch.Tensor = self._get_lograd(shape, log_base)
|
||
|
radius_y: torch.Tensor = radius_x.clone()
|
||
|
|
||
|
ellipse_coef: torch.Tensor = imshape[0] / imshape[1]
|
||
|
radius_x /= ellipse_coef
|
||
|
|
||
|
y = radius_y * torch.sin(theta) + center[0]
|
||
|
y /= float(image.shape[-2])
|
||
|
y *= 2
|
||
|
y -= 1
|
||
|
|
||
|
x = radius_x * torch.cos(theta) + center[1]
|
||
|
x /= float(image.shape[-1])
|
||
|
x *= 2
|
||
|
x -= 1
|
||
|
|
||
|
idx_x = torch.where(torch.abs(x) <= 1.0, 1.0, 0.0)
|
||
|
idx_y = torch.where(torch.abs(y) <= 1.0, 1.0, 0.0)
|
||
|
|
||
|
normalized_coords = torch.cat(
|
||
|
(
|
||
|
x.unsqueeze(-1),
|
||
|
y.unsqueeze(-1),
|
||
|
),
|
||
|
dim=-1,
|
||
|
).unsqueeze(0)
|
||
|
|
||
|
output = torch.empty(
|
||
|
(int(image.shape[0]), int(y.shape[0]), int(y.shape[1])),
|
||
|
dtype=self.default_dtype,
|
||
|
device=self.device,
|
||
|
)
|
||
|
|
||
|
for id in range(0, int(image.shape[0])):
|
||
|
bgval: torch.Tensor = torch.quantile(image[id, :, :], q=1.0 / 100.0)
|
||
|
|
||
|
temp = torch.nn.functional.grid_sample(
|
||
|
image[id, :, :].unsqueeze(0).unsqueeze(0),
|
||
|
normalized_coords,
|
||
|
mode="bilinear",
|
||
|
padding_mode="zeros",
|
||
|
align_corners=False,
|
||
|
)
|
||
|
|
||
|
output[id, :, :] = torch.where((idx_x * idx_y) == 0.0, bgval, temp)
|
||
|
|
||
|
return output
|
||
|
|
||
|
def _argmax_ext(self, array: torch.Tensor, exponent: float | str) -> torch.Tensor:
|
||
|
assert array.ndim == 3
|
||
|
|
||
|
if exponent == "inf":
|
||
|
ret = self._argmax_2d(array)
|
||
|
else:
|
||
|
assert isinstance(exponent, float) or isinstance(exponent, int)
|
||
|
|
||
|
col = (
|
||
|
torch.arange(
|
||
|
0, array.shape[-2], dtype=self.default_dtype, device=self.device
|
||
|
)
|
||
|
.unsqueeze(-1)
|
||
|
.unsqueeze(0)
|
||
|
)
|
||
|
row = (
|
||
|
torch.arange(
|
||
|
0, array.shape[-1], dtype=self.default_dtype, device=self.device
|
||
|
)
|
||
|
.unsqueeze(0)
|
||
|
.unsqueeze(0)
|
||
|
)
|
||
|
|
||
|
arr2 = torch.pow(array, float(exponent))
|
||
|
arrsum = arr2.sum(dim=-2).sum(dim=-1)
|
||
|
|
||
|
ret = torch.zeros(
|
||
|
(array.shape[0], 2), dtype=self.default_dtype, device=self.device
|
||
|
)
|
||
|
|
||
|
arrprody = (arr2 * col).sum(dim=-1).sum(dim=-1) / arrsum
|
||
|
arrprodx = (arr2 * row).sum(dim=-1).sum(dim=-1) / arrsum
|
||
|
|
||
|
ret[:, 0] = arrprody.squeeze(-1).squeeze(-1)
|
||
|
ret[:, 1] = arrprodx.squeeze(-1).squeeze(-1)
|
||
|
|
||
|
idx = torch.where(arrsum == 0.0)[0]
|
||
|
ret[idx, :] = 0.0
|
||
|
return ret
|
||
|
|
||
|
def _interpolate(
|
||
|
self, array: torch.Tensor, rough: torch.Tensor, rad: int = 2
|
||
|
) -> torch.Tensor:
|
||
|
assert array.ndim == 3
|
||
|
assert rough.ndim == 2
|
||
|
|
||
|
rough = torch.round(rough).type(torch.int64)
|
||
|
|
||
|
surroundings = self._get_subarr(array, rough, rad)
|
||
|
|
||
|
com = self._argmax_ext(surroundings, 1.0)
|
||
|
|
||
|
offset = com - rad
|
||
|
ret = rough + offset
|
||
|
|
||
|
ret += 0.5
|
||
|
ret %= (
|
||
|
torch.tensor(array.shape[-2:], dtype=self.default_dtype, device=self.device)
|
||
|
.type(dtype=torch.int64)
|
||
|
.unsqueeze(0)
|
||
|
)
|
||
|
ret -= 0.5
|
||
|
return ret
|
||
|
|
||
|
def _get_success(
|
||
|
self, array: torch.Tensor, coord: torch.Tensor, radius: int = 2
|
||
|
) -> torch.Tensor:
|
||
|
assert array.ndim == 3
|
||
|
assert coord.ndim == 2
|
||
|
assert array.shape[0] == coord.shape[0]
|
||
|
assert coord.shape[1] == 2
|
||
|
|
||
|
coord = torch.round(coord).type(dtype=torch.int64)
|
||
|
subarr = self._get_subarr(
|
||
|
array, coord, 2
|
||
|
) # Not my fault. They want a 2 there. Not radius
|
||
|
|
||
|
theval = subarr.sum(dim=-1).sum(dim=-1)
|
||
|
|
||
|
theval2 = array[range(0, coord.shape[0]), coord[:, 0], coord[:, 1]]
|
||
|
|
||
|
success = torch.sqrt(theval * theval2)
|
||
|
return success
|
||
|
|
||
|
def _get_constraint_mask(
|
||
|
self,
|
||
|
shape: torch.Size,
|
||
|
log_base: torch.Tensor,
|
||
|
constraints_scale_0: torch.Tensor,
|
||
|
constraints_scale_1: torch.Tensor | None,
|
||
|
constraints_angle_0: torch.Tensor,
|
||
|
constraints_angle_1: torch.Tensor | None,
|
||
|
) -> torch.Tensor:
|
||
|
assert constraints_scale_0 is not None
|
||
|
assert constraints_angle_0 is not None
|
||
|
assert constraints_scale_0.ndim == 1
|
||
|
assert constraints_angle_0.ndim == 1
|
||
|
|
||
|
assert constraints_scale_0.shape[0] == constraints_angle_0.shape[0]
|
||
|
|
||
|
mask: torch.Tensor = torch.ones(
|
||
|
(constraints_scale_0.shape[0], int(shape[-2]), int(shape[-1])),
|
||
|
device=self.device,
|
||
|
dtype=self.default_dtype,
|
||
|
)
|
||
|
|
||
|
scale: torch.Tensor = constraints_scale_0.clone()
|
||
|
if constraints_scale_1 is not None:
|
||
|
sigma: torch.Tensor | None = constraints_scale_1.clone()
|
||
|
else:
|
||
|
sigma = None
|
||
|
|
||
|
scales = torch.fft.ifftshift(
|
||
|
self._get_lograd(
|
||
|
torch.tensor(shape[-2:], device=self.device, dtype=self.default_dtype),
|
||
|
log_base,
|
||
|
)
|
||
|
)
|
||
|
|
||
|
scales *= log_base ** (-shape[-1] / 2.0)
|
||
|
scales = scales.unsqueeze(0) - (1.0 / scale).unsqueeze(-1).unsqueeze(-1)
|
||
|
|
||
|
if sigma is not None:
|
||
|
assert sigma.shape[0] == constraints_scale_0.shape[0]
|
||
|
|
||
|
for p_id in range(0, sigma.shape[0]):
|
||
|
if sigma[p_id] == 0:
|
||
|
ascales = torch.abs(scales[p_id, ...])
|
||
|
scale_min = ascales.min()
|
||
|
binary_mask = torch.where(ascales > scale_min, 0.0, 1.0)
|
||
|
mask[p_id, ...] *= binary_mask
|
||
|
else:
|
||
|
mask[p_id, ...] *= torch.exp(
|
||
|
-(torch.pow(scales[p_id, ...], 2)) / torch.pow(sigma[p_id], 2)
|
||
|
)
|
||
|
|
||
|
angle: torch.Tensor = constraints_angle_0.clone()
|
||
|
if constraints_angle_1 is not None:
|
||
|
sigma = constraints_angle_1.clone()
|
||
|
else:
|
||
|
sigma = None
|
||
|
|
||
|
angles = self._get_angles(
|
||
|
torch.tensor(shape[-2:], device=self.device, dtype=self.default_dtype)
|
||
|
)
|
||
|
|
||
|
angles = angles.unsqueeze(0) + torch.deg2rad(angle).unsqueeze(-1).unsqueeze(-1)
|
||
|
|
||
|
angles = torch.rad2deg(angles)
|
||
|
|
||
|
if sigma is not None:
|
||
|
assert sigma.shape[0] == constraints_scale_0.shape[0]
|
||
|
|
||
|
for p_id in range(0, sigma.shape[0]):
|
||
|
if sigma[p_id] == 0:
|
||
|
aangles = torch.abs(angles[p_id, ...])
|
||
|
angle_min = aangles.min()
|
||
|
binary_mask = torch.where(aangles > angle_min, 0.0, 1.0)
|
||
|
mask[p_id, ...] *= binary_mask
|
||
|
else:
|
||
|
mask *= torch.exp(
|
||
|
-(torch.pow(angles[p_id, ...], 2)) / torch.pow(sigma[p_id], 2)
|
||
|
)
|
||
|
|
||
|
mask = torch.fft.fftshift(mask, dim=(-2, -1))
|
||
|
|
||
|
return mask
|
||
|
|
||
|
def argmax_angscale(
|
||
|
self,
|
||
|
array: torch.Tensor,
|
||
|
log_base: torch.Tensor,
|
||
|
constraints_scale_0: torch.Tensor,
|
||
|
constraints_scale_1: torch.Tensor | None,
|
||
|
constraints_angle_0: torch.Tensor,
|
||
|
constraints_angle_1: torch.Tensor | None,
|
||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||
|
assert array.ndim == 3
|
||
|
assert constraints_scale_0 is not None
|
||
|
assert constraints_angle_0 is not None
|
||
|
assert constraints_scale_0.ndim == 1
|
||
|
assert constraints_angle_0.ndim == 1
|
||
|
|
||
|
mask = self._get_constraint_mask(
|
||
|
array.shape[-2:],
|
||
|
log_base,
|
||
|
constraints_scale_0,
|
||
|
constraints_scale_1,
|
||
|
constraints_angle_0,
|
||
|
constraints_angle_1,
|
||
|
)
|
||
|
|
||
|
array_orig = array.clone()
|
||
|
|
||
|
array *= mask
|
||
|
ret = self._argmax_ext(array, self.exponent)
|
||
|
|
||
|
ret_final = self._interpolate(array, ret)
|
||
|
|
||
|
success = self._get_success(array_orig, ret_final, 0)
|
||
|
|
||
|
return ret_final, success
|
||
|
|
||
|
def argmax_translation(
|
||
|
self, array: torch.Tensor
|
||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||
|
assert array.ndim == 3
|
||
|
|
||
|
array_orig = array.clone()
|
||
|
|
||
|
ashape = torch.tensor(array.shape[-2:], device=self.device).type(
|
||
|
dtype=torch.int64
|
||
|
)
|
||
|
|
||
|
aporad = (ashape // 6).min()
|
||
|
mask2 = self.get_apofield(torch.Size(ashape), aporad).unsqueeze(0)
|
||
|
array *= mask2
|
||
|
|
||
|
tvec = self._argmax_ext(array, "inf")
|
||
|
tvec = self._interpolate(array_orig, tvec)
|
||
|
|
||
|
success = self._get_success(array_orig, tvec, 2)
|
||
|
|
||
|
return tvec, success
|
||
|
|
||
|
def transform_img(
|
||
|
self,
|
||
|
img: torch.Tensor,
|
||
|
scale: torch.Tensor | None = None,
|
||
|
angle: torch.Tensor | None = None,
|
||
|
tvec: torch.Tensor | None = None,
|
||
|
bgval: torch.Tensor | None = None,
|
||
|
) -> torch.Tensor:
|
||
|
assert img.ndim == 3
|
||
|
|
||
|
if scale is None:
|
||
|
scale = torch.ones(
|
||
|
(img.shape[0],), dtype=self.default_dtype, device=self.device
|
||
|
)
|
||
|
assert scale.ndim == 1
|
||
|
assert scale.shape[0] == img.shape[0]
|
||
|
|
||
|
if angle is None:
|
||
|
angle = torch.zeros(
|
||
|
(img.shape[0],), dtype=self.default_dtype, device=self.device
|
||
|
)
|
||
|
assert angle.ndim == 1
|
||
|
assert angle.shape[0] == img.shape[0]
|
||
|
|
||
|
if tvec is None:
|
||
|
tvec = torch.zeros(
|
||
|
(img.shape[0], 2), dtype=self.default_dtype, device=self.device
|
||
|
)
|
||
|
assert tvec.ndim == 2
|
||
|
assert tvec.shape[0] == img.shape[0]
|
||
|
assert tvec.shape[1] == 2
|
||
|
|
||
|
if bgval is None:
|
||
|
bgval = self.get_borderval(img)
|
||
|
assert bgval.ndim == 1
|
||
|
assert bgval.shape[0] == img.shape[0]
|
||
|
|
||
|
# Otherwise we need to decompose it and put it back together
|
||
|
assert torch.is_complex(img) is False
|
||
|
|
||
|
output = torch.zeros_like(img)
|
||
|
|
||
|
for pos in range(0, img.shape[0]):
|
||
|
image_processed = img[pos, :, :].unsqueeze(0).clone()
|
||
|
|
||
|
temp_shift = [
|
||
|
int(round(tvec[pos, 1].item() * self.scale_factor)),
|
||
|
int(round(tvec[pos, 0].item() * self.scale_factor)),
|
||
|
]
|
||
|
|
||
|
image_processed = torch.nn.functional.interpolate(
|
||
|
image_processed.unsqueeze(0),
|
||
|
scale_factor=self.scale_factor,
|
||
|
mode="bilinear",
|
||
|
).squeeze(0)
|
||
|
|
||
|
image_processed = tv.transforms.functional.affine(
|
||
|
img=image_processed,
|
||
|
angle=-float(angle[pos]),
|
||
|
translate=temp_shift,
|
||
|
scale=float(scale[pos]),
|
||
|
shear=[0, 0],
|
||
|
interpolation=tv.transforms.InterpolationMode.BILINEAR,
|
||
|
fill=float(bgval[pos]),
|
||
|
center=None,
|
||
|
)
|
||
|
|
||
|
image_processed = torch.nn.functional.interpolate(
|
||
|
image_processed.unsqueeze(0),
|
||
|
scale_factor=1.0 / self.scale_factor,
|
||
|
mode="bilinear",
|
||
|
).squeeze(0)
|
||
|
|
||
|
image_processed = tv.transforms.functional.center_crop(
|
||
|
image_processed, img.shape[-2:]
|
||
|
)
|
||
|
|
||
|
output[pos, ...] = image_processed.squeeze(0)
|
||
|
|
||
|
return output
|
||
|
|
||
|
def transform_img_dict(
|
||
|
self,
|
||
|
img: torch.Tensor,
|
||
|
scale: torch.Tensor | None = None,
|
||
|
angle: torch.Tensor | None = None,
|
||
|
tvec: torch.Tensor | None = None,
|
||
|
bgval: torch.Tensor | None = None,
|
||
|
invert=False,
|
||
|
) -> torch.Tensor:
|
||
|
if invert:
|
||
|
if scale is not None:
|
||
|
scale = 1.0 / scale
|
||
|
if angle is not None:
|
||
|
angle *= -1
|
||
|
if tvec is not None:
|
||
|
tvec *= -1
|
||
|
|
||
|
res = self.transform_img(img, scale, angle, tvec, bgval=bgval)
|
||
|
return res
|
||
|
|
||
|
def _phase_correlation(
|
||
|
self, image_reference: torch.Tensor, images_todo: torch.Tensor, callback, *args
|
||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||
|
assert image_reference.ndim == 3
|
||
|
assert image_reference.shape[0] == 1
|
||
|
assert images_todo.ndim == 3
|
||
|
|
||
|
assert callback is not None
|
||
|
|
||
|
image_reference_fft = torch.fft.fft2(image_reference, dim=(-2, -1))
|
||
|
images_todo_fft = torch.fft.fft2(images_todo, dim=(-2, -1))
|
||
|
|
||
|
eps = torch.abs(images_todo_fft).max(dim=-1)[0].max(dim=-1)[0] * 1e-15
|
||
|
|
||
|
cps = abs(
|
||
|
torch.fft.ifft2(
|
||
|
(image_reference_fft * images_todo_fft.conj())
|
||
|
/ (
|
||
|
torch.abs(image_reference_fft) * torch.abs(images_todo_fft)
|
||
|
+ eps.unsqueeze(-1).unsqueeze(-1)
|
||
|
)
|
||
|
)
|
||
|
)
|
||
|
|
||
|
scps = torch.fft.fftshift(cps, dim=(-2, -1))
|
||
|
|
||
|
ret, success = callback(scps, *args)
|
||
|
|
||
|
ret[:, 0] -= image_reference_fft.shape[-2] // 2
|
||
|
ret[:, 1] -= image_reference_fft.shape[-1] // 2
|
||
|
|
||
|
return ret, success
|
||
|
|
||
|
def _translation(
|
||
|
self, im0: torch.Tensor, im1: torch.Tensor
|
||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||
|
assert im0.ndim == 2
|
||
|
ret, succ = self._phase_correlation(
|
||
|
im0.unsqueeze(0), im1, self.argmax_translation
|
||
|
)
|
||
|
return ret, succ
|
||
|
|
||
|
def _get_ang_scale(
|
||
|
self,
|
||
|
image_reference: torch.Tensor,
|
||
|
images_todo: torch.Tensor,
|
||
|
constraints_scale_0: torch.Tensor,
|
||
|
constraints_scale_1: torch.Tensor | None,
|
||
|
constraints_angle_0: torch.Tensor,
|
||
|
constraints_angle_1: torch.Tensor | None,
|
||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||
|
assert image_reference.ndim == 2
|
||
|
assert images_todo.ndim == 3
|
||
|
assert image_reference.shape[-1] == images_todo.shape[-1]
|
||
|
assert image_reference.shape[-2] == images_todo.shape[-2]
|
||
|
assert constraints_scale_0.shape[0] == images_todo.shape[0]
|
||
|
assert constraints_angle_0.shape[0] == images_todo.shape[0]
|
||
|
|
||
|
if constraints_scale_1 is not None:
|
||
|
assert constraints_scale_1.shape[0] == images_todo.shape[0]
|
||
|
|
||
|
if constraints_angle_1 is not None:
|
||
|
assert constraints_angle_1.shape[0] == images_todo.shape[0]
|
||
|
|
||
|
if self.image_reference_dft is None:
|
||
|
image_reference_apod = self._apodize(image_reference.unsqueeze(0))
|
||
|
self.image_reference_dft = torch.fft.fftshift(
|
||
|
torch.fft.fft2(image_reference_apod, dim=(-2, -1)), dim=(-2, -1)
|
||
|
)
|
||
|
self.filt = self._logpolar_filter(image_reference.shape).unsqueeze(0)
|
||
|
self.image_reference_dft *= self.filt
|
||
|
self.pcorr_shape = torch.tensor(
|
||
|
self._get_pcorr_shape(image_reference.shape[-2:]),
|
||
|
dtype=self.default_dtype,
|
||
|
device=self.device,
|
||
|
)
|
||
|
self.log_base = self._get_log_base(
|
||
|
image_reference.shape,
|
||
|
self.pcorr_shape[1],
|
||
|
)
|
||
|
self.image_reference_logp = self._logpolar(
|
||
|
torch.abs(self.image_reference_dft), self.pcorr_shape, self.log_base
|
||
|
)
|
||
|
|
||
|
images_todo_apod = self._apodize(images_todo)
|
||
|
images_todo_dft = torch.fft.fftshift(
|
||
|
torch.fft.fft2(images_todo_apod, dim=(-2, -1)), dim=(-2, -1)
|
||
|
)
|
||
|
|
||
|
images_todo_dft *= self.filt
|
||
|
|
||
|
images_todo_lopg = self._logpolar(
|
||
|
torch.abs(images_todo_dft), self.pcorr_shape, self.log_base
|
||
|
)
|
||
|
|
||
|
temp, _ = self._phase_correlation(
|
||
|
self.image_reference_logp,
|
||
|
images_todo_lopg,
|
||
|
self.argmax_angscale,
|
||
|
self.log_base,
|
||
|
constraints_scale_0,
|
||
|
constraints_scale_1,
|
||
|
constraints_angle_0,
|
||
|
constraints_angle_1,
|
||
|
)
|
||
|
|
||
|
arg_ang = temp[:, 0].clone()
|
||
|
arg_rad = temp[:, 1].clone()
|
||
|
|
||
|
angle = -torch.pi * arg_ang / float(self.pcorr_shape[0])
|
||
|
angle = torch.rad2deg(angle)
|
||
|
|
||
|
angle = self.wrap_angle(angle, 360)
|
||
|
|
||
|
scale = torch.pow(self.log_base, arg_rad)
|
||
|
|
||
|
angle = -angle
|
||
|
scale = 1.0 / scale
|
||
|
|
||
|
assert torch.where(scale < 2)[0].shape[0] == scale.shape[0]
|
||
|
assert torch.where(scale > 0.5)[0].shape[0] == scale.shape[0]
|
||
|
|
||
|
return scale, angle
|
||
|
|
||
|
def translation(
|
||
|
self, im0: torch.Tensor, im1: torch.Tensor
|
||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
|
angle = torch.zeros(
|
||
|
(im1.shape[0]), dtype=self.default_dtype, device=self.device
|
||
|
)
|
||
|
assert im1.ndim == 3
|
||
|
assert im0.shape[-2] == im1.shape[-2]
|
||
|
assert im0.shape[-1] == im1.shape[-1]
|
||
|
|
||
|
tvec, succ = self._translation(im0, im1)
|
||
|
tvec2, succ2 = self._translation(im0, torch.rot90(im1, k=2, dims=[-2, -1]))
|
||
|
|
||
|
assert tvec.shape[0] == tvec2.shape[0]
|
||
|
assert tvec.ndim == 2
|
||
|
assert tvec2.ndim == 2
|
||
|
assert tvec.shape[1] == 2
|
||
|
assert tvec2.shape[1] == 2
|
||
|
assert succ.shape[0] == succ2.shape[0]
|
||
|
assert succ.ndim == 1
|
||
|
assert succ2.ndim == 1
|
||
|
assert tvec.shape[0] == succ.shape[0]
|
||
|
assert angle.shape[0] == tvec.shape[0]
|
||
|
assert angle.ndim == 1
|
||
|
|
||
|
for pos in range(0, angle.shape[0]):
|
||
|
pick_rotated = False
|
||
|
if succ2[pos] > succ[pos]:
|
||
|
pick_rotated = True
|
||
|
|
||
|
if pick_rotated:
|
||
|
tvec[pos, :] = tvec2[pos, :]
|
||
|
succ[pos] = succ2[pos]
|
||
|
angle[pos] += 180
|
||
|
|
||
|
return tvec, succ, angle
|
||
|
|
||
|
def _similarity(
|
||
|
self,
|
||
|
image_reference: torch.Tensor,
|
||
|
images_todo: torch.Tensor,
|
||
|
bgval: torch.Tensor,
|
||
|
):
|
||
|
assert image_reference.ndim == 2
|
||
|
assert images_todo.ndim == 3
|
||
|
assert image_reference.shape[-1] == images_todo.shape[-1]
|
||
|
assert image_reference.shape[-2] == images_todo.shape[-2]
|
||
|
|
||
|
# We are going to iterate and precise scale and angle estimates
|
||
|
scale: torch.Tensor = torch.ones(
|
||
|
(images_todo.shape[0]), dtype=self.default_dtype, device=self.device
|
||
|
)
|
||
|
angle: torch.Tensor = torch.zeros(
|
||
|
(images_todo.shape[0]), dtype=self.default_dtype, device=self.device
|
||
|
)
|
||
|
|
||
|
constraints_dynamic_angle_0: torch.Tensor = torch.zeros(
|
||
|
(images_todo.shape[0]), dtype=self.default_dtype, device=self.device
|
||
|
)
|
||
|
constraints_dynamic_angle_1: torch.Tensor | None = None
|
||
|
constraints_dynamic_scale_0: torch.Tensor = torch.ones(
|
||
|
(images_todo.shape[0]), dtype=self.default_dtype, device=self.device
|
||
|
)
|
||
|
constraints_dynamic_scale_1: torch.Tensor | None = None
|
||
|
|
||
|
newscale, newangle = self._get_ang_scale(
|
||
|
image_reference,
|
||
|
images_todo,
|
||
|
constraints_dynamic_scale_0,
|
||
|
constraints_dynamic_scale_1,
|
||
|
constraints_dynamic_angle_0,
|
||
|
constraints_dynamic_angle_1,
|
||
|
)
|
||
|
scale *= newscale
|
||
|
angle += newangle
|
||
|
|
||
|
im2 = self.transform_img(images_todo, scale, angle, bgval=bgval)
|
||
|
|
||
|
tvec, self.success, res_angle = self.translation(image_reference, im2)
|
||
|
|
||
|
angle += res_angle
|
||
|
|
||
|
angle = self.wrap_angle(angle, 360)
|
||
|
|
||
|
return scale, angle, tvec
|
||
|
|
||
|
def similarity(
|
||
|
self,
|
||
|
image_reference: torch.Tensor,
|
||
|
images_todo: torch.Tensor,
|
||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
|
assert image_reference.ndim == 2
|
||
|
assert images_todo.ndim == 3
|
||
|
|
||
|
bgval: torch.Tensor = self.get_borderval(img=images_todo, radius=5)
|
||
|
|
||
|
scale, angle, tvec = self._similarity(
|
||
|
image_reference,
|
||
|
images_todo,
|
||
|
bgval,
|
||
|
)
|
||
|
|
||
|
im2 = self.transform_img_dict(
|
||
|
img=images_todo,
|
||
|
scale=scale,
|
||
|
angle=angle,
|
||
|
tvec=tvec,
|
||
|
bgval=bgval,
|
||
|
)
|
||
|
|
||
|
return scale, angle, tvec, im2
|