gevi/reproduction_effort/functions/align_refref.py

55 lines
1.6 KiB
Python
Raw Normal View History

2024-02-03 12:50:07 +01:00
import torch
import torchvision as tv # type: ignore
from functions.ImageAlignment import ImageAlignment
from functions.calculate_translation import calculate_translation
from functions.calculate_rotation import calculate_rotation
@torch.no_grad()
def align_refref(
ref_image_acceptor: torch.Tensor,
ref_image_donor: torch.Tensor,
image_alignment: ImageAlignment,
batch_size: int,
2024-02-03 19:19:24 +01:00
fill_value: float = 0,
2024-02-03 12:50:07 +01:00
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
angle_refref = calculate_rotation(
2024-02-04 05:22:06 +01:00
image_alignment=image_alignment,
input=ref_image_acceptor.unsqueeze(0),
reference_image=ref_image_donor,
2024-02-03 12:50:07 +01:00
batch_size=batch_size,
)
ref_image_acceptor = tv.transforms.functional.affine(
img=ref_image_acceptor.unsqueeze(0),
angle=-float(angle_refref),
translate=[0, 0],
scale=1.0,
shear=0,
interpolation=tv.transforms.InterpolationMode.BILINEAR,
fill=fill_value,
)
tvec_refref = calculate_translation(
2024-02-04 05:22:06 +01:00
image_alignment=image_alignment,
input=ref_image_acceptor,
reference_image=ref_image_donor,
2024-02-03 12:50:07 +01:00
batch_size=batch_size,
)
tvec_refref = tvec_refref[0, :]
ref_image_acceptor = tv.transforms.functional.affine(
img=ref_image_acceptor,
angle=0,
translate=[tvec_refref[1], tvec_refref[0]],
scale=1.0,
shear=0,
interpolation=tv.transforms.InterpolationMode.BILINEAR,
fill=fill_value,
2024-02-04 05:22:06 +01:00
).squeeze(0)
2024-02-03 12:50:07 +01:00
return angle_refref, tvec_refref, ref_image_acceptor, ref_image_donor