diff --git a/functions/align_refref.py b/functions/align_refref.py index 249665c..3208cf3 100644 --- a/functions/align_refref.py +++ b/functions/align_refref.py @@ -11,11 +11,14 @@ def align_refref( mylogger: logging.Logger, ref_image_acceptor: torch.Tensor, ref_image_donor: torch.Tensor, - image_alignment: ImageAlignment, batch_size: int, fill_value: float = 0, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + image_alignment = ImageAlignment( + default_dtype=ref_image_acceptor.dtype, device=ref_image_acceptor.device + ) + mylogger.info("Rotate ref image acceptor onto donor") angle_refref = calculate_rotation( image_alignment=image_alignment, diff --git a/functions/bandpass.py b/functions/bandpass.py index 171baf5..2659847 100644 --- a/functions/bandpass.py +++ b/functions/bandpass.py @@ -57,21 +57,49 @@ def chunk_iterator(array: torch.Tensor, chunk_size: int): @torch.no_grad() def bandpass( data: torch.Tensor, - device: torch.device, + low_frequency: float = 0.1, + high_frequency: float = 1.0, + fs=30.0, + filtfilt_chuck_size: int = 10, +) -> torch.Tensor: + + try: + return bandpass_internal( + data=data, + low_frequency=low_frequency, + high_frequency=high_frequency, + fs=fs, + filtfilt_chuck_size=filtfilt_chuck_size, + ) + + except torch.cuda.OutOfMemoryError: + + return bandpass_internal( + data=data.cpu(), + low_frequency=low_frequency, + high_frequency=high_frequency, + fs=fs, + filtfilt_chuck_size=filtfilt_chuck_size, + ).to(device=data.device) + + +@torch.no_grad() +def bandpass_internal( + data: torch.Tensor, low_frequency: float = 0.1, high_frequency: float = 1.0, fs=30.0, filtfilt_chuck_size: int = 10, ) -> torch.Tensor: butter_a, butter_b = butter_bandpass( - device=device, + device=data.device, low_frequency=low_frequency, high_frequency=high_frequency, fs=fs, ) index_full_dataset: torch.Tensor = torch.arange( - 0, data.shape[1], device=device, dtype=torch.int64 + 0, data.shape[1], device=data.device, dtype=torch.int64 ) for chunk in chunk_iterator(index_full_dataset, filtfilt_chuck_size): diff --git a/functions/binning.py b/functions/binning.py index ccfe657..f873433 100644 --- a/functions/binning.py +++ b/functions/binning.py @@ -1,6 +1,7 @@ import torch +@torch.no_grad() def binning( data: torch.Tensor, kernel_size: int = 4, @@ -8,6 +9,30 @@ def binning( divisor_override: int | None = 1, ) -> torch.Tensor: + try: + return binning_internal( + data=data, + kernel_size=kernel_size, + stride=stride, + divisor_override=divisor_override, + ) + except torch.cuda.OutOfMemoryError: + return binning_internal( + data=data.cpu(), + kernel_size=kernel_size, + stride=stride, + divisor_override=divisor_override, + ).to(device=data.device) + + +@torch.no_grad() +def binning_internal( + data: torch.Tensor, + kernel_size: int = 4, + stride: int = 4, + divisor_override: int | None = 1, +) -> torch.Tensor: + assert data.ndim == 4 return ( torch.nn.functional.avg_pool2d( diff --git a/functions/gauss_smear_individual.py b/functions/gauss_smear_individual.py index 36700e7..73dba65 100644 --- a/functions/gauss_smear_individual.py +++ b/functions/gauss_smear_individual.py @@ -11,6 +11,47 @@ def gauss_smear_individual( use_matlab_mask: bool = True, epsilon: float = float(torch.finfo(torch.float64).eps), ) -> tuple[torch.Tensor, torch.Tensor]: + try: + return gauss_smear_individual_core( + input=input, + spatial_width=spatial_width, + temporal_width=temporal_width, + overwrite_fft_gauss=overwrite_fft_gauss, + use_matlab_mask=use_matlab_mask, + epsilon=epsilon, + ) + except torch.cuda.OutOfMemoryError: + + if overwrite_fft_gauss is None: + overwrite_fft_gauss_cpu: None | torch.Tensor = None + else: + overwrite_fft_gauss_cpu = overwrite_fft_gauss.cpu() + + input_cpu: torch.Tensor = input.cpu() + + output, overwrite_fft_gauss = gauss_smear_individual_core( + input=input_cpu, + spatial_width=spatial_width, + temporal_width=temporal_width, + overwrite_fft_gauss=overwrite_fft_gauss_cpu, + use_matlab_mask=use_matlab_mask, + epsilon=epsilon, + ) + return ( + output.to(device=input.device), + overwrite_fft_gauss.to(device=input.device), + ) + + +@torch.no_grad() +def gauss_smear_individual_core( + input: torch.Tensor, + spatial_width: float, + temporal_width: float, + overwrite_fft_gauss: None | torch.Tensor = None, + use_matlab_mask: bool = True, + epsilon: float = float(torch.finfo(torch.float64).eps), +) -> tuple[torch.Tensor, torch.Tensor]: dim_x: int = int(2 * math.ceil(2 * spatial_width) + 1) dim_y: int = int(2 * math.ceil(2 * spatial_width) + 1) diff --git a/functions/perform_donor_volume_rotation.py b/functions/perform_donor_volume_rotation.py index 590630f..1d2f55b 100644 --- a/functions/perform_donor_volume_rotation.py +++ b/functions/perform_donor_volume_rotation.py @@ -14,7 +14,6 @@ def perform_donor_volume_rotation( volume: torch.Tensor, ref_image_donor: torch.Tensor, ref_image_volume: torch.Tensor, - image_alignment: ImageAlignment, batch_size: int, config: dict, fill_value: float = 0, @@ -25,6 +24,74 @@ def perform_donor_volume_rotation( torch.Tensor, torch.Tensor, ]: + try: + + return perform_donor_volume_rotation_internal( + mylogger=mylogger, + acceptor=acceptor, + donor=donor, + oxygenation=oxygenation, + volume=volume, + ref_image_donor=ref_image_donor, + ref_image_volume=ref_image_volume, + batch_size=batch_size, + config=config, + fill_value=fill_value, + ) + + except torch.cuda.OutOfMemoryError: + + ( + acceptor_cpu, + donor_cpu, + oxygenation_cpu, + volume_cpu, + angle_donor_volume_cpu, + ) = perform_donor_volume_rotation_internal( + mylogger=mylogger, + acceptor=acceptor.cpu(), + donor=donor.cpu(), + oxygenation=oxygenation.cpu(), + volume=volume.cpu(), + ref_image_donor=ref_image_donor.cpu(), + ref_image_volume=ref_image_volume.cpu(), + batch_size=batch_size, + config=config, + fill_value=fill_value, + ) + + return ( + acceptor_cpu.to(device=acceptor.device), + donor_cpu.to(device=acceptor.device), + oxygenation_cpu.to(device=acceptor.device), + volume_cpu.to(device=acceptor.device), + angle_donor_volume_cpu.to(device=acceptor.device), + ) + + +@torch.no_grad() +def perform_donor_volume_rotation_internal( + mylogger: logging.Logger, + acceptor: torch.Tensor, + donor: torch.Tensor, + oxygenation: torch.Tensor, + volume: torch.Tensor, + ref_image_donor: torch.Tensor, + ref_image_volume: torch.Tensor, + batch_size: int, + config: dict, + fill_value: float = 0, +) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + + image_alignment = ImageAlignment( + default_dtype=acceptor.dtype, device=acceptor.device + ) mylogger.info("Calculate rotation between donor data and donor ref image") diff --git a/functions/perform_donor_volume_translation.py b/functions/perform_donor_volume_translation.py index 7091add..72e94fa 100644 --- a/functions/perform_donor_volume_translation.py +++ b/functions/perform_donor_volume_translation.py @@ -15,7 +15,6 @@ def perform_donor_volume_translation( volume: torch.Tensor, ref_image_donor: torch.Tensor, ref_image_volume: torch.Tensor, - image_alignment: ImageAlignment, batch_size: int, config: dict, fill_value: float = 0, @@ -26,6 +25,74 @@ def perform_donor_volume_translation( torch.Tensor, torch.Tensor, ]: + try: + + return perform_donor_volume_translation_internal( + mylogger=mylogger, + acceptor=acceptor, + donor=donor, + oxygenation=oxygenation, + volume=volume, + ref_image_donor=ref_image_donor, + ref_image_volume=ref_image_volume, + batch_size=batch_size, + config=config, + fill_value=fill_value, + ) + + except torch.cuda.OutOfMemoryError: + + ( + acceptor_cpu, + donor_cpu, + oxygenation_cpu, + volume_cpu, + tvec_donor_volume_cpu, + ) = perform_donor_volume_translation_internal( + mylogger=mylogger, + acceptor=acceptor.cpu(), + donor=donor.cpu(), + oxygenation=oxygenation.cpu(), + volume=volume.cpu(), + ref_image_donor=ref_image_donor.cpu(), + ref_image_volume=ref_image_volume.cpu(), + batch_size=batch_size, + config=config, + fill_value=fill_value, + ) + + return ( + acceptor_cpu.to(device=acceptor.device), + donor_cpu.to(device=acceptor.device), + oxygenation_cpu.to(device=acceptor.device), + volume_cpu.to(device=acceptor.device), + tvec_donor_volume_cpu.to(device=acceptor.device), + ) + + +@torch.no_grad() +def perform_donor_volume_translation_internal( + mylogger: logging.Logger, + acceptor: torch.Tensor, + donor: torch.Tensor, + oxygenation: torch.Tensor, + volume: torch.Tensor, + ref_image_donor: torch.Tensor, + ref_image_volume: torch.Tensor, + batch_size: int, + config: dict, + fill_value: float = 0, +) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + + image_alignment = ImageAlignment( + default_dtype=acceptor.dtype, device=acceptor.device + ) mylogger.info("Calculate translation between donor data and donor ref image") tvec_donor = calculate_translation( diff --git a/functions/regression_internal.py b/functions/regression_internal.py index 352d7ba..dd94d3c 100644 --- a/functions/regression_internal.py +++ b/functions/regression_internal.py @@ -11,7 +11,14 @@ def regression_internal( regressor = input_regressor - regressor_offset target = input_target - target_offset - coefficients, _, _, _ = torch.linalg.lstsq(regressor, target, rcond=None) # None ? + try: + coefficients, _, _, _ = torch.linalg.lstsq(regressor, target, rcond=None) + except torch.cuda.OutOfMemoryError: + coefficients_cpu, _, _, _ = torch.linalg.lstsq( + regressor.cpu(), target.cpu(), rcond=None + ) + coefficients = coefficients_cpu.to(regressor.device, copy=True) + del coefficients_cpu intercept = target_offset.squeeze(-1) - ( coefficients * regressor_offset.squeeze(-2)