diff --git a/new_pipeline/functions/perform_donor_volume_rotation.py b/new_pipeline/functions/perform_donor_volume_rotation.py index 1c9b804..590630f 100644 --- a/new_pipeline/functions/perform_donor_volume_rotation.py +++ b/new_pipeline/functions/perform_donor_volume_rotation.py @@ -16,6 +16,7 @@ def perform_donor_volume_rotation( ref_image_volume: torch.Tensor, image_alignment: ImageAlignment, batch_size: int, + config: dict, fill_value: float = 0, ) -> tuple[ torch.Tensor, @@ -43,8 +44,50 @@ def perform_donor_volume_rotation( ) mylogger.info("Average over both rotations") + + donor_threshold: torch.Tensor = torch.sort(torch.abs(angle_donor))[0] + donor_threshold = donor_threshold[ + int( + donor_threshold.shape[0] + * float(config["rotation_stabilization_threshold_border"]) + ) + ] * float(config["rotation_stabilization_threshold_factor"]) + + volume_threshold: torch.Tensor = torch.sort(torch.abs(angle_volume))[0] + volume_threshold = volume_threshold[ + int( + volume_threshold.shape[0] + * float(config["rotation_stabilization_threshold_border"]) + ) + ] * float(config["rotation_stabilization_threshold_factor"]) + + donor_idx = torch.where(torch.abs(angle_donor) > donor_threshold)[0] + volume_idx = torch.where(torch.abs(angle_volume) > volume_threshold)[0] + mylogger.info( + f"Border: {config['rotation_stabilization_threshold_border']}, " + f"factor {config['rotation_stabilization_threshold_factor']} " + ) + mylogger.info( + f"Donor threshold: {donor_threshold:.3e}, volume threshold: {volume_threshold:.3e}" + ) + mylogger.info( + f"Found broken rotation values: " + f"donor {int(donor_idx.shape[0])}, " + f"volume {int(volume_idx.shape[0])}" + ) + angle_donor[donor_idx] = angle_volume[donor_idx] + angle_volume[volume_idx] = angle_donor[volume_idx] + + donor_idx = torch.where(torch.abs(angle_donor) > donor_threshold)[0] + volume_idx = torch.where(torch.abs(angle_volume) > volume_threshold)[0] + mylogger.info( + f"After fill in these broken rotation values remain: " + f"donor {int(donor_idx.shape[0])}, " + f"volume {int(volume_idx.shape[0])}" + ) + angle_donor[donor_idx] = 0.0 + angle_volume[volume_idx] = 0.0 angle_donor_volume = (angle_donor + angle_volume) / 2.0 - angle_donor_volume *= 0.0 mylogger.info("Rotate acceptor data based on the average rotation") for frame_id in range(0, angle_donor_volume.shape[0]): diff --git a/new_pipeline/functions/perform_donor_volume_translation.py b/new_pipeline/functions/perform_donor_volume_translation.py index 0fce44e..7091add 100644 --- a/new_pipeline/functions/perform_donor_volume_translation.py +++ b/new_pipeline/functions/perform_donor_volume_translation.py @@ -17,6 +17,7 @@ def perform_donor_volume_translation( ref_image_volume: torch.Tensor, image_alignment: ImageAlignment, batch_size: int, + config: dict, fill_value: float = 0, ) -> tuple[ torch.Tensor, @@ -43,9 +44,53 @@ def perform_donor_volume_translation( ) mylogger.info("Average over both translations") - tvec_donor_volume = (tvec_donor + tvec_volume) / 2.0 - tvec_donor_volume *= 0.0 + for i in range(0, 2): + mylogger.info(f"Processing dimension {i}") + donor_threshold: torch.Tensor = torch.sort(torch.abs(tvec_donor[:, i]))[0] + donor_threshold = donor_threshold[ + int( + donor_threshold.shape[0] + * float(config["rotation_stabilization_threshold_border"]) + ) + ] * float(config["rotation_stabilization_threshold_factor"]) + + volume_threshold: torch.Tensor = torch.sort(torch.abs(tvec_volume[:, i]))[0] + volume_threshold = volume_threshold[ + int( + volume_threshold.shape[0] + * float(config["rotation_stabilization_threshold_border"]) + ) + ] * float(config["rotation_stabilization_threshold_factor"]) + + donor_idx = torch.where(torch.abs(tvec_donor[:, i]) > donor_threshold)[0] + volume_idx = torch.where(torch.abs(tvec_volume[:, i]) > volume_threshold)[0] + mylogger.info( + f"Border: {config['rotation_stabilization_threshold_border']}, " + f"factor {config['rotation_stabilization_threshold_factor']} " + ) + mylogger.info( + f"Donor threshold: {donor_threshold:.3e}, volume threshold: {volume_threshold:.3e}" + ) + mylogger.info( + f"Found broken rotation values: " + f"donor {int(donor_idx.shape[0])}, " + f"volume {int(volume_idx.shape[0])}" + ) + tvec_donor[donor_idx, i] = tvec_volume[donor_idx, i] + tvec_volume[volume_idx, i] = tvec_donor[volume_idx, i] + + donor_idx = torch.where(torch.abs(tvec_donor[:, i]) > donor_threshold)[0] + volume_idx = torch.where(torch.abs(tvec_volume[:, i]) > volume_threshold)[0] + mylogger.info( + f"After fill in these broken rotation values remain: " + f"donor {int(donor_idx.shape[0])}, " + f"volume {int(volume_idx.shape[0])}" + ) + tvec_donor[donor_idx, i] = 0.0 + tvec_volume[volume_idx, i] = 0.0 + + tvec_donor_volume = (tvec_donor + tvec_volume) / 2.0 mylogger.info("Translate acceptor data based on the average translation vector") for frame_id in range(0, tvec_donor_volume.shape[0]):