Add files via upload

This commit is contained in:
David Rotermund 2024-02-03 12:50:07 +01:00 committed by GitHub
parent 5f10647e83
commit faafa72031
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 133 additions and 0 deletions

View file

@ -0,0 +1,56 @@
import torch
import torchvision as tv # type: ignore
from functions.ImageAlignment import ImageAlignment
from functions.calculate_translation import calculate_translation
from functions.calculate_rotation import calculate_rotation
@torch.no_grad()
def align_refref(
ref_image_acceptor: torch.Tensor,
ref_image_donor: torch.Tensor,
image_alignment: ImageAlignment,
batch_size: int,
fill_value: int = 0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
angle_refref = calculate_rotation(
image_alignment,
ref_image_acceptor.unsqueeze(0),
ref_image_donor,
batch_size=batch_size,
)
ref_image_acceptor = tv.transforms.functional.affine(
img=ref_image_acceptor.unsqueeze(0),
angle=-float(angle_refref),
translate=[0, 0],
scale=1.0,
shear=0,
interpolation=tv.transforms.InterpolationMode.BILINEAR,
fill=fill_value,
)
tvec_refref = calculate_translation(
image_alignment,
ref_image_acceptor,
ref_image_donor,
batch_size=batch_size,
)
tvec_refref = tvec_refref[0, :]
ref_image_acceptor = tv.transforms.functional.affine(
img=ref_image_acceptor,
angle=0,
translate=[tvec_refref[1], tvec_refref[0]],
scale=1.0,
shear=0,
interpolation=tv.transforms.InterpolationMode.BILINEAR,
fill=fill_value,
)
ref_image_acceptor = ref_image_acceptor.squeeze(0)
return angle_refref, tvec_refref, ref_image_acceptor, ref_image_donor

View file

@ -0,0 +1,40 @@
import torch
from functions.ImageAlignment import ImageAlignment
@torch.no_grad()
def calculate_rotation(
image_alignment: ImageAlignment,
input: torch.Tensor,
reference_image: torch.Tensor,
batch_size: int,
) -> torch.Tensor:
angle = torch.zeros((input.shape[0]))
data_loader = torch.utils.data.DataLoader(
torch.utils.data.TensorDataset(input),
batch_size=batch_size,
shuffle=False,
)
start_position: int = 0
for input_batch in data_loader:
assert len(input_batch) == 1
end_position = start_position + input_batch[0].shape[0]
angle_temp = image_alignment.dry_run_angle(
input=input_batch[0],
new_reference_image=reference_image,
)
assert angle_temp is not None
angle[start_position:end_position] = angle_temp
start_position += input_batch[0].shape[0]
angle = torch.where(angle >= 180, 360.0 - angle, angle)
angle = torch.where(angle <= -180, 360.0 + angle, angle)
return angle

View file

@ -0,0 +1,37 @@
import torch
from functions.ImageAlignment import ImageAlignment
@torch.no_grad()
def calculate_translation(
image_alignment: ImageAlignment,
input: torch.Tensor,
reference_image: torch.Tensor,
batch_size: int,
) -> torch.Tensor:
tvec = torch.zeros((input.shape[0], 2))
data_loader = torch.utils.data.DataLoader(
torch.utils.data.TensorDataset(input),
batch_size=batch_size,
shuffle=False,
)
start_position: int = 0
for input_batch in data_loader:
assert len(input_batch) == 1
end_position = start_position + input_batch[0].shape[0]
tvec_temp = image_alignment.dry_run_translation(
input=input_batch[0],
new_reference_image=reference_image,
)
assert tvec_temp is not None
tvec[start_position:end_position, :] = tvec_temp
start_position += input_batch[0].shape[0]
return tvec