Add files via upload
This commit is contained in:
parent
63bec06690
commit
7a4f34bdc3
2 changed files with 38 additions and 23 deletions
|
@ -111,7 +111,6 @@ for i in range(0, len(meta_channels)):
|
|||
|
||||
heartbeat_ts: torch.Tensor = bandpass(
|
||||
data=data[..., i],
|
||||
device=data.device,
|
||||
low_frequency=config["lower_freqency_bandpass"],
|
||||
high_frequency=config["upper_freqency_bandpass"],
|
||||
fs=sample_frequency,
|
||||
|
|
|
@ -12,7 +12,6 @@ from functions.load_config import load_config
|
|||
from functions.get_experiments import get_experiments
|
||||
from functions.get_trials import get_trials
|
||||
from functions.binning import binning
|
||||
from functions.ImageAlignment import ImageAlignment
|
||||
from functions.align_refref import align_refref
|
||||
from functions.perform_donor_volume_rotation import perform_donor_volume_rotation
|
||||
from functions.perform_donor_volume_translation import perform_donor_volume_translation
|
||||
|
@ -127,7 +126,9 @@ def process_trial(
|
|||
|
||||
mylogger.info(f"Loading ref file data: {ref_image_path_acceptor}")
|
||||
ref_image_acceptor: torch.Tensor = torch.tensor(
|
||||
np.load(ref_image_path_acceptor).astype(dtype_np), dtype=dtype, device=device
|
||||
np.load(ref_image_path_acceptor).astype(dtype_np),
|
||||
dtype=dtype,
|
||||
device=data.device,
|
||||
)
|
||||
|
||||
ref_image_path_donor: str = os.path.join(ref_image_path, "donor.npy")
|
||||
|
@ -138,7 +139,7 @@ def process_trial(
|
|||
|
||||
mylogger.info(f"Loading ref file data: {ref_image_path_donor}")
|
||||
ref_image_donor: torch.Tensor = torch.tensor(
|
||||
np.load(ref_image_path_donor).astype(dtype_np), dtype=dtype, device=device
|
||||
np.load(ref_image_path_donor).astype(dtype_np), dtype=dtype, device=data.device
|
||||
)
|
||||
|
||||
ref_image_path_oxygenation: str = os.path.join(ref_image_path, "oxygenation.npy")
|
||||
|
@ -149,7 +150,9 @@ def process_trial(
|
|||
|
||||
mylogger.info(f"Loading ref file data: {ref_image_path_oxygenation}")
|
||||
ref_image_oxygenation: torch.Tensor = torch.tensor(
|
||||
np.load(ref_image_path_oxygenation).astype(dtype_np), dtype=dtype, device=device
|
||||
np.load(ref_image_path_oxygenation).astype(dtype_np),
|
||||
dtype=dtype,
|
||||
device=data.device,
|
||||
)
|
||||
|
||||
ref_image_path_volume: str = os.path.join(ref_image_path, "volume.npy")
|
||||
|
@ -160,7 +163,7 @@ def process_trial(
|
|||
|
||||
mylogger.info(f"Loading ref file data: {ref_image_path_volume}")
|
||||
ref_image_volume: torch.Tensor = torch.tensor(
|
||||
np.load(ref_image_path_volume).astype(dtype_np), dtype=dtype, device=device
|
||||
np.load(ref_image_path_volume).astype(dtype_np), dtype=dtype, device=data.device
|
||||
)
|
||||
|
||||
refined_mask_file: str = os.path.join(ref_image_path, "mask_not_rotated.npy")
|
||||
|
@ -171,7 +174,7 @@ def process_trial(
|
|||
|
||||
mylogger.info(f"Loading mask file data: {refined_mask_file}")
|
||||
mask: torch.Tensor = torch.tensor(
|
||||
np.load(refined_mask_file).astype(dtype_np), dtype=dtype, device=device
|
||||
np.load(refined_mask_file).astype(dtype_np), dtype=dtype, device=data.device
|
||||
)
|
||||
mylogger.info("-==- Done -==-")
|
||||
|
||||
|
@ -190,7 +193,7 @@ def process_trial(
|
|||
kernel_size=int(config["binning_kernel_size"]),
|
||||
stride=int(config["binning_stride"]),
|
||||
divisor_override=int(config["binning_divisor_override"]),
|
||||
).to(device=device)
|
||||
).to(device=data.device)
|
||||
ref_image_acceptor = (
|
||||
binning(
|
||||
ref_image_acceptor.unsqueeze(-1).unsqueeze(-1),
|
||||
|
@ -245,8 +248,6 @@ def process_trial(
|
|||
mylogger.info("-==- Done -==-")
|
||||
|
||||
mylogger.info("Preparing alignment")
|
||||
image_alignment = ImageAlignment(default_dtype=dtype, device=device)
|
||||
|
||||
mylogger.info("Re-order Raw data")
|
||||
data = data.moveaxis(-2, 0).moveaxis(-1, 0)
|
||||
mylogger.info(f"Data shape: {data.shape}")
|
||||
|
@ -260,7 +261,6 @@ def process_trial(
|
|||
mylogger=mylogger,
|
||||
ref_image_acceptor=ref_image_acceptor,
|
||||
ref_image_donor=ref_image_donor,
|
||||
image_alignment=image_alignment,
|
||||
batch_size=config["alignment_batch_size"],
|
||||
fill_value=-100.0,
|
||||
)
|
||||
|
@ -372,7 +372,6 @@ def process_trial(
|
|||
volume=data[volume_index, ...],
|
||||
ref_image_donor=ref_image_donor,
|
||||
ref_image_volume=ref_image_volume,
|
||||
image_alignment=image_alignment,
|
||||
batch_size=config["alignment_batch_size"],
|
||||
fill_value=-100.0,
|
||||
config=config,
|
||||
|
@ -409,7 +408,6 @@ def process_trial(
|
|||
volume=data[volume_index, ...],
|
||||
ref_image_donor=ref_image_donor,
|
||||
ref_image_volume=ref_image_volume,
|
||||
image_alignment=image_alignment,
|
||||
batch_size=config["alignment_batch_size"],
|
||||
fill_value=-100.0,
|
||||
config=config,
|
||||
|
@ -474,7 +472,6 @@ def process_trial(
|
|||
mylogger.info("Extract heartbeat from volume signal")
|
||||
heartbeat_ts: torch.Tensor = bandpass(
|
||||
data=data[volume_index, ...].movedim(0, -1).clone(),
|
||||
device=data.device,
|
||||
low_frequency=config["lower_freqency_bandpass"],
|
||||
high_frequency=config["upper_freqency_bandpass"],
|
||||
fs=sample_frequency,
|
||||
|
@ -488,7 +485,16 @@ def process_trial(
|
|||
heartbeat_ts = heartbeat_ts.movedim(0, -1)
|
||||
heartbeat_ts -= heartbeat_ts.mean(dim=0, keepdim=True)
|
||||
|
||||
try:
|
||||
volume_heartbeat, _, _ = torch.linalg.svd(heartbeat_ts, full_matrices=False)
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
mylogger.info("torch.cuda.OutOfMemoryError: Fallback to cpu")
|
||||
volume_heartbeat_cpu, _, _ = torch.linalg.svd(
|
||||
heartbeat_ts.cpu(), full_matrices=False
|
||||
)
|
||||
volume_heartbeat = volume_heartbeat_cpu.to(heartbeat_ts.data, copy=True)
|
||||
del volume_heartbeat_cpu
|
||||
|
||||
volume_heartbeat = volume_heartbeat[:, 0]
|
||||
volume_heartbeat -= volume_heartbeat[
|
||||
config["skip_frames_in_the_beginning"] :
|
||||
|
@ -525,7 +531,6 @@ def process_trial(
|
|||
for i in range(0, data.shape[0]):
|
||||
y = bandpass(
|
||||
data=data[i, ...].movedim(0, -1).clone(),
|
||||
device=data.device,
|
||||
low_frequency=config["lower_freqency_bandpass"],
|
||||
high_frequency=config["upper_freqency_bandpass"],
|
||||
fs=sample_frequency,
|
||||
|
@ -909,6 +914,7 @@ for experiment_counter in range(0, experiments.shape[0]):
|
|||
)
|
||||
mylogger.info("")
|
||||
|
||||
try:
|
||||
process_trial(
|
||||
config=config,
|
||||
mylogger=mylogger,
|
||||
|
@ -916,3 +922,13 @@ for experiment_counter in range(0, experiments.shape[0]):
|
|||
trial_id=trial_id,
|
||||
device=device,
|
||||
)
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
mylogger.info("WARNING: RUNNING IN FAILBACK MODE!!!!")
|
||||
mylogger.info("Not enough GPU memory. Retry on CPU")
|
||||
process_trial(
|
||||
config=config,
|
||||
mylogger=mylogger,
|
||||
experiment_id=experiment_id,
|
||||
trial_id=trial_id,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue