diff --git a/stage_1_get_ref_image.py b/stage_1_get_ref_image.py index 55435f4..637e324 100644 --- a/stage_1_get_ref_image.py +++ b/stage_1_get_ref_image.py @@ -111,7 +111,6 @@ for i in range(0, len(meta_channels)): heartbeat_ts: torch.Tensor = bandpass( data=data[..., i], - device=data.device, low_frequency=config["lower_freqency_bandpass"], high_frequency=config["upper_freqency_bandpass"], fs=sample_frequency, diff --git a/stage_4_process.py b/stage_4_process.py index b822856..c0569d4 100644 --- a/stage_4_process.py +++ b/stage_4_process.py @@ -12,7 +12,6 @@ from functions.load_config import load_config from functions.get_experiments import get_experiments from functions.get_trials import get_trials from functions.binning import binning -from functions.ImageAlignment import ImageAlignment from functions.align_refref import align_refref from functions.perform_donor_volume_rotation import perform_donor_volume_rotation from functions.perform_donor_volume_translation import perform_donor_volume_translation @@ -127,7 +126,9 @@ def process_trial( mylogger.info(f"Loading ref file data: {ref_image_path_acceptor}") ref_image_acceptor: torch.Tensor = torch.tensor( - np.load(ref_image_path_acceptor).astype(dtype_np), dtype=dtype, device=device + np.load(ref_image_path_acceptor).astype(dtype_np), + dtype=dtype, + device=data.device, ) ref_image_path_donor: str = os.path.join(ref_image_path, "donor.npy") @@ -138,7 +139,7 @@ def process_trial( mylogger.info(f"Loading ref file data: {ref_image_path_donor}") ref_image_donor: torch.Tensor = torch.tensor( - np.load(ref_image_path_donor).astype(dtype_np), dtype=dtype, device=device + np.load(ref_image_path_donor).astype(dtype_np), dtype=dtype, device=data.device ) ref_image_path_oxygenation: str = os.path.join(ref_image_path, "oxygenation.npy") @@ -149,7 +150,9 @@ def process_trial( mylogger.info(f"Loading ref file data: {ref_image_path_oxygenation}") ref_image_oxygenation: torch.Tensor = torch.tensor( - np.load(ref_image_path_oxygenation).astype(dtype_np), dtype=dtype, device=device + np.load(ref_image_path_oxygenation).astype(dtype_np), + dtype=dtype, + device=data.device, ) ref_image_path_volume: str = os.path.join(ref_image_path, "volume.npy") @@ -160,7 +163,7 @@ def process_trial( mylogger.info(f"Loading ref file data: {ref_image_path_volume}") ref_image_volume: torch.Tensor = torch.tensor( - np.load(ref_image_path_volume).astype(dtype_np), dtype=dtype, device=device + np.load(ref_image_path_volume).astype(dtype_np), dtype=dtype, device=data.device ) refined_mask_file: str = os.path.join(ref_image_path, "mask_not_rotated.npy") @@ -171,7 +174,7 @@ def process_trial( mylogger.info(f"Loading mask file data: {refined_mask_file}") mask: torch.Tensor = torch.tensor( - np.load(refined_mask_file).astype(dtype_np), dtype=dtype, device=device + np.load(refined_mask_file).astype(dtype_np), dtype=dtype, device=data.device ) mylogger.info("-==- Done -==-") @@ -190,7 +193,7 @@ def process_trial( kernel_size=int(config["binning_kernel_size"]), stride=int(config["binning_stride"]), divisor_override=int(config["binning_divisor_override"]), - ).to(device=device) + ).to(device=data.device) ref_image_acceptor = ( binning( ref_image_acceptor.unsqueeze(-1).unsqueeze(-1), @@ -245,8 +248,6 @@ def process_trial( mylogger.info("-==- Done -==-") mylogger.info("Preparing alignment") - image_alignment = ImageAlignment(default_dtype=dtype, device=device) - mylogger.info("Re-order Raw data") data = data.moveaxis(-2, 0).moveaxis(-1, 0) mylogger.info(f"Data shape: {data.shape}") @@ -260,7 +261,6 @@ def process_trial( mylogger=mylogger, ref_image_acceptor=ref_image_acceptor, ref_image_donor=ref_image_donor, - image_alignment=image_alignment, batch_size=config["alignment_batch_size"], fill_value=-100.0, ) @@ -372,7 +372,6 @@ def process_trial( volume=data[volume_index, ...], ref_image_donor=ref_image_donor, ref_image_volume=ref_image_volume, - image_alignment=image_alignment, batch_size=config["alignment_batch_size"], fill_value=-100.0, config=config, @@ -409,7 +408,6 @@ def process_trial( volume=data[volume_index, ...], ref_image_donor=ref_image_donor, ref_image_volume=ref_image_volume, - image_alignment=image_alignment, batch_size=config["alignment_batch_size"], fill_value=-100.0, config=config, @@ -474,7 +472,6 @@ def process_trial( mylogger.info("Extract heartbeat from volume signal") heartbeat_ts: torch.Tensor = bandpass( data=data[volume_index, ...].movedim(0, -1).clone(), - device=data.device, low_frequency=config["lower_freqency_bandpass"], high_frequency=config["upper_freqency_bandpass"], fs=sample_frequency, @@ -488,7 +485,16 @@ def process_trial( heartbeat_ts = heartbeat_ts.movedim(0, -1) heartbeat_ts -= heartbeat_ts.mean(dim=0, keepdim=True) - volume_heartbeat, _, _ = torch.linalg.svd(heartbeat_ts, full_matrices=False) + try: + volume_heartbeat, _, _ = torch.linalg.svd(heartbeat_ts, full_matrices=False) + except torch.cuda.OutOfMemoryError: + mylogger.info("torch.cuda.OutOfMemoryError: Fallback to cpu") + volume_heartbeat_cpu, _, _ = torch.linalg.svd( + heartbeat_ts.cpu(), full_matrices=False + ) + volume_heartbeat = volume_heartbeat_cpu.to(heartbeat_ts.data, copy=True) + del volume_heartbeat_cpu + volume_heartbeat = volume_heartbeat[:, 0] volume_heartbeat -= volume_heartbeat[ config["skip_frames_in_the_beginning"] : @@ -525,7 +531,6 @@ def process_trial( for i in range(0, data.shape[0]): y = bandpass( data=data[i, ...].movedim(0, -1).clone(), - device=data.device, low_frequency=config["lower_freqency_bandpass"], high_frequency=config["upper_freqency_bandpass"], fs=sample_frequency, @@ -909,10 +914,21 @@ for experiment_counter in range(0, experiments.shape[0]): ) mylogger.info("") - process_trial( - config=config, - mylogger=mylogger, - experiment_id=experiment_id, - trial_id=trial_id, - device=device, - ) + try: + process_trial( + config=config, + mylogger=mylogger, + experiment_id=experiment_id, + trial_id=trial_id, + device=device, + ) + except torch.cuda.OutOfMemoryError: + mylogger.info("WARNING: RUNNING IN FAILBACK MODE!!!!") + mylogger.info("Not enough GPU memory. Retry on CPU") + process_trial( + config=config, + mylogger=mylogger, + experiment_id=experiment_id, + trial_id=trial_id, + device=torch.device("cpu"), + )