diff --git a/reproduction_effort/functions/ImageAlignment.py b/reproduction_effort/functions/ImageAlignment.py index 76ec320..6472d02 100644 --- a/reproduction_effort/functions/ImageAlignment.py +++ b/reproduction_effort/functions/ImageAlignment.py @@ -310,12 +310,14 @@ class ImageAlignment(torch.nn.Module): (array.shape[0], array.shape[1] * array.shape[2]) ).argmax(dim=1) pos_0 = max_pos // array.shape[2] + max_pos -= pos_0 * array.shape[2] ret = torch.zeros( (array.shape[0], 2), dtype=self.default_dtype, device=self.device ) ret[:, 0] = pos_0 ret[:, 1] = max_pos + return ret.type(dtype=torch.int64) def _apodize(self, what: torch.Tensor) -> torch.Tensor: @@ -666,6 +668,7 @@ class ImageAlignment(torch.nn.Module): array *= mask2 tvec = self._argmax_ext(array, "inf") + tvec = self._interpolate(array_orig, tvec) success = self._get_success(array_orig, tvec, 2) @@ -793,7 +796,8 @@ class ImageAlignment(torch.nn.Module): / ( torch.abs(image_reference_fft) * torch.abs(images_todo_fft) + eps.unsqueeze(-1).unsqueeze(-1) - ) + ), + dim=(-2, -1), ) ) @@ -813,6 +817,7 @@ class ImageAlignment(torch.nn.Module): ret, succ = self._phase_correlation( im0.unsqueeze(0), im1, self.argmax_translation ) + return ret, succ def _get_ang_scale( diff --git a/reproduction_effort/functions/align_cameras.py b/reproduction_effort/functions/align_cameras.py index e89506c..be9b696 100644 --- a/reproduction_effort/functions/align_cameras.py +++ b/reproduction_effort/functions/align_cameras.py @@ -61,12 +61,12 @@ def align_cameras( # --- Calculate translation and rotation between the reference images --- angle_refref, tvec_refref, ref_image_acceptor, ref_image_donor = align_refref( ref_image_acceptor=acceptor[ - acceptor.shape[2] // 2, + acceptor.shape[0] // 2, :, :, ], ref_image_donor=donor[ - donor.shape[2] // 2, + donor.shape[0] // 2, :, :, ], @@ -77,7 +77,7 @@ def align_cameras( ref_image_oxygenation = tv.transforms.functional.affine( img=oxygenation[ - oxygenation.shape[2] // 2, + oxygenation.shape[0] // 2, :, :, ].unsqueeze(0), @@ -102,7 +102,7 @@ def align_cameras( ref_image_oxygenation = ref_image_oxygenation.squeeze(0) ref_image_volume = volume[ - volume.shape[2] // 2, + volume.shape[0] // 2, :, :, ].clone() diff --git a/reproduction_effort/functions/align_refref.py b/reproduction_effort/functions/align_refref.py index 094b09c..7361849 100644 --- a/reproduction_effort/functions/align_refref.py +++ b/reproduction_effort/functions/align_refref.py @@ -16,9 +16,9 @@ def align_refref( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: angle_refref = calculate_rotation( - image_alignment, - ref_image_acceptor.unsqueeze(0), - ref_image_donor, + image_alignment=image_alignment, + input=ref_image_acceptor.unsqueeze(0), + reference_image=ref_image_donor, batch_size=batch_size, ) @@ -33,9 +33,9 @@ def align_refref( ) tvec_refref = calculate_translation( - image_alignment, - ref_image_acceptor, - ref_image_donor, + image_alignment=image_alignment, + input=ref_image_acceptor, + reference_image=ref_image_donor, batch_size=batch_size, ) @@ -49,8 +49,6 @@ def align_refref( shear=0, interpolation=tv.transforms.InterpolationMode.BILINEAR, fill=fill_value, - ) - - ref_image_acceptor = ref_image_acceptor.squeeze(0) + ).squeeze(0) return angle_refref, tvec_refref, ref_image_acceptor, ref_image_donor