gevi/functions/ImageAlignment.py

1016 lines
32 KiB
Python
Raw Permalink Normal View History

2024-02-28 16:14:50 +01:00
import torch
import torchvision as tv # type: ignore
# The source code is based on:
# https://github.com/matejak/imreg_dft
# The original LICENSE:
# Copyright (c) 2014, Matěj Týč
# Copyright (c) 2011-2014, Christoph Gohlke
# Copyright (c) 2011-2014, The Regents of the University of California
# Produced at the Laboratory for Fluorescence Dynamics
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of the {organization} nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
class ImageAlignment(torch.nn.Module):
device: torch.device
default_dtype: torch.dtype
excess_const: float = 1.1
exponent: str = "inf"
success: torch.Tensor | None = None
# The factor that detmines how many
# sub-pixel we will shift
scale_factor: int = 4
reference_image: torch.Tensor | None = None
last_scale: torch.Tensor | None = None
last_angle: torch.Tensor | None = None
last_tvec: torch.Tensor | None = None
# Cache
image_reference_dft: torch.Tensor | None = None
filt: torch.Tensor
pcorr_shape: torch.Tensor
log_base: torch.Tensor
image_reference_logp: torch.Tensor
def __init__(
self,
device: torch.device | None = None,
default_dtype: torch.dtype | None = None,
) -> None:
super().__init__()
assert device is not None
assert default_dtype is not None
self.device = device
self.default_dtype = default_dtype
def set_new_reference_image(self, new_reference_image: torch.Tensor | None = None):
assert new_reference_image is not None
assert new_reference_image.ndim == 2
self.reference_image = (
new_reference_image.detach()
.clone()
.to(device=self.device)
.type(dtype=self.default_dtype)
)
self.image_reference_dft = None
def forward(
self, input: torch.Tensor, new_reference_image: torch.Tensor | None = None
) -> torch.Tensor:
assert input.ndim == 3
if new_reference_image is not None:
self.set_new_reference_image(new_reference_image)
assert self.reference_image is not None
assert self.reference_image.ndim == 2
assert input.shape[-2] == self.reference_image.shape[-2]
assert input.shape[-1] == self.reference_image.shape[-1]
self.last_scale, self.last_angle, self.last_tvec, output = self.similarity(
self.reference_image,
input.to(device=self.device).type(dtype=self.default_dtype),
)
return output
def dry_run(
self, input: torch.Tensor, new_reference_image: torch.Tensor | None = None
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
assert input.ndim == 3
if new_reference_image is not None:
self.set_new_reference_image(new_reference_image)
assert self.reference_image is not None
assert self.reference_image.ndim == 2
assert input.shape[-2] == self.reference_image.shape[-2]
assert input.shape[-1] == self.reference_image.shape[-1]
images_todo = input.to(device=self.device).type(dtype=self.default_dtype)
image_reference = self.reference_image
assert image_reference.ndim == 2
assert images_todo.ndim == 3
bgval: torch.Tensor = self.get_borderval(img=images_todo, radius=5)
self.last_scale, self.last_angle, self.last_tvec = self._similarity(
image_reference,
images_todo,
bgval,
)
return self.last_scale, self.last_angle, self.last_tvec
def dry_run_translation(
self, input: torch.Tensor, new_reference_image: torch.Tensor | None = None
) -> torch.Tensor:
assert input.ndim == 3
if new_reference_image is not None:
self.set_new_reference_image(new_reference_image)
assert self.reference_image is not None
assert self.reference_image.ndim == 2
assert input.shape[-2] == self.reference_image.shape[-2]
assert input.shape[-1] == self.reference_image.shape[-1]
images_todo = input.to(device=self.device).type(dtype=self.default_dtype)
image_reference = self.reference_image
assert image_reference.ndim == 2
assert images_todo.ndim == 3
tvec, _ = self._translation(image_reference, images_todo)
return tvec
# ---------------
def dry_run_angle(
self,
input: torch.Tensor,
new_reference_image: torch.Tensor | None = None,
) -> torch.Tensor:
assert input.ndim == 3
if new_reference_image is not None:
self.set_new_reference_image(new_reference_image)
constraints_dynamic_angle_0: torch.Tensor = torch.zeros(
(input.shape[0]), dtype=self.default_dtype, device=self.device
)
constraints_dynamic_angle_1: torch.Tensor | None = None
constraints_dynamic_scale_0: torch.Tensor = torch.ones(
(input.shape[0]), dtype=self.default_dtype, device=self.device
)
constraints_dynamic_scale_1: torch.Tensor | None = None
assert self.reference_image is not None
assert self.reference_image.ndim == 2
assert input.shape[-2] == self.reference_image.shape[-2]
assert input.shape[-1] == self.reference_image.shape[-1]
images_todo = input.to(device=self.device).type(dtype=self.default_dtype)
image_reference = self.reference_image
assert image_reference.ndim == 2
assert images_todo.ndim == 3
_, newangle = self._get_ang_scale(
image_reference,
images_todo,
constraints_dynamic_scale_0,
constraints_dynamic_scale_1,
constraints_dynamic_angle_0,
constraints_dynamic_angle_1,
)
return newangle
# ---------------
def _get_pcorr_shape(self, shape: torch.Size) -> tuple[int, int]:
ret = (int(max(shape[-2:]) * 1.0),) * 2
return ret
def _get_log_base(self, shape: torch.Size, new_r: torch.Tensor) -> torch.Tensor:
old_r = torch.tensor(
(float(shape[-2]) * self.excess_const) / 2.0,
dtype=self.default_dtype,
device=self.device,
)
log_base = torch.exp(torch.log(old_r) / new_r)
return log_base
def wrap_angle(
self, angles: torch.Tensor, ceil: float = 2 * torch.pi
) -> torch.Tensor:
angles += ceil / 2.0
angles %= ceil
angles -= ceil / 2.0
return angles
def get_borderval(
self, img: torch.Tensor, radius: int | None = None
) -> torch.Tensor:
assert img.ndim == 3
if radius is None:
mindim = min([int(img.shape[-2]), int(img.shape[-1])])
radius = max(1, mindim // 20)
mask = torch.zeros(
(int(img.shape[-2]), int(img.shape[-1])),
dtype=torch.bool,
device=self.device,
)
mask[:, :radius] = True
mask[:, -radius:] = True
mask[:radius, :] = True
mask[-radius:, :] = True
mean = torch.median(img[:, mask], dim=-1)[0]
return mean
def get_apofield(self, shape: torch.Size, aporad: int) -> torch.Tensor:
if aporad == 0:
return torch.ones(
shape[-2:],
dtype=self.default_dtype,
device=self.device,
)
assert int(shape[-2]) > aporad * 2
assert int(shape[-1]) > aporad * 2
apos = torch.hann_window(
aporad * 2, dtype=self.default_dtype, periodic=False, device=self.device
)
toapp_0 = torch.ones(
shape[-2],
dtype=self.default_dtype,
device=self.device,
)
toapp_0[:aporad] = apos[:aporad]
toapp_0[-aporad:] = apos[-aporad:]
toapp_1 = torch.ones(
shape[-1],
dtype=self.default_dtype,
device=self.device,
)
toapp_1[:aporad] = apos[:aporad]
toapp_1[-aporad:] = apos[-aporad:]
apofield = torch.outer(toapp_0, toapp_1)
return apofield
def _get_subarr(
self, array: torch.Tensor, center: torch.Tensor, rad: int
) -> torch.Tensor:
assert array.ndim == 3
assert center.ndim == 2
assert array.shape[0] == center.shape[0]
assert center.shape[1] == 2
dim = 1 + 2 * rad
subarr = torch.zeros(
(array.shape[0], dim, dim), dtype=self.default_dtype, device=self.device
)
corner = center - rad
idx_p = range(0, corner.shape[0])
for ii in range(0, dim):
yidx = corner[:, 0] + ii
yidx %= array.shape[-2]
for jj in range(0, dim):
xidx = corner[:, 1] + jj
xidx %= array.shape[-1]
subarr[:, ii, jj] = array[idx_p, yidx, xidx]
return subarr
def _argmax_2d(self, array: torch.Tensor) -> torch.Tensor:
assert array.ndim == 3
max_pos = array.reshape(
(array.shape[0], array.shape[1] * array.shape[2])
).argmax(dim=1)
pos_0 = max_pos // array.shape[2]
max_pos -= pos_0 * array.shape[2]
ret = torch.zeros(
(array.shape[0], 2), dtype=self.default_dtype, device=self.device
)
ret[:, 0] = pos_0
ret[:, 1] = max_pos
return ret.type(dtype=torch.int64)
def _apodize(self, what: torch.Tensor) -> torch.Tensor:
mindim = min([int(what.shape[-2]), int(what.shape[-1])])
aporad = int(mindim * 0.12)
apofield = self.get_apofield(what.shape, aporad).unsqueeze(0)
res = what * apofield
bg = self.get_borderval(what, aporad // 2).unsqueeze(-1).unsqueeze(-1)
res += bg * (1 - apofield)
return res
def _logpolar_filter(self, shape: torch.Size) -> torch.Tensor:
yy = torch.linspace(
-torch.pi / 2.0,
torch.pi / 2.0,
shape[-2],
dtype=self.default_dtype,
device=self.device,
).unsqueeze(1)
xx = torch.linspace(
-torch.pi / 2.0,
torch.pi / 2.0,
shape[-1],
dtype=self.default_dtype,
device=self.device,
).unsqueeze(0)
rads = torch.sqrt(yy**2 + xx**2)
filt = 1.0 - torch.cos(rads) ** 2
filt[torch.abs(rads) > torch.pi / 2] = 1
return filt
def _get_angles(self, shape: torch.Tensor) -> torch.Tensor:
ret = torch.zeros(
(int(shape[-2]), int(shape[-1])),
dtype=self.default_dtype,
device=self.device,
)
ret -= torch.linspace(
0,
torch.pi,
int(shape[-2] + 1),
dtype=self.default_dtype,
device=self.device,
)[:-1].unsqueeze(-1)
return ret
def _get_lograd(self, shape: torch.Tensor, log_base: torch.Tensor) -> torch.Tensor:
ret = torch.zeros(
(int(shape[-2]), int(shape[-1])),
dtype=self.default_dtype,
device=self.device,
)
ret += torch.pow(
log_base,
torch.arange(
0,
int(shape[-1]),
dtype=self.default_dtype,
device=self.device,
),
).unsqueeze(0)
return ret
def _logpolar(
self, image: torch.Tensor, shape: torch.Tensor, log_base: torch.Tensor
) -> torch.Tensor:
assert image.ndim == 3
imshape: torch.Tensor = torch.tensor(
image.shape[-2:],
dtype=self.default_dtype,
device=self.device,
)
center: torch.Tensor = imshape.clone() / 2
theta: torch.Tensor = self._get_angles(shape)
radius_x: torch.Tensor = self._get_lograd(shape, log_base)
radius_y: torch.Tensor = radius_x.clone()
ellipse_coef: torch.Tensor = imshape[0] / imshape[1]
radius_x /= ellipse_coef
y = radius_y * torch.sin(theta) + center[0]
y /= float(image.shape[-2])
y *= 2
y -= 1
x = radius_x * torch.cos(theta) + center[1]
x /= float(image.shape[-1])
x *= 2
x -= 1
idx_x = torch.where(torch.abs(x) <= 1.0, 1.0, 0.0)
idx_y = torch.where(torch.abs(y) <= 1.0, 1.0, 0.0)
normalized_coords = torch.cat(
(
x.unsqueeze(-1),
y.unsqueeze(-1),
),
dim=-1,
).unsqueeze(0)
output = torch.empty(
(int(image.shape[0]), int(y.shape[0]), int(y.shape[1])),
dtype=self.default_dtype,
device=self.device,
)
for id in range(0, int(image.shape[0])):
bgval: torch.Tensor = torch.quantile(image[id, :, :], q=1.0 / 100.0)
temp = torch.nn.functional.grid_sample(
image[id, :, :].unsqueeze(0).unsqueeze(0),
normalized_coords,
mode="bilinear",
padding_mode="zeros",
align_corners=False,
)
output[id, :, :] = torch.where((idx_x * idx_y) == 0.0, bgval, temp)
return output
def _argmax_ext(self, array: torch.Tensor, exponent: float | str) -> torch.Tensor:
assert array.ndim == 3
if exponent == "inf":
ret = self._argmax_2d(array)
else:
assert isinstance(exponent, float) or isinstance(exponent, int)
col = (
torch.arange(
0, array.shape[-2], dtype=self.default_dtype, device=self.device
)
.unsqueeze(-1)
.unsqueeze(0)
)
row = (
torch.arange(
0, array.shape[-1], dtype=self.default_dtype, device=self.device
)
.unsqueeze(0)
.unsqueeze(0)
)
arr2 = torch.pow(array, float(exponent))
arrsum = arr2.sum(dim=-2).sum(dim=-1)
ret = torch.zeros(
(array.shape[0], 2), dtype=self.default_dtype, device=self.device
)
arrprody = (arr2 * col).sum(dim=-1).sum(dim=-1) / arrsum
arrprodx = (arr2 * row).sum(dim=-1).sum(dim=-1) / arrsum
ret[:, 0] = arrprody.squeeze(-1).squeeze(-1)
ret[:, 1] = arrprodx.squeeze(-1).squeeze(-1)
idx = torch.where(arrsum == 0.0)[0]
ret[idx, :] = 0.0
return ret
def _interpolate(
self, array: torch.Tensor, rough: torch.Tensor, rad: int = 2
) -> torch.Tensor:
assert array.ndim == 3
assert rough.ndim == 2
rough = torch.round(rough).type(torch.int64)
surroundings = self._get_subarr(array, rough, rad)
com = self._argmax_ext(surroundings, 1.0)
offset = com - rad
ret = rough + offset
ret += 0.5
ret %= (
torch.tensor(array.shape[-2:], dtype=self.default_dtype, device=self.device)
.type(dtype=torch.int64)
.unsqueeze(0)
)
ret -= 0.5
return ret
def _get_success(
self, array: torch.Tensor, coord: torch.Tensor, radius: int = 2
) -> torch.Tensor:
assert array.ndim == 3
assert coord.ndim == 2
assert array.shape[0] == coord.shape[0]
assert coord.shape[1] == 2
coord = torch.round(coord).type(dtype=torch.int64)
subarr = self._get_subarr(
array, coord, 2
) # Not my fault. They want a 2 there. Not radius
theval = subarr.sum(dim=-1).sum(dim=-1)
theval2 = array[range(0, coord.shape[0]), coord[:, 0], coord[:, 1]]
success = torch.sqrt(theval * theval2)
return success
def _get_constraint_mask(
self,
shape: torch.Size,
log_base: torch.Tensor,
constraints_scale_0: torch.Tensor,
constraints_scale_1: torch.Tensor | None,
constraints_angle_0: torch.Tensor,
constraints_angle_1: torch.Tensor | None,
) -> torch.Tensor:
assert constraints_scale_0 is not None
assert constraints_angle_0 is not None
assert constraints_scale_0.ndim == 1
assert constraints_angle_0.ndim == 1
assert constraints_scale_0.shape[0] == constraints_angle_0.shape[0]
mask: torch.Tensor = torch.ones(
(constraints_scale_0.shape[0], int(shape[-2]), int(shape[-1])),
device=self.device,
dtype=self.default_dtype,
)
scale: torch.Tensor = constraints_scale_0.clone()
if constraints_scale_1 is not None:
sigma: torch.Tensor | None = constraints_scale_1.clone()
else:
sigma = None
scales = torch.fft.ifftshift(
self._get_lograd(
torch.tensor(shape[-2:], device=self.device, dtype=self.default_dtype),
log_base,
)
)
scales *= log_base ** (-shape[-1] / 2.0)
scales = scales.unsqueeze(0) - (1.0 / scale).unsqueeze(-1).unsqueeze(-1)
if sigma is not None:
assert sigma.shape[0] == constraints_scale_0.shape[0]
for p_id in range(0, sigma.shape[0]):
if sigma[p_id] == 0:
ascales = torch.abs(scales[p_id, ...])
scale_min = ascales.min()
binary_mask = torch.where(ascales > scale_min, 0.0, 1.0)
mask[p_id, ...] *= binary_mask
else:
mask[p_id, ...] *= torch.exp(
-(torch.pow(scales[p_id, ...], 2)) / torch.pow(sigma[p_id], 2)
)
angle: torch.Tensor = constraints_angle_0.clone()
if constraints_angle_1 is not None:
sigma = constraints_angle_1.clone()
else:
sigma = None
angles = self._get_angles(
torch.tensor(shape[-2:], device=self.device, dtype=self.default_dtype)
)
angles = angles.unsqueeze(0) + torch.deg2rad(angle).unsqueeze(-1).unsqueeze(-1)
angles = torch.rad2deg(angles)
if sigma is not None:
assert sigma.shape[0] == constraints_scale_0.shape[0]
for p_id in range(0, sigma.shape[0]):
if sigma[p_id] == 0:
aangles = torch.abs(angles[p_id, ...])
angle_min = aangles.min()
binary_mask = torch.where(aangles > angle_min, 0.0, 1.0)
mask[p_id, ...] *= binary_mask
else:
mask *= torch.exp(
-(torch.pow(angles[p_id, ...], 2)) / torch.pow(sigma[p_id], 2)
)
mask = torch.fft.fftshift(mask, dim=(-2, -1))
return mask
def argmax_angscale(
self,
array: torch.Tensor,
log_base: torch.Tensor,
constraints_scale_0: torch.Tensor,
constraints_scale_1: torch.Tensor | None,
constraints_angle_0: torch.Tensor,
constraints_angle_1: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
assert array.ndim == 3
assert constraints_scale_0 is not None
assert constraints_angle_0 is not None
assert constraints_scale_0.ndim == 1
assert constraints_angle_0.ndim == 1
mask = self._get_constraint_mask(
array.shape[-2:],
log_base,
constraints_scale_0,
constraints_scale_1,
constraints_angle_0,
constraints_angle_1,
)
array_orig = array.clone()
array *= mask
ret = self._argmax_ext(array, self.exponent)
ret_final = self._interpolate(array, ret)
success = self._get_success(array_orig, ret_final, 0)
return ret_final, success
def argmax_translation(
self, array: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
assert array.ndim == 3
array_orig = array.clone()
ashape = torch.tensor(array.shape[-2:], device=self.device).type(
dtype=torch.int64
)
aporad = (ashape // 6).min()
mask2 = self.get_apofield(torch.Size(ashape), aporad).unsqueeze(0)
array *= mask2
tvec = self._argmax_ext(array, "inf")
tvec = self._interpolate(array_orig, tvec)
success = self._get_success(array_orig, tvec, 2)
return tvec, success
def transform_img(
self,
img: torch.Tensor,
scale: torch.Tensor | None = None,
angle: torch.Tensor | None = None,
tvec: torch.Tensor | None = None,
bgval: torch.Tensor | None = None,
) -> torch.Tensor:
assert img.ndim == 3
if scale is None:
scale = torch.ones(
(img.shape[0],), dtype=self.default_dtype, device=self.device
)
assert scale.ndim == 1
assert scale.shape[0] == img.shape[0]
if angle is None:
angle = torch.zeros(
(img.shape[0],), dtype=self.default_dtype, device=self.device
)
assert angle.ndim == 1
assert angle.shape[0] == img.shape[0]
if tvec is None:
tvec = torch.zeros(
(img.shape[0], 2), dtype=self.default_dtype, device=self.device
)
assert tvec.ndim == 2
assert tvec.shape[0] == img.shape[0]
assert tvec.shape[1] == 2
if bgval is None:
bgval = self.get_borderval(img)
assert bgval.ndim == 1
assert bgval.shape[0] == img.shape[0]
# Otherwise we need to decompose it and put it back together
assert torch.is_complex(img) is False
output = torch.zeros_like(img)
for pos in range(0, img.shape[0]):
image_processed = img[pos, :, :].unsqueeze(0).clone()
temp_shift = [
int(round(tvec[pos, 1].item() * self.scale_factor)),
int(round(tvec[pos, 0].item() * self.scale_factor)),
]
image_processed = torch.nn.functional.interpolate(
image_processed.unsqueeze(0),
scale_factor=self.scale_factor,
mode="bilinear",
).squeeze(0)
image_processed = tv.transforms.functional.affine(
img=image_processed,
angle=-float(angle[pos]),
translate=temp_shift,
scale=float(scale[pos]),
shear=[0, 0],
interpolation=tv.transforms.InterpolationMode.BILINEAR,
fill=float(bgval[pos]),
center=None,
)
image_processed = torch.nn.functional.interpolate(
image_processed.unsqueeze(0),
scale_factor=1.0 / self.scale_factor,
mode="bilinear",
).squeeze(0)
image_processed = tv.transforms.functional.center_crop(
image_processed, img.shape[-2:]
)
output[pos, ...] = image_processed.squeeze(0)
return output
def transform_img_dict(
self,
img: torch.Tensor,
scale: torch.Tensor | None = None,
angle: torch.Tensor | None = None,
tvec: torch.Tensor | None = None,
bgval: torch.Tensor | None = None,
invert=False,
) -> torch.Tensor:
if invert:
if scale is not None:
scale = 1.0 / scale
if angle is not None:
angle *= -1
if tvec is not None:
tvec *= -1
res = self.transform_img(img, scale, angle, tvec, bgval=bgval)
return res
def _phase_correlation(
self, image_reference: torch.Tensor, images_todo: torch.Tensor, callback, *args
) -> tuple[torch.Tensor, torch.Tensor]:
assert image_reference.ndim == 3
assert image_reference.shape[0] == 1
assert images_todo.ndim == 3
assert callback is not None
image_reference_fft = torch.fft.fft2(image_reference, dim=(-2, -1))
images_todo_fft = torch.fft.fft2(images_todo, dim=(-2, -1))
eps = torch.abs(images_todo_fft).max(dim=-1)[0].max(dim=-1)[0] * 1e-15
cps = abs(
torch.fft.ifft2(
(image_reference_fft * images_todo_fft.conj())
/ (
torch.abs(image_reference_fft) * torch.abs(images_todo_fft)
+ eps.unsqueeze(-1).unsqueeze(-1)
),
dim=(-2, -1),
)
)
scps = torch.fft.fftshift(cps, dim=(-2, -1))
ret, success = callback(scps, *args)
ret[:, 0] -= image_reference_fft.shape[-2] // 2
ret[:, 1] -= image_reference_fft.shape[-1] // 2
return ret, success
def _translation(
self, im0: torch.Tensor, im1: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
assert im0.ndim == 2
ret, succ = self._phase_correlation(
im0.unsqueeze(0), im1, self.argmax_translation
)
return ret, succ
def _get_ang_scale(
self,
image_reference: torch.Tensor,
images_todo: torch.Tensor,
constraints_scale_0: torch.Tensor,
constraints_scale_1: torch.Tensor | None,
constraints_angle_0: torch.Tensor,
constraints_angle_1: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
assert image_reference.ndim == 2
assert images_todo.ndim == 3
assert image_reference.shape[-1] == images_todo.shape[-1]
assert image_reference.shape[-2] == images_todo.shape[-2]
assert constraints_scale_0.shape[0] == images_todo.shape[0]
assert constraints_angle_0.shape[0] == images_todo.shape[0]
if constraints_scale_1 is not None:
assert constraints_scale_1.shape[0] == images_todo.shape[0]
if constraints_angle_1 is not None:
assert constraints_angle_1.shape[0] == images_todo.shape[0]
if self.image_reference_dft is None:
image_reference_apod = self._apodize(image_reference.unsqueeze(0))
self.image_reference_dft = torch.fft.fftshift(
torch.fft.fft2(image_reference_apod, dim=(-2, -1)), dim=(-2, -1)
)
self.filt = self._logpolar_filter(image_reference.shape).unsqueeze(0)
self.image_reference_dft *= self.filt
self.pcorr_shape = torch.tensor(
self._get_pcorr_shape(image_reference.shape[-2:]),
dtype=self.default_dtype,
device=self.device,
)
self.log_base = self._get_log_base(
image_reference.shape,
self.pcorr_shape[1],
)
self.image_reference_logp = self._logpolar(
torch.abs(self.image_reference_dft), self.pcorr_shape, self.log_base
)
images_todo_apod = self._apodize(images_todo)
images_todo_dft = torch.fft.fftshift(
torch.fft.fft2(images_todo_apod, dim=(-2, -1)), dim=(-2, -1)
)
images_todo_dft *= self.filt
images_todo_lopg = self._logpolar(
torch.abs(images_todo_dft), self.pcorr_shape, self.log_base
)
temp, _ = self._phase_correlation(
self.image_reference_logp,
images_todo_lopg,
self.argmax_angscale,
self.log_base,
constraints_scale_0,
constraints_scale_1,
constraints_angle_0,
constraints_angle_1,
)
arg_ang = temp[:, 0].clone()
arg_rad = temp[:, 1].clone()
angle = -torch.pi * arg_ang / float(self.pcorr_shape[0])
angle = torch.rad2deg(angle)
angle = self.wrap_angle(angle, 360)
scale = torch.pow(self.log_base, arg_rad)
angle = -angle
scale = 1.0 / scale
assert torch.where(scale < 2)[0].shape[0] == scale.shape[0]
assert torch.where(scale > 0.5)[0].shape[0] == scale.shape[0]
return scale, angle
def translation(
self, im0: torch.Tensor, im1: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
angle = torch.zeros(
(im1.shape[0]), dtype=self.default_dtype, device=self.device
)
assert im1.ndim == 3
assert im0.shape[-2] == im1.shape[-2]
assert im0.shape[-1] == im1.shape[-1]
tvec, succ = self._translation(im0, im1)
tvec2, succ2 = self._translation(im0, torch.rot90(im1, k=2, dims=[-2, -1]))
assert tvec.shape[0] == tvec2.shape[0]
assert tvec.ndim == 2
assert tvec2.ndim == 2
assert tvec.shape[1] == 2
assert tvec2.shape[1] == 2
assert succ.shape[0] == succ2.shape[0]
assert succ.ndim == 1
assert succ2.ndim == 1
assert tvec.shape[0] == succ.shape[0]
assert angle.shape[0] == tvec.shape[0]
assert angle.ndim == 1
for pos in range(0, angle.shape[0]):
pick_rotated = False
if succ2[pos] > succ[pos]:
pick_rotated = True
if pick_rotated:
tvec[pos, :] = tvec2[pos, :]
succ[pos] = succ2[pos]
angle[pos] += 180
return tvec, succ, angle
def _similarity(
self,
image_reference: torch.Tensor,
images_todo: torch.Tensor,
bgval: torch.Tensor,
):
assert image_reference.ndim == 2
assert images_todo.ndim == 3
assert image_reference.shape[-1] == images_todo.shape[-1]
assert image_reference.shape[-2] == images_todo.shape[-2]
# We are going to iterate and precise scale and angle estimates
scale: torch.Tensor = torch.ones(
(images_todo.shape[0]), dtype=self.default_dtype, device=self.device
)
angle: torch.Tensor = torch.zeros(
(images_todo.shape[0]), dtype=self.default_dtype, device=self.device
)
constraints_dynamic_angle_0: torch.Tensor = torch.zeros(
(images_todo.shape[0]), dtype=self.default_dtype, device=self.device
)
constraints_dynamic_angle_1: torch.Tensor | None = None
constraints_dynamic_scale_0: torch.Tensor = torch.ones(
(images_todo.shape[0]), dtype=self.default_dtype, device=self.device
)
constraints_dynamic_scale_1: torch.Tensor | None = None
newscale, newangle = self._get_ang_scale(
image_reference,
images_todo,
constraints_dynamic_scale_0,
constraints_dynamic_scale_1,
constraints_dynamic_angle_0,
constraints_dynamic_angle_1,
)
scale *= newscale
angle += newangle
im2 = self.transform_img(images_todo, scale, angle, bgval=bgval)
tvec, self.success, res_angle = self.translation(image_reference, im2)
angle += res_angle
angle = self.wrap_angle(angle, 360)
return scale, angle, tvec
def similarity(
self,
image_reference: torch.Tensor,
images_todo: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
assert image_reference.ndim == 2
assert images_todo.ndim == 3
bgval: torch.Tensor = self.get_borderval(img=images_todo, radius=5)
scale, angle, tvec = self._similarity(
image_reference,
images_todo,
bgval,
)
im2 = self.transform_img_dict(
img=images_todo,
scale=scale,
angle=angle,
tvec=tvec,
bgval=bgval,
)
return scale, angle, tvec, im2