import torch import torchvision as tv # type: ignore import logging from functions.calculate_rotation import calculate_rotation from functions.ImageAlignment import ImageAlignment @torch.no_grad() def perform_donor_volume_rotation( mylogger: logging.Logger, acceptor: torch.Tensor, donor: torch.Tensor, oxygenation: torch.Tensor, volume: torch.Tensor, ref_image_donor: torch.Tensor, ref_image_volume: torch.Tensor, batch_size: int, config: dict, fill_value: float = 0, ) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, ]: try: return perform_donor_volume_rotation_internal( mylogger=mylogger, acceptor=acceptor, donor=donor, oxygenation=oxygenation, volume=volume, ref_image_donor=ref_image_donor, ref_image_volume=ref_image_volume, batch_size=batch_size, config=config, fill_value=fill_value, ) except torch.cuda.OutOfMemoryError: ( acceptor_cpu, donor_cpu, oxygenation_cpu, volume_cpu, angle_donor_volume_cpu, ) = perform_donor_volume_rotation_internal( mylogger=mylogger, acceptor=acceptor.cpu(), donor=donor.cpu(), oxygenation=oxygenation.cpu(), volume=volume.cpu(), ref_image_donor=ref_image_donor.cpu(), ref_image_volume=ref_image_volume.cpu(), batch_size=batch_size, config=config, fill_value=fill_value, ) return ( acceptor_cpu.to(device=acceptor.device), donor_cpu.to(device=acceptor.device), oxygenation_cpu.to(device=acceptor.device), volume_cpu.to(device=acceptor.device), angle_donor_volume_cpu.to(device=acceptor.device), ) @torch.no_grad() def perform_donor_volume_rotation_internal( mylogger: logging.Logger, acceptor: torch.Tensor, donor: torch.Tensor, oxygenation: torch.Tensor, volume: torch.Tensor, ref_image_donor: torch.Tensor, ref_image_volume: torch.Tensor, batch_size: int, config: dict, fill_value: float = 0, ) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, ]: image_alignment = ImageAlignment( default_dtype=acceptor.dtype, device=acceptor.device ) mylogger.info("Calculate rotation between donor data and donor ref image") angle_donor = calculate_rotation( input=donor, reference_image=ref_image_donor, image_alignment=image_alignment, batch_size=batch_size, ) mylogger.info("Calculate rotation between volume data and volume ref image") angle_volume = calculate_rotation( input=volume, reference_image=ref_image_volume, image_alignment=image_alignment, batch_size=batch_size, ) mylogger.info("Average over both rotations") donor_threshold: torch.Tensor = torch.sort(torch.abs(angle_donor))[0] donor_threshold = donor_threshold[ int( donor_threshold.shape[0] * float(config["rotation_stabilization_threshold_border"]) ) ] * float(config["rotation_stabilization_threshold_factor"]) volume_threshold: torch.Tensor = torch.sort(torch.abs(angle_volume))[0] volume_threshold = volume_threshold[ int( volume_threshold.shape[0] * float(config["rotation_stabilization_threshold_border"]) ) ] * float(config["rotation_stabilization_threshold_factor"]) donor_idx = torch.where(torch.abs(angle_donor) > donor_threshold)[0] volume_idx = torch.where(torch.abs(angle_volume) > volume_threshold)[0] mylogger.info( f"Border: {config['rotation_stabilization_threshold_border']}, " f"factor {config['rotation_stabilization_threshold_factor']} " ) mylogger.info( f"Donor threshold: {donor_threshold:.3e}, volume threshold: {volume_threshold:.3e}" ) mylogger.info( f"Found broken rotation values: " f"donor {int(donor_idx.shape[0])}, " f"volume {int(volume_idx.shape[0])}" ) angle_donor[donor_idx] = angle_volume[donor_idx] angle_volume[volume_idx] = angle_donor[volume_idx] donor_idx = torch.where(torch.abs(angle_donor) > donor_threshold)[0] volume_idx = torch.where(torch.abs(angle_volume) > volume_threshold)[0] mylogger.info( f"After fill in these broken rotation values remain: " f"donor {int(donor_idx.shape[0])}, " f"volume {int(volume_idx.shape[0])}" ) angle_donor[donor_idx] = 0.0 angle_volume[volume_idx] = 0.0 angle_donor_volume = (angle_donor + angle_volume) / 2.0 mylogger.info("Rotate acceptor data based on the average rotation") for frame_id in range(0, angle_donor_volume.shape[0]): acceptor[frame_id, ...] = tv.transforms.functional.affine( img=acceptor[frame_id, ...].unsqueeze(0), angle=-float(angle_donor_volume[frame_id]), translate=[0, 0], scale=1.0, shear=0, interpolation=tv.transforms.InterpolationMode.BILINEAR, fill=fill_value, ).squeeze(0) mylogger.info("Rotate donor data based on the average rotation") for frame_id in range(0, angle_donor_volume.shape[0]): donor[frame_id, ...] = tv.transforms.functional.affine( img=donor[frame_id, ...].unsqueeze(0), angle=-float(angle_donor_volume[frame_id]), translate=[0, 0], scale=1.0, shear=0, interpolation=tv.transforms.InterpolationMode.BILINEAR, fill=fill_value, ).squeeze(0) mylogger.info("Rotate oxygenation data based on the average rotation") for frame_id in range(0, angle_donor_volume.shape[0]): oxygenation[frame_id, ...] = tv.transforms.functional.affine( img=oxygenation[frame_id, ...].unsqueeze(0), angle=-float(angle_donor_volume[frame_id]), translate=[0, 0], scale=1.0, shear=0, interpolation=tv.transforms.InterpolationMode.BILINEAR, fill=fill_value, ).squeeze(0) mylogger.info("Rotate volume data based on the average rotation") for frame_id in range(0, angle_donor_volume.shape[0]): volume[frame_id, ...] = tv.transforms.functional.affine( img=volume[frame_id, ...].unsqueeze(0), angle=-float(angle_donor_volume[frame_id]), translate=[0, 0], scale=1.0, shear=0, interpolation=tv.transforms.InterpolationMode.BILINEAR, fill=fill_value, ).squeeze(0) return (acceptor, donor, oxygenation, volume, angle_donor_volume)