gevi/functions/perform_donor_volume_translation.py

211 lines
7.2 KiB
Python
Raw Permalink Normal View History

2024-02-28 16:14:50 +01:00
import torch
import torchvision as tv # type: ignore
import logging
from functions.calculate_translation import calculate_translation
from functions.ImageAlignment import ImageAlignment
@torch.no_grad()
def perform_donor_volume_translation(
mylogger: logging.Logger,
acceptor: torch.Tensor,
donor: torch.Tensor,
oxygenation: torch.Tensor,
volume: torch.Tensor,
ref_image_donor: torch.Tensor,
ref_image_volume: torch.Tensor,
batch_size: int,
config: dict,
fill_value: float = 0,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
2024-02-28 18:55:37 +01:00
try:
return perform_donor_volume_translation_internal(
mylogger=mylogger,
acceptor=acceptor,
donor=donor,
oxygenation=oxygenation,
volume=volume,
ref_image_donor=ref_image_donor,
ref_image_volume=ref_image_volume,
batch_size=batch_size,
config=config,
fill_value=fill_value,
)
except torch.cuda.OutOfMemoryError:
(
acceptor_cpu,
donor_cpu,
oxygenation_cpu,
volume_cpu,
tvec_donor_volume_cpu,
) = perform_donor_volume_translation_internal(
mylogger=mylogger,
acceptor=acceptor.cpu(),
donor=donor.cpu(),
oxygenation=oxygenation.cpu(),
volume=volume.cpu(),
ref_image_donor=ref_image_donor.cpu(),
ref_image_volume=ref_image_volume.cpu(),
batch_size=batch_size,
config=config,
fill_value=fill_value,
)
return (
acceptor_cpu.to(device=acceptor.device),
donor_cpu.to(device=acceptor.device),
oxygenation_cpu.to(device=acceptor.device),
volume_cpu.to(device=acceptor.device),
tvec_donor_volume_cpu.to(device=acceptor.device),
)
@torch.no_grad()
def perform_donor_volume_translation_internal(
mylogger: logging.Logger,
acceptor: torch.Tensor,
donor: torch.Tensor,
oxygenation: torch.Tensor,
volume: torch.Tensor,
ref_image_donor: torch.Tensor,
ref_image_volume: torch.Tensor,
batch_size: int,
config: dict,
fill_value: float = 0,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
image_alignment = ImageAlignment(
default_dtype=acceptor.dtype, device=acceptor.device
)
2024-02-28 16:14:50 +01:00
mylogger.info("Calculate translation between donor data and donor ref image")
tvec_donor = calculate_translation(
input=donor,
reference_image=ref_image_donor,
image_alignment=image_alignment,
batch_size=batch_size,
)
mylogger.info("Calculate translation between volume data and volume ref image")
tvec_volume = calculate_translation(
input=volume,
reference_image=ref_image_volume,
image_alignment=image_alignment,
batch_size=batch_size,
)
mylogger.info("Average over both translations")
for i in range(0, 2):
mylogger.info(f"Processing dimension {i}")
donor_threshold: torch.Tensor = torch.sort(torch.abs(tvec_donor[:, i]))[0]
donor_threshold = donor_threshold[
int(
donor_threshold.shape[0]
* float(config["rotation_stabilization_threshold_border"])
)
] * float(config["rotation_stabilization_threshold_factor"])
volume_threshold: torch.Tensor = torch.sort(torch.abs(tvec_volume[:, i]))[0]
volume_threshold = volume_threshold[
int(
volume_threshold.shape[0]
* float(config["rotation_stabilization_threshold_border"])
)
] * float(config["rotation_stabilization_threshold_factor"])
donor_idx = torch.where(torch.abs(tvec_donor[:, i]) > donor_threshold)[0]
volume_idx = torch.where(torch.abs(tvec_volume[:, i]) > volume_threshold)[0]
mylogger.info(
f"Border: {config['rotation_stabilization_threshold_border']}, "
f"factor {config['rotation_stabilization_threshold_factor']} "
)
mylogger.info(
f"Donor threshold: {donor_threshold:.3e}, volume threshold: {volume_threshold:.3e}"
)
mylogger.info(
f"Found broken rotation values: "
f"donor {int(donor_idx.shape[0])}, "
f"volume {int(volume_idx.shape[0])}"
)
tvec_donor[donor_idx, i] = tvec_volume[donor_idx, i]
tvec_volume[volume_idx, i] = tvec_donor[volume_idx, i]
donor_idx = torch.where(torch.abs(tvec_donor[:, i]) > donor_threshold)[0]
volume_idx = torch.where(torch.abs(tvec_volume[:, i]) > volume_threshold)[0]
mylogger.info(
f"After fill in these broken rotation values remain: "
f"donor {int(donor_idx.shape[0])}, "
f"volume {int(volume_idx.shape[0])}"
)
tvec_donor[donor_idx, i] = 0.0
tvec_volume[volume_idx, i] = 0.0
tvec_donor_volume = (tvec_donor + tvec_volume) / 2.0
mylogger.info("Translate acceptor data based on the average translation vector")
for frame_id in range(0, tvec_donor_volume.shape[0]):
acceptor[frame_id, ...] = tv.transforms.functional.affine(
img=acceptor[frame_id, ...].unsqueeze(0),
angle=0,
translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]],
scale=1.0,
shear=0,
interpolation=tv.transforms.InterpolationMode.BILINEAR,
fill=fill_value,
).squeeze(0)
mylogger.info("Translate donor data based on the average translation vector")
for frame_id in range(0, tvec_donor_volume.shape[0]):
donor[frame_id, ...] = tv.transforms.functional.affine(
img=donor[frame_id, ...].unsqueeze(0),
angle=0,
translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]],
scale=1.0,
shear=0,
interpolation=tv.transforms.InterpolationMode.BILINEAR,
fill=fill_value,
).squeeze(0)
mylogger.info("Translate oxygenation data based on the average translation vector")
for frame_id in range(0, tvec_donor_volume.shape[0]):
oxygenation[frame_id, ...] = tv.transforms.functional.affine(
img=oxygenation[frame_id, ...].unsqueeze(0),
angle=0,
translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]],
scale=1.0,
shear=0,
interpolation=tv.transforms.InterpolationMode.BILINEAR,
fill=fill_value,
).squeeze(0)
mylogger.info("Translate volume data based on the average translation vector")
for frame_id in range(0, tvec_donor_volume.shape[0]):
volume[frame_id, ...] = tv.transforms.functional.affine(
img=volume[frame_id, ...].unsqueeze(0),
angle=0,
translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]],
scale=1.0,
shear=0,
interpolation=tv.transforms.InterpolationMode.BILINEAR,
fill=fill_value,
).squeeze(0)
return (acceptor, donor, oxygenation, volume, tvec_donor_volume)