Add files via upload

This commit is contained in:
David Rotermund 2024-02-28 18:55:37 +01:00 committed by GitHub
parent 7a4f34bdc3
commit dfa2beae76
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 245 additions and 7 deletions

View file

@ -11,11 +11,14 @@ def align_refref(
mylogger: logging.Logger,
ref_image_acceptor: torch.Tensor,
ref_image_donor: torch.Tensor,
image_alignment: ImageAlignment,
batch_size: int,
fill_value: float = 0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
image_alignment = ImageAlignment(
default_dtype=ref_image_acceptor.dtype, device=ref_image_acceptor.device
)
mylogger.info("Rotate ref image acceptor onto donor")
angle_refref = calculate_rotation(
image_alignment=image_alignment,

View file

@ -57,21 +57,49 @@ def chunk_iterator(array: torch.Tensor, chunk_size: int):
@torch.no_grad()
def bandpass(
data: torch.Tensor,
device: torch.device,
low_frequency: float = 0.1,
high_frequency: float = 1.0,
fs=30.0,
filtfilt_chuck_size: int = 10,
) -> torch.Tensor:
try:
return bandpass_internal(
data=data,
low_frequency=low_frequency,
high_frequency=high_frequency,
fs=fs,
filtfilt_chuck_size=filtfilt_chuck_size,
)
except torch.cuda.OutOfMemoryError:
return bandpass_internal(
data=data.cpu(),
low_frequency=low_frequency,
high_frequency=high_frequency,
fs=fs,
filtfilt_chuck_size=filtfilt_chuck_size,
).to(device=data.device)
@torch.no_grad()
def bandpass_internal(
data: torch.Tensor,
low_frequency: float = 0.1,
high_frequency: float = 1.0,
fs=30.0,
filtfilt_chuck_size: int = 10,
) -> torch.Tensor:
butter_a, butter_b = butter_bandpass(
device=device,
device=data.device,
low_frequency=low_frequency,
high_frequency=high_frequency,
fs=fs,
)
index_full_dataset: torch.Tensor = torch.arange(
0, data.shape[1], device=device, dtype=torch.int64
0, data.shape[1], device=data.device, dtype=torch.int64
)
for chunk in chunk_iterator(index_full_dataset, filtfilt_chuck_size):

View file

@ -1,6 +1,7 @@
import torch
@torch.no_grad()
def binning(
data: torch.Tensor,
kernel_size: int = 4,
@ -8,6 +9,30 @@ def binning(
divisor_override: int | None = 1,
) -> torch.Tensor:
try:
return binning_internal(
data=data,
kernel_size=kernel_size,
stride=stride,
divisor_override=divisor_override,
)
except torch.cuda.OutOfMemoryError:
return binning_internal(
data=data.cpu(),
kernel_size=kernel_size,
stride=stride,
divisor_override=divisor_override,
).to(device=data.device)
@torch.no_grad()
def binning_internal(
data: torch.Tensor,
kernel_size: int = 4,
stride: int = 4,
divisor_override: int | None = 1,
) -> torch.Tensor:
assert data.ndim == 4
return (
torch.nn.functional.avg_pool2d(

View file

@ -11,6 +11,47 @@ def gauss_smear_individual(
use_matlab_mask: bool = True,
epsilon: float = float(torch.finfo(torch.float64).eps),
) -> tuple[torch.Tensor, torch.Tensor]:
try:
return gauss_smear_individual_core(
input=input,
spatial_width=spatial_width,
temporal_width=temporal_width,
overwrite_fft_gauss=overwrite_fft_gauss,
use_matlab_mask=use_matlab_mask,
epsilon=epsilon,
)
except torch.cuda.OutOfMemoryError:
if overwrite_fft_gauss is None:
overwrite_fft_gauss_cpu: None | torch.Tensor = None
else:
overwrite_fft_gauss_cpu = overwrite_fft_gauss.cpu()
input_cpu: torch.Tensor = input.cpu()
output, overwrite_fft_gauss = gauss_smear_individual_core(
input=input_cpu,
spatial_width=spatial_width,
temporal_width=temporal_width,
overwrite_fft_gauss=overwrite_fft_gauss_cpu,
use_matlab_mask=use_matlab_mask,
epsilon=epsilon,
)
return (
output.to(device=input.device),
overwrite_fft_gauss.to(device=input.device),
)
@torch.no_grad()
def gauss_smear_individual_core(
input: torch.Tensor,
spatial_width: float,
temporal_width: float,
overwrite_fft_gauss: None | torch.Tensor = None,
use_matlab_mask: bool = True,
epsilon: float = float(torch.finfo(torch.float64).eps),
) -> tuple[torch.Tensor, torch.Tensor]:
dim_x: int = int(2 * math.ceil(2 * spatial_width) + 1)
dim_y: int = int(2 * math.ceil(2 * spatial_width) + 1)

View file

@ -14,7 +14,6 @@ def perform_donor_volume_rotation(
volume: torch.Tensor,
ref_image_donor: torch.Tensor,
ref_image_volume: torch.Tensor,
image_alignment: ImageAlignment,
batch_size: int,
config: dict,
fill_value: float = 0,
@ -25,6 +24,74 @@ def perform_donor_volume_rotation(
torch.Tensor,
torch.Tensor,
]:
try:
return perform_donor_volume_rotation_internal(
mylogger=mylogger,
acceptor=acceptor,
donor=donor,
oxygenation=oxygenation,
volume=volume,
ref_image_donor=ref_image_donor,
ref_image_volume=ref_image_volume,
batch_size=batch_size,
config=config,
fill_value=fill_value,
)
except torch.cuda.OutOfMemoryError:
(
acceptor_cpu,
donor_cpu,
oxygenation_cpu,
volume_cpu,
angle_donor_volume_cpu,
) = perform_donor_volume_rotation_internal(
mylogger=mylogger,
acceptor=acceptor.cpu(),
donor=donor.cpu(),
oxygenation=oxygenation.cpu(),
volume=volume.cpu(),
ref_image_donor=ref_image_donor.cpu(),
ref_image_volume=ref_image_volume.cpu(),
batch_size=batch_size,
config=config,
fill_value=fill_value,
)
return (
acceptor_cpu.to(device=acceptor.device),
donor_cpu.to(device=acceptor.device),
oxygenation_cpu.to(device=acceptor.device),
volume_cpu.to(device=acceptor.device),
angle_donor_volume_cpu.to(device=acceptor.device),
)
@torch.no_grad()
def perform_donor_volume_rotation_internal(
mylogger: logging.Logger,
acceptor: torch.Tensor,
donor: torch.Tensor,
oxygenation: torch.Tensor,
volume: torch.Tensor,
ref_image_donor: torch.Tensor,
ref_image_volume: torch.Tensor,
batch_size: int,
config: dict,
fill_value: float = 0,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
image_alignment = ImageAlignment(
default_dtype=acceptor.dtype, device=acceptor.device
)
mylogger.info("Calculate rotation between donor data and donor ref image")

View file

@ -15,7 +15,6 @@ def perform_donor_volume_translation(
volume: torch.Tensor,
ref_image_donor: torch.Tensor,
ref_image_volume: torch.Tensor,
image_alignment: ImageAlignment,
batch_size: int,
config: dict,
fill_value: float = 0,
@ -26,6 +25,74 @@ def perform_donor_volume_translation(
torch.Tensor,
torch.Tensor,
]:
try:
return perform_donor_volume_translation_internal(
mylogger=mylogger,
acceptor=acceptor,
donor=donor,
oxygenation=oxygenation,
volume=volume,
ref_image_donor=ref_image_donor,
ref_image_volume=ref_image_volume,
batch_size=batch_size,
config=config,
fill_value=fill_value,
)
except torch.cuda.OutOfMemoryError:
(
acceptor_cpu,
donor_cpu,
oxygenation_cpu,
volume_cpu,
tvec_donor_volume_cpu,
) = perform_donor_volume_translation_internal(
mylogger=mylogger,
acceptor=acceptor.cpu(),
donor=donor.cpu(),
oxygenation=oxygenation.cpu(),
volume=volume.cpu(),
ref_image_donor=ref_image_donor.cpu(),
ref_image_volume=ref_image_volume.cpu(),
batch_size=batch_size,
config=config,
fill_value=fill_value,
)
return (
acceptor_cpu.to(device=acceptor.device),
donor_cpu.to(device=acceptor.device),
oxygenation_cpu.to(device=acceptor.device),
volume_cpu.to(device=acceptor.device),
tvec_donor_volume_cpu.to(device=acceptor.device),
)
@torch.no_grad()
def perform_donor_volume_translation_internal(
mylogger: logging.Logger,
acceptor: torch.Tensor,
donor: torch.Tensor,
oxygenation: torch.Tensor,
volume: torch.Tensor,
ref_image_donor: torch.Tensor,
ref_image_volume: torch.Tensor,
batch_size: int,
config: dict,
fill_value: float = 0,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
image_alignment = ImageAlignment(
default_dtype=acceptor.dtype, device=acceptor.device
)
mylogger.info("Calculate translation between donor data and donor ref image")
tvec_donor = calculate_translation(

View file

@ -11,7 +11,14 @@ def regression_internal(
regressor = input_regressor - regressor_offset
target = input_target - target_offset
coefficients, _, _, _ = torch.linalg.lstsq(regressor, target, rcond=None) # None ?
try:
coefficients, _, _, _ = torch.linalg.lstsq(regressor, target, rcond=None)
except torch.cuda.OutOfMemoryError:
coefficients_cpu, _, _, _ = torch.linalg.lstsq(
regressor.cpu(), target.cpu(), rcond=None
)
coefficients = coefficients_cpu.to(regressor.device, copy=True)
del coefficients_cpu
intercept = target_offset.squeeze(-1) - (
coefficients * regressor_offset.squeeze(-2)