Add files via upload
This commit is contained in:
parent
63bec06690
commit
7a4f34bdc3
2 changed files with 38 additions and 23 deletions
|
@ -111,7 +111,6 @@ for i in range(0, len(meta_channels)):
|
||||||
|
|
||||||
heartbeat_ts: torch.Tensor = bandpass(
|
heartbeat_ts: torch.Tensor = bandpass(
|
||||||
data=data[..., i],
|
data=data[..., i],
|
||||||
device=data.device,
|
|
||||||
low_frequency=config["lower_freqency_bandpass"],
|
low_frequency=config["lower_freqency_bandpass"],
|
||||||
high_frequency=config["upper_freqency_bandpass"],
|
high_frequency=config["upper_freqency_bandpass"],
|
||||||
fs=sample_frequency,
|
fs=sample_frequency,
|
||||||
|
|
|
@ -12,7 +12,6 @@ from functions.load_config import load_config
|
||||||
from functions.get_experiments import get_experiments
|
from functions.get_experiments import get_experiments
|
||||||
from functions.get_trials import get_trials
|
from functions.get_trials import get_trials
|
||||||
from functions.binning import binning
|
from functions.binning import binning
|
||||||
from functions.ImageAlignment import ImageAlignment
|
|
||||||
from functions.align_refref import align_refref
|
from functions.align_refref import align_refref
|
||||||
from functions.perform_donor_volume_rotation import perform_donor_volume_rotation
|
from functions.perform_donor_volume_rotation import perform_donor_volume_rotation
|
||||||
from functions.perform_donor_volume_translation import perform_donor_volume_translation
|
from functions.perform_donor_volume_translation import perform_donor_volume_translation
|
||||||
|
@ -127,7 +126,9 @@ def process_trial(
|
||||||
|
|
||||||
mylogger.info(f"Loading ref file data: {ref_image_path_acceptor}")
|
mylogger.info(f"Loading ref file data: {ref_image_path_acceptor}")
|
||||||
ref_image_acceptor: torch.Tensor = torch.tensor(
|
ref_image_acceptor: torch.Tensor = torch.tensor(
|
||||||
np.load(ref_image_path_acceptor).astype(dtype_np), dtype=dtype, device=device
|
np.load(ref_image_path_acceptor).astype(dtype_np),
|
||||||
|
dtype=dtype,
|
||||||
|
device=data.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
ref_image_path_donor: str = os.path.join(ref_image_path, "donor.npy")
|
ref_image_path_donor: str = os.path.join(ref_image_path, "donor.npy")
|
||||||
|
@ -138,7 +139,7 @@ def process_trial(
|
||||||
|
|
||||||
mylogger.info(f"Loading ref file data: {ref_image_path_donor}")
|
mylogger.info(f"Loading ref file data: {ref_image_path_donor}")
|
||||||
ref_image_donor: torch.Tensor = torch.tensor(
|
ref_image_donor: torch.Tensor = torch.tensor(
|
||||||
np.load(ref_image_path_donor).astype(dtype_np), dtype=dtype, device=device
|
np.load(ref_image_path_donor).astype(dtype_np), dtype=dtype, device=data.device
|
||||||
)
|
)
|
||||||
|
|
||||||
ref_image_path_oxygenation: str = os.path.join(ref_image_path, "oxygenation.npy")
|
ref_image_path_oxygenation: str = os.path.join(ref_image_path, "oxygenation.npy")
|
||||||
|
@ -149,7 +150,9 @@ def process_trial(
|
||||||
|
|
||||||
mylogger.info(f"Loading ref file data: {ref_image_path_oxygenation}")
|
mylogger.info(f"Loading ref file data: {ref_image_path_oxygenation}")
|
||||||
ref_image_oxygenation: torch.Tensor = torch.tensor(
|
ref_image_oxygenation: torch.Tensor = torch.tensor(
|
||||||
np.load(ref_image_path_oxygenation).astype(dtype_np), dtype=dtype, device=device
|
np.load(ref_image_path_oxygenation).astype(dtype_np),
|
||||||
|
dtype=dtype,
|
||||||
|
device=data.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
ref_image_path_volume: str = os.path.join(ref_image_path, "volume.npy")
|
ref_image_path_volume: str = os.path.join(ref_image_path, "volume.npy")
|
||||||
|
@ -160,7 +163,7 @@ def process_trial(
|
||||||
|
|
||||||
mylogger.info(f"Loading ref file data: {ref_image_path_volume}")
|
mylogger.info(f"Loading ref file data: {ref_image_path_volume}")
|
||||||
ref_image_volume: torch.Tensor = torch.tensor(
|
ref_image_volume: torch.Tensor = torch.tensor(
|
||||||
np.load(ref_image_path_volume).astype(dtype_np), dtype=dtype, device=device
|
np.load(ref_image_path_volume).astype(dtype_np), dtype=dtype, device=data.device
|
||||||
)
|
)
|
||||||
|
|
||||||
refined_mask_file: str = os.path.join(ref_image_path, "mask_not_rotated.npy")
|
refined_mask_file: str = os.path.join(ref_image_path, "mask_not_rotated.npy")
|
||||||
|
@ -171,7 +174,7 @@ def process_trial(
|
||||||
|
|
||||||
mylogger.info(f"Loading mask file data: {refined_mask_file}")
|
mylogger.info(f"Loading mask file data: {refined_mask_file}")
|
||||||
mask: torch.Tensor = torch.tensor(
|
mask: torch.Tensor = torch.tensor(
|
||||||
np.load(refined_mask_file).astype(dtype_np), dtype=dtype, device=device
|
np.load(refined_mask_file).astype(dtype_np), dtype=dtype, device=data.device
|
||||||
)
|
)
|
||||||
mylogger.info("-==- Done -==-")
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
|
@ -190,7 +193,7 @@ def process_trial(
|
||||||
kernel_size=int(config["binning_kernel_size"]),
|
kernel_size=int(config["binning_kernel_size"]),
|
||||||
stride=int(config["binning_stride"]),
|
stride=int(config["binning_stride"]),
|
||||||
divisor_override=int(config["binning_divisor_override"]),
|
divisor_override=int(config["binning_divisor_override"]),
|
||||||
).to(device=device)
|
).to(device=data.device)
|
||||||
ref_image_acceptor = (
|
ref_image_acceptor = (
|
||||||
binning(
|
binning(
|
||||||
ref_image_acceptor.unsqueeze(-1).unsqueeze(-1),
|
ref_image_acceptor.unsqueeze(-1).unsqueeze(-1),
|
||||||
|
@ -245,8 +248,6 @@ def process_trial(
|
||||||
mylogger.info("-==- Done -==-")
|
mylogger.info("-==- Done -==-")
|
||||||
|
|
||||||
mylogger.info("Preparing alignment")
|
mylogger.info("Preparing alignment")
|
||||||
image_alignment = ImageAlignment(default_dtype=dtype, device=device)
|
|
||||||
|
|
||||||
mylogger.info("Re-order Raw data")
|
mylogger.info("Re-order Raw data")
|
||||||
data = data.moveaxis(-2, 0).moveaxis(-1, 0)
|
data = data.moveaxis(-2, 0).moveaxis(-1, 0)
|
||||||
mylogger.info(f"Data shape: {data.shape}")
|
mylogger.info(f"Data shape: {data.shape}")
|
||||||
|
@ -260,7 +261,6 @@ def process_trial(
|
||||||
mylogger=mylogger,
|
mylogger=mylogger,
|
||||||
ref_image_acceptor=ref_image_acceptor,
|
ref_image_acceptor=ref_image_acceptor,
|
||||||
ref_image_donor=ref_image_donor,
|
ref_image_donor=ref_image_donor,
|
||||||
image_alignment=image_alignment,
|
|
||||||
batch_size=config["alignment_batch_size"],
|
batch_size=config["alignment_batch_size"],
|
||||||
fill_value=-100.0,
|
fill_value=-100.0,
|
||||||
)
|
)
|
||||||
|
@ -372,7 +372,6 @@ def process_trial(
|
||||||
volume=data[volume_index, ...],
|
volume=data[volume_index, ...],
|
||||||
ref_image_donor=ref_image_donor,
|
ref_image_donor=ref_image_donor,
|
||||||
ref_image_volume=ref_image_volume,
|
ref_image_volume=ref_image_volume,
|
||||||
image_alignment=image_alignment,
|
|
||||||
batch_size=config["alignment_batch_size"],
|
batch_size=config["alignment_batch_size"],
|
||||||
fill_value=-100.0,
|
fill_value=-100.0,
|
||||||
config=config,
|
config=config,
|
||||||
|
@ -409,7 +408,6 @@ def process_trial(
|
||||||
volume=data[volume_index, ...],
|
volume=data[volume_index, ...],
|
||||||
ref_image_donor=ref_image_donor,
|
ref_image_donor=ref_image_donor,
|
||||||
ref_image_volume=ref_image_volume,
|
ref_image_volume=ref_image_volume,
|
||||||
image_alignment=image_alignment,
|
|
||||||
batch_size=config["alignment_batch_size"],
|
batch_size=config["alignment_batch_size"],
|
||||||
fill_value=-100.0,
|
fill_value=-100.0,
|
||||||
config=config,
|
config=config,
|
||||||
|
@ -474,7 +472,6 @@ def process_trial(
|
||||||
mylogger.info("Extract heartbeat from volume signal")
|
mylogger.info("Extract heartbeat from volume signal")
|
||||||
heartbeat_ts: torch.Tensor = bandpass(
|
heartbeat_ts: torch.Tensor = bandpass(
|
||||||
data=data[volume_index, ...].movedim(0, -1).clone(),
|
data=data[volume_index, ...].movedim(0, -1).clone(),
|
||||||
device=data.device,
|
|
||||||
low_frequency=config["lower_freqency_bandpass"],
|
low_frequency=config["lower_freqency_bandpass"],
|
||||||
high_frequency=config["upper_freqency_bandpass"],
|
high_frequency=config["upper_freqency_bandpass"],
|
||||||
fs=sample_frequency,
|
fs=sample_frequency,
|
||||||
|
@ -488,7 +485,16 @@ def process_trial(
|
||||||
heartbeat_ts = heartbeat_ts.movedim(0, -1)
|
heartbeat_ts = heartbeat_ts.movedim(0, -1)
|
||||||
heartbeat_ts -= heartbeat_ts.mean(dim=0, keepdim=True)
|
heartbeat_ts -= heartbeat_ts.mean(dim=0, keepdim=True)
|
||||||
|
|
||||||
|
try:
|
||||||
volume_heartbeat, _, _ = torch.linalg.svd(heartbeat_ts, full_matrices=False)
|
volume_heartbeat, _, _ = torch.linalg.svd(heartbeat_ts, full_matrices=False)
|
||||||
|
except torch.cuda.OutOfMemoryError:
|
||||||
|
mylogger.info("torch.cuda.OutOfMemoryError: Fallback to cpu")
|
||||||
|
volume_heartbeat_cpu, _, _ = torch.linalg.svd(
|
||||||
|
heartbeat_ts.cpu(), full_matrices=False
|
||||||
|
)
|
||||||
|
volume_heartbeat = volume_heartbeat_cpu.to(heartbeat_ts.data, copy=True)
|
||||||
|
del volume_heartbeat_cpu
|
||||||
|
|
||||||
volume_heartbeat = volume_heartbeat[:, 0]
|
volume_heartbeat = volume_heartbeat[:, 0]
|
||||||
volume_heartbeat -= volume_heartbeat[
|
volume_heartbeat -= volume_heartbeat[
|
||||||
config["skip_frames_in_the_beginning"] :
|
config["skip_frames_in_the_beginning"] :
|
||||||
|
@ -525,7 +531,6 @@ def process_trial(
|
||||||
for i in range(0, data.shape[0]):
|
for i in range(0, data.shape[0]):
|
||||||
y = bandpass(
|
y = bandpass(
|
||||||
data=data[i, ...].movedim(0, -1).clone(),
|
data=data[i, ...].movedim(0, -1).clone(),
|
||||||
device=data.device,
|
|
||||||
low_frequency=config["lower_freqency_bandpass"],
|
low_frequency=config["lower_freqency_bandpass"],
|
||||||
high_frequency=config["upper_freqency_bandpass"],
|
high_frequency=config["upper_freqency_bandpass"],
|
||||||
fs=sample_frequency,
|
fs=sample_frequency,
|
||||||
|
@ -909,6 +914,7 @@ for experiment_counter in range(0, experiments.shape[0]):
|
||||||
)
|
)
|
||||||
mylogger.info("")
|
mylogger.info("")
|
||||||
|
|
||||||
|
try:
|
||||||
process_trial(
|
process_trial(
|
||||||
config=config,
|
config=config,
|
||||||
mylogger=mylogger,
|
mylogger=mylogger,
|
||||||
|
@ -916,3 +922,13 @@ for experiment_counter in range(0, experiments.shape[0]):
|
||||||
trial_id=trial_id,
|
trial_id=trial_id,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
except torch.cuda.OutOfMemoryError:
|
||||||
|
mylogger.info("WARNING: RUNNING IN FAILBACK MODE!!!!")
|
||||||
|
mylogger.info("Not enough GPU memory. Retry on CPU")
|
||||||
|
process_trial(
|
||||||
|
config=config,
|
||||||
|
mylogger=mylogger,
|
||||||
|
experiment_id=experiment_id,
|
||||||
|
trial_id=trial_id,
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in a new issue