diff --git a/reproduction_effort/functions/align_refref.py b/reproduction_effort/functions/align_refref.py new file mode 100644 index 0000000..401fc60 --- /dev/null +++ b/reproduction_effort/functions/align_refref.py @@ -0,0 +1,56 @@ +import torch +import torchvision as tv # type: ignore + +from functions.ImageAlignment import ImageAlignment +from functions.calculate_translation import calculate_translation +from functions.calculate_rotation import calculate_rotation + + +@torch.no_grad() +def align_refref( + ref_image_acceptor: torch.Tensor, + ref_image_donor: torch.Tensor, + image_alignment: ImageAlignment, + batch_size: int, + fill_value: int = 0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + + angle_refref = calculate_rotation( + image_alignment, + ref_image_acceptor.unsqueeze(0), + ref_image_donor, + batch_size=batch_size, + ) + + ref_image_acceptor = tv.transforms.functional.affine( + img=ref_image_acceptor.unsqueeze(0), + angle=-float(angle_refref), + translate=[0, 0], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=fill_value, + ) + + tvec_refref = calculate_translation( + image_alignment, + ref_image_acceptor, + ref_image_donor, + batch_size=batch_size, + ) + + tvec_refref = tvec_refref[0, :] + + ref_image_acceptor = tv.transforms.functional.affine( + img=ref_image_acceptor, + angle=0, + translate=[tvec_refref[1], tvec_refref[0]], + scale=1.0, + shear=0, + interpolation=tv.transforms.InterpolationMode.BILINEAR, + fill=fill_value, + ) + + ref_image_acceptor = ref_image_acceptor.squeeze(0) + + return angle_refref, tvec_refref, ref_image_acceptor, ref_image_donor diff --git a/reproduction_effort/functions/calculate_rotation.py b/reproduction_effort/functions/calculate_rotation.py new file mode 100644 index 0000000..6a53afd --- /dev/null +++ b/reproduction_effort/functions/calculate_rotation.py @@ -0,0 +1,40 @@ +import torch + +from functions.ImageAlignment import ImageAlignment + + +@torch.no_grad() +def calculate_rotation( + image_alignment: ImageAlignment, + input: torch.Tensor, + reference_image: torch.Tensor, + batch_size: int, +) -> torch.Tensor: + angle = torch.zeros((input.shape[0])) + + data_loader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset(input), + batch_size=batch_size, + shuffle=False, + ) + start_position: int = 0 + for input_batch in data_loader: + assert len(input_batch) == 1 + + end_position = start_position + input_batch[0].shape[0] + + angle_temp = image_alignment.dry_run_angle( + input=input_batch[0], + new_reference_image=reference_image, + ) + + assert angle_temp is not None + + angle[start_position:end_position] = angle_temp + + start_position += input_batch[0].shape[0] + + angle = torch.where(angle >= 180, 360.0 - angle, angle) + angle = torch.where(angle <= -180, 360.0 + angle, angle) + + return angle diff --git a/reproduction_effort/functions/calculate_translation.py b/reproduction_effort/functions/calculate_translation.py new file mode 100644 index 0000000..9eadf59 --- /dev/null +++ b/reproduction_effort/functions/calculate_translation.py @@ -0,0 +1,37 @@ +import torch + +from functions.ImageAlignment import ImageAlignment + + +@torch.no_grad() +def calculate_translation( + image_alignment: ImageAlignment, + input: torch.Tensor, + reference_image: torch.Tensor, + batch_size: int, +) -> torch.Tensor: + tvec = torch.zeros((input.shape[0], 2)) + + data_loader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset(input), + batch_size=batch_size, + shuffle=False, + ) + start_position: int = 0 + for input_batch in data_loader: + assert len(input_batch) == 1 + + end_position = start_position + input_batch[0].shape[0] + + tvec_temp = image_alignment.dry_run_translation( + input=input_batch[0], + new_reference_image=reference_image, + ) + + assert tvec_temp is not None + + tvec[start_position:end_position, :] = tvec_temp + + start_position += input_batch[0].shape[0] + + return tvec