Add files via upload

This commit is contained in:
David Rotermund 2024-02-28 18:55:09 +01:00 committed by GitHub
parent 63bec06690
commit 7a4f34bdc3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 38 additions and 23 deletions

View file

@ -111,7 +111,6 @@ for i in range(0, len(meta_channels)):
heartbeat_ts: torch.Tensor = bandpass(
data=data[..., i],
device=data.device,
low_frequency=config["lower_freqency_bandpass"],
high_frequency=config["upper_freqency_bandpass"],
fs=sample_frequency,

View file

@ -12,7 +12,6 @@ from functions.load_config import load_config
from functions.get_experiments import get_experiments
from functions.get_trials import get_trials
from functions.binning import binning
from functions.ImageAlignment import ImageAlignment
from functions.align_refref import align_refref
from functions.perform_donor_volume_rotation import perform_donor_volume_rotation
from functions.perform_donor_volume_translation import perform_donor_volume_translation
@ -127,7 +126,9 @@ def process_trial(
mylogger.info(f"Loading ref file data: {ref_image_path_acceptor}")
ref_image_acceptor: torch.Tensor = torch.tensor(
np.load(ref_image_path_acceptor).astype(dtype_np), dtype=dtype, device=device
np.load(ref_image_path_acceptor).astype(dtype_np),
dtype=dtype,
device=data.device,
)
ref_image_path_donor: str = os.path.join(ref_image_path, "donor.npy")
@ -138,7 +139,7 @@ def process_trial(
mylogger.info(f"Loading ref file data: {ref_image_path_donor}")
ref_image_donor: torch.Tensor = torch.tensor(
np.load(ref_image_path_donor).astype(dtype_np), dtype=dtype, device=device
np.load(ref_image_path_donor).astype(dtype_np), dtype=dtype, device=data.device
)
ref_image_path_oxygenation: str = os.path.join(ref_image_path, "oxygenation.npy")
@ -149,7 +150,9 @@ def process_trial(
mylogger.info(f"Loading ref file data: {ref_image_path_oxygenation}")
ref_image_oxygenation: torch.Tensor = torch.tensor(
np.load(ref_image_path_oxygenation).astype(dtype_np), dtype=dtype, device=device
np.load(ref_image_path_oxygenation).astype(dtype_np),
dtype=dtype,
device=data.device,
)
ref_image_path_volume: str = os.path.join(ref_image_path, "volume.npy")
@ -160,7 +163,7 @@ def process_trial(
mylogger.info(f"Loading ref file data: {ref_image_path_volume}")
ref_image_volume: torch.Tensor = torch.tensor(
np.load(ref_image_path_volume).astype(dtype_np), dtype=dtype, device=device
np.load(ref_image_path_volume).astype(dtype_np), dtype=dtype, device=data.device
)
refined_mask_file: str = os.path.join(ref_image_path, "mask_not_rotated.npy")
@ -171,7 +174,7 @@ def process_trial(
mylogger.info(f"Loading mask file data: {refined_mask_file}")
mask: torch.Tensor = torch.tensor(
np.load(refined_mask_file).astype(dtype_np), dtype=dtype, device=device
np.load(refined_mask_file).astype(dtype_np), dtype=dtype, device=data.device
)
mylogger.info("-==- Done -==-")
@ -190,7 +193,7 @@ def process_trial(
kernel_size=int(config["binning_kernel_size"]),
stride=int(config["binning_stride"]),
divisor_override=int(config["binning_divisor_override"]),
).to(device=device)
).to(device=data.device)
ref_image_acceptor = (
binning(
ref_image_acceptor.unsqueeze(-1).unsqueeze(-1),
@ -245,8 +248,6 @@ def process_trial(
mylogger.info("-==- Done -==-")
mylogger.info("Preparing alignment")
image_alignment = ImageAlignment(default_dtype=dtype, device=device)
mylogger.info("Re-order Raw data")
data = data.moveaxis(-2, 0).moveaxis(-1, 0)
mylogger.info(f"Data shape: {data.shape}")
@ -260,7 +261,6 @@ def process_trial(
mylogger=mylogger,
ref_image_acceptor=ref_image_acceptor,
ref_image_donor=ref_image_donor,
image_alignment=image_alignment,
batch_size=config["alignment_batch_size"],
fill_value=-100.0,
)
@ -372,7 +372,6 @@ def process_trial(
volume=data[volume_index, ...],
ref_image_donor=ref_image_donor,
ref_image_volume=ref_image_volume,
image_alignment=image_alignment,
batch_size=config["alignment_batch_size"],
fill_value=-100.0,
config=config,
@ -409,7 +408,6 @@ def process_trial(
volume=data[volume_index, ...],
ref_image_donor=ref_image_donor,
ref_image_volume=ref_image_volume,
image_alignment=image_alignment,
batch_size=config["alignment_batch_size"],
fill_value=-100.0,
config=config,
@ -474,7 +472,6 @@ def process_trial(
mylogger.info("Extract heartbeat from volume signal")
heartbeat_ts: torch.Tensor = bandpass(
data=data[volume_index, ...].movedim(0, -1).clone(),
device=data.device,
low_frequency=config["lower_freqency_bandpass"],
high_frequency=config["upper_freqency_bandpass"],
fs=sample_frequency,
@ -488,7 +485,16 @@ def process_trial(
heartbeat_ts = heartbeat_ts.movedim(0, -1)
heartbeat_ts -= heartbeat_ts.mean(dim=0, keepdim=True)
try:
volume_heartbeat, _, _ = torch.linalg.svd(heartbeat_ts, full_matrices=False)
except torch.cuda.OutOfMemoryError:
mylogger.info("torch.cuda.OutOfMemoryError: Fallback to cpu")
volume_heartbeat_cpu, _, _ = torch.linalg.svd(
heartbeat_ts.cpu(), full_matrices=False
)
volume_heartbeat = volume_heartbeat_cpu.to(heartbeat_ts.data, copy=True)
del volume_heartbeat_cpu
volume_heartbeat = volume_heartbeat[:, 0]
volume_heartbeat -= volume_heartbeat[
config["skip_frames_in_the_beginning"] :
@ -525,7 +531,6 @@ def process_trial(
for i in range(0, data.shape[0]):
y = bandpass(
data=data[i, ...].movedim(0, -1).clone(),
device=data.device,
low_frequency=config["lower_freqency_bandpass"],
high_frequency=config["upper_freqency_bandpass"],
fs=sample_frequency,
@ -909,6 +914,7 @@ for experiment_counter in range(0, experiments.shape[0]):
)
mylogger.info("")
try:
process_trial(
config=config,
mylogger=mylogger,
@ -916,3 +922,13 @@ for experiment_counter in range(0, experiments.shape[0]):
trial_id=trial_id,
device=device,
)
except torch.cuda.OutOfMemoryError:
mylogger.info("WARNING: RUNNING IN FAILBACK MODE!!!!")
mylogger.info("Not enough GPU memory. Retry on CPU")
process_trial(
config=config,
mylogger=mylogger,
experiment_id=experiment_id,
trial_id=trial_id,
device=torch.device("cpu"),
)