2024-02-28 16:14:50 +01:00
|
|
|
import torch
|
|
|
|
import torchvision as tv # type: ignore
|
|
|
|
import logging
|
|
|
|
from functions.calculate_rotation import calculate_rotation
|
|
|
|
from functions.ImageAlignment import ImageAlignment
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def perform_donor_volume_rotation(
|
|
|
|
mylogger: logging.Logger,
|
|
|
|
acceptor: torch.Tensor,
|
|
|
|
donor: torch.Tensor,
|
|
|
|
oxygenation: torch.Tensor,
|
|
|
|
volume: torch.Tensor,
|
|
|
|
ref_image_donor: torch.Tensor,
|
|
|
|
ref_image_volume: torch.Tensor,
|
|
|
|
batch_size: int,
|
|
|
|
config: dict,
|
|
|
|
fill_value: float = 0,
|
|
|
|
) -> tuple[
|
|
|
|
torch.Tensor,
|
|
|
|
torch.Tensor,
|
|
|
|
torch.Tensor,
|
|
|
|
torch.Tensor,
|
|
|
|
torch.Tensor,
|
|
|
|
]:
|
2024-02-28 18:55:37 +01:00
|
|
|
try:
|
|
|
|
|
|
|
|
return perform_donor_volume_rotation_internal(
|
|
|
|
mylogger=mylogger,
|
|
|
|
acceptor=acceptor,
|
|
|
|
donor=donor,
|
|
|
|
oxygenation=oxygenation,
|
|
|
|
volume=volume,
|
|
|
|
ref_image_donor=ref_image_donor,
|
|
|
|
ref_image_volume=ref_image_volume,
|
|
|
|
batch_size=batch_size,
|
|
|
|
config=config,
|
|
|
|
fill_value=fill_value,
|
|
|
|
)
|
|
|
|
|
|
|
|
except torch.cuda.OutOfMemoryError:
|
|
|
|
|
|
|
|
(
|
|
|
|
acceptor_cpu,
|
|
|
|
donor_cpu,
|
|
|
|
oxygenation_cpu,
|
|
|
|
volume_cpu,
|
|
|
|
angle_donor_volume_cpu,
|
|
|
|
) = perform_donor_volume_rotation_internal(
|
|
|
|
mylogger=mylogger,
|
|
|
|
acceptor=acceptor.cpu(),
|
|
|
|
donor=donor.cpu(),
|
|
|
|
oxygenation=oxygenation.cpu(),
|
|
|
|
volume=volume.cpu(),
|
|
|
|
ref_image_donor=ref_image_donor.cpu(),
|
|
|
|
ref_image_volume=ref_image_volume.cpu(),
|
|
|
|
batch_size=batch_size,
|
|
|
|
config=config,
|
|
|
|
fill_value=fill_value,
|
|
|
|
)
|
|
|
|
|
|
|
|
return (
|
|
|
|
acceptor_cpu.to(device=acceptor.device),
|
|
|
|
donor_cpu.to(device=acceptor.device),
|
|
|
|
oxygenation_cpu.to(device=acceptor.device),
|
|
|
|
volume_cpu.to(device=acceptor.device),
|
|
|
|
angle_donor_volume_cpu.to(device=acceptor.device),
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def perform_donor_volume_rotation_internal(
|
|
|
|
mylogger: logging.Logger,
|
|
|
|
acceptor: torch.Tensor,
|
|
|
|
donor: torch.Tensor,
|
|
|
|
oxygenation: torch.Tensor,
|
|
|
|
volume: torch.Tensor,
|
|
|
|
ref_image_donor: torch.Tensor,
|
|
|
|
ref_image_volume: torch.Tensor,
|
|
|
|
batch_size: int,
|
|
|
|
config: dict,
|
|
|
|
fill_value: float = 0,
|
|
|
|
) -> tuple[
|
|
|
|
torch.Tensor,
|
|
|
|
torch.Tensor,
|
|
|
|
torch.Tensor,
|
|
|
|
torch.Tensor,
|
|
|
|
torch.Tensor,
|
|
|
|
]:
|
|
|
|
|
|
|
|
image_alignment = ImageAlignment(
|
|
|
|
default_dtype=acceptor.dtype, device=acceptor.device
|
|
|
|
)
|
2024-02-28 16:14:50 +01:00
|
|
|
|
|
|
|
mylogger.info("Calculate rotation between donor data and donor ref image")
|
|
|
|
|
|
|
|
angle_donor = calculate_rotation(
|
|
|
|
input=donor,
|
|
|
|
reference_image=ref_image_donor,
|
|
|
|
image_alignment=image_alignment,
|
|
|
|
batch_size=batch_size,
|
|
|
|
)
|
|
|
|
|
|
|
|
mylogger.info("Calculate rotation between volume data and volume ref image")
|
|
|
|
angle_volume = calculate_rotation(
|
|
|
|
input=volume,
|
|
|
|
reference_image=ref_image_volume,
|
|
|
|
image_alignment=image_alignment,
|
|
|
|
batch_size=batch_size,
|
|
|
|
)
|
|
|
|
|
|
|
|
mylogger.info("Average over both rotations")
|
|
|
|
|
|
|
|
donor_threshold: torch.Tensor = torch.sort(torch.abs(angle_donor))[0]
|
|
|
|
donor_threshold = donor_threshold[
|
|
|
|
int(
|
|
|
|
donor_threshold.shape[0]
|
|
|
|
* float(config["rotation_stabilization_threshold_border"])
|
|
|
|
)
|
|
|
|
] * float(config["rotation_stabilization_threshold_factor"])
|
|
|
|
|
|
|
|
volume_threshold: torch.Tensor = torch.sort(torch.abs(angle_volume))[0]
|
|
|
|
volume_threshold = volume_threshold[
|
|
|
|
int(
|
|
|
|
volume_threshold.shape[0]
|
|
|
|
* float(config["rotation_stabilization_threshold_border"])
|
|
|
|
)
|
|
|
|
] * float(config["rotation_stabilization_threshold_factor"])
|
|
|
|
|
|
|
|
donor_idx = torch.where(torch.abs(angle_donor) > donor_threshold)[0]
|
|
|
|
volume_idx = torch.where(torch.abs(angle_volume) > volume_threshold)[0]
|
|
|
|
mylogger.info(
|
|
|
|
f"Border: {config['rotation_stabilization_threshold_border']}, "
|
|
|
|
f"factor {config['rotation_stabilization_threshold_factor']} "
|
|
|
|
)
|
|
|
|
mylogger.info(
|
|
|
|
f"Donor threshold: {donor_threshold:.3e}, volume threshold: {volume_threshold:.3e}"
|
|
|
|
)
|
|
|
|
mylogger.info(
|
|
|
|
f"Found broken rotation values: "
|
|
|
|
f"donor {int(donor_idx.shape[0])}, "
|
|
|
|
f"volume {int(volume_idx.shape[0])}"
|
|
|
|
)
|
|
|
|
angle_donor[donor_idx] = angle_volume[donor_idx]
|
|
|
|
angle_volume[volume_idx] = angle_donor[volume_idx]
|
|
|
|
|
|
|
|
donor_idx = torch.where(torch.abs(angle_donor) > donor_threshold)[0]
|
|
|
|
volume_idx = torch.where(torch.abs(angle_volume) > volume_threshold)[0]
|
|
|
|
mylogger.info(
|
|
|
|
f"After fill in these broken rotation values remain: "
|
|
|
|
f"donor {int(donor_idx.shape[0])}, "
|
|
|
|
f"volume {int(volume_idx.shape[0])}"
|
|
|
|
)
|
|
|
|
angle_donor[donor_idx] = 0.0
|
|
|
|
angle_volume[volume_idx] = 0.0
|
|
|
|
angle_donor_volume = (angle_donor + angle_volume) / 2.0
|
|
|
|
|
|
|
|
mylogger.info("Rotate acceptor data based on the average rotation")
|
|
|
|
for frame_id in range(0, angle_donor_volume.shape[0]):
|
|
|
|
acceptor[frame_id, ...] = tv.transforms.functional.affine(
|
|
|
|
img=acceptor[frame_id, ...].unsqueeze(0),
|
|
|
|
angle=-float(angle_donor_volume[frame_id]),
|
|
|
|
translate=[0, 0],
|
|
|
|
scale=1.0,
|
|
|
|
shear=0,
|
|
|
|
interpolation=tv.transforms.InterpolationMode.BILINEAR,
|
|
|
|
fill=fill_value,
|
|
|
|
).squeeze(0)
|
|
|
|
|
|
|
|
mylogger.info("Rotate donor data based on the average rotation")
|
|
|
|
for frame_id in range(0, angle_donor_volume.shape[0]):
|
|
|
|
donor[frame_id, ...] = tv.transforms.functional.affine(
|
|
|
|
img=donor[frame_id, ...].unsqueeze(0),
|
|
|
|
angle=-float(angle_donor_volume[frame_id]),
|
|
|
|
translate=[0, 0],
|
|
|
|
scale=1.0,
|
|
|
|
shear=0,
|
|
|
|
interpolation=tv.transforms.InterpolationMode.BILINEAR,
|
|
|
|
fill=fill_value,
|
|
|
|
).squeeze(0)
|
|
|
|
|
|
|
|
mylogger.info("Rotate oxygenation data based on the average rotation")
|
|
|
|
for frame_id in range(0, angle_donor_volume.shape[0]):
|
|
|
|
oxygenation[frame_id, ...] = tv.transforms.functional.affine(
|
|
|
|
img=oxygenation[frame_id, ...].unsqueeze(0),
|
|
|
|
angle=-float(angle_donor_volume[frame_id]),
|
|
|
|
translate=[0, 0],
|
|
|
|
scale=1.0,
|
|
|
|
shear=0,
|
|
|
|
interpolation=tv.transforms.InterpolationMode.BILINEAR,
|
|
|
|
fill=fill_value,
|
|
|
|
).squeeze(0)
|
|
|
|
|
|
|
|
mylogger.info("Rotate volume data based on the average rotation")
|
|
|
|
for frame_id in range(0, angle_donor_volume.shape[0]):
|
|
|
|
volume[frame_id, ...] = tv.transforms.functional.affine(
|
|
|
|
img=volume[frame_id, ...].unsqueeze(0),
|
|
|
|
angle=-float(angle_donor_volume[frame_id]),
|
|
|
|
translate=[0, 0],
|
|
|
|
scale=1.0,
|
|
|
|
shear=0,
|
|
|
|
interpolation=tv.transforms.InterpolationMode.BILINEAR,
|
|
|
|
fill=fill_value,
|
|
|
|
).squeeze(0)
|
|
|
|
|
|
|
|
return (acceptor, donor, oxygenation, volume, angle_donor_volume)
|