2024-02-26 18:57:33 +01:00
|
|
|
import torch
|
|
|
|
import torchvision as tv # type: ignore
|
|
|
|
import logging
|
|
|
|
|
|
|
|
from functions.calculate_translation import calculate_translation
|
|
|
|
from functions.ImageAlignment import ImageAlignment
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def perform_donor_volume_translation(
|
|
|
|
mylogger: logging.Logger,
|
|
|
|
acceptor: torch.Tensor,
|
|
|
|
donor: torch.Tensor,
|
|
|
|
oxygenation: torch.Tensor,
|
|
|
|
volume: torch.Tensor,
|
|
|
|
ref_image_donor: torch.Tensor,
|
|
|
|
ref_image_volume: torch.Tensor,
|
|
|
|
image_alignment: ImageAlignment,
|
|
|
|
batch_size: int,
|
|
|
|
fill_value: float = 0,
|
|
|
|
) -> tuple[
|
|
|
|
torch.Tensor,
|
|
|
|
torch.Tensor,
|
|
|
|
torch.Tensor,
|
|
|
|
torch.Tensor,
|
|
|
|
torch.Tensor,
|
|
|
|
]:
|
|
|
|
|
|
|
|
mylogger.info("Calculate translation between donor data and donor ref image")
|
|
|
|
tvec_donor = calculate_translation(
|
|
|
|
input=donor,
|
|
|
|
reference_image=ref_image_donor,
|
|
|
|
image_alignment=image_alignment,
|
|
|
|
batch_size=batch_size,
|
|
|
|
)
|
|
|
|
|
|
|
|
mylogger.info("Calculate translation between volume data and volume ref image")
|
|
|
|
tvec_volume = calculate_translation(
|
|
|
|
input=volume,
|
|
|
|
reference_image=ref_image_volume,
|
|
|
|
image_alignment=image_alignment,
|
|
|
|
batch_size=batch_size,
|
|
|
|
)
|
|
|
|
|
|
|
|
mylogger.info("Average over both translations")
|
|
|
|
tvec_donor_volume = (tvec_donor + tvec_volume) / 2.0
|
|
|
|
|
2024-02-27 18:18:47 +01:00
|
|
|
tvec_donor_volume *= 0.0
|
|
|
|
|
2024-02-26 18:57:33 +01:00
|
|
|
mylogger.info("Translate acceptor data based on the average translation vector")
|
|
|
|
for frame_id in range(0, tvec_donor_volume.shape[0]):
|
|
|
|
acceptor[frame_id, ...] = tv.transforms.functional.affine(
|
|
|
|
img=acceptor[frame_id, ...].unsqueeze(0),
|
|
|
|
angle=0,
|
|
|
|
translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]],
|
|
|
|
scale=1.0,
|
|
|
|
shear=0,
|
|
|
|
interpolation=tv.transforms.InterpolationMode.BILINEAR,
|
|
|
|
fill=fill_value,
|
|
|
|
).squeeze(0)
|
|
|
|
|
|
|
|
mylogger.info("Translate donor data based on the average translation vector")
|
|
|
|
for frame_id in range(0, tvec_donor_volume.shape[0]):
|
|
|
|
donor[frame_id, ...] = tv.transforms.functional.affine(
|
|
|
|
img=donor[frame_id, ...].unsqueeze(0),
|
|
|
|
angle=0,
|
|
|
|
translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]],
|
|
|
|
scale=1.0,
|
|
|
|
shear=0,
|
|
|
|
interpolation=tv.transforms.InterpolationMode.BILINEAR,
|
|
|
|
fill=fill_value,
|
|
|
|
).squeeze(0)
|
|
|
|
|
|
|
|
mylogger.info("Translate oxygenation data based on the average translation vector")
|
|
|
|
for frame_id in range(0, tvec_donor_volume.shape[0]):
|
|
|
|
oxygenation[frame_id, ...] = tv.transforms.functional.affine(
|
|
|
|
img=oxygenation[frame_id, ...].unsqueeze(0),
|
|
|
|
angle=0,
|
|
|
|
translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]],
|
|
|
|
scale=1.0,
|
|
|
|
shear=0,
|
|
|
|
interpolation=tv.transforms.InterpolationMode.BILINEAR,
|
|
|
|
fill=fill_value,
|
|
|
|
).squeeze(0)
|
|
|
|
|
|
|
|
mylogger.info("Translate volume data based on the average translation vector")
|
|
|
|
for frame_id in range(0, tvec_donor_volume.shape[0]):
|
|
|
|
volume[frame_id, ...] = tv.transforms.functional.affine(
|
|
|
|
img=volume[frame_id, ...].unsqueeze(0),
|
|
|
|
angle=0,
|
|
|
|
translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]],
|
|
|
|
scale=1.0,
|
|
|
|
shear=0,
|
|
|
|
interpolation=tv.transforms.InterpolationMode.BILINEAR,
|
|
|
|
fill=fill_value,
|
|
|
|
).squeeze(0)
|
|
|
|
|
|
|
|
return (acceptor, donor, oxygenation, volume, tvec_donor_volume)
|