Add files via upload

This commit is contained in:
David Rotermund 2024-08-10 20:56:23 +02:00 committed by GitHub
parent 2e5879cd62
commit 6599c9c81d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 772 additions and 188 deletions

View file

@ -99,7 +99,7 @@ def process_trial(
free_mem = cuda_total_memory - max(
[torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)]
)
mylogger.info(f"CUDA memory: {free_mem//1024} MByte")
mylogger.info(f"CUDA memory: {free_mem // 1024} MByte")
mylogger.info(f"Data shape: {data.shape}")
mylogger.info("-==- Done -==-")
@ -266,9 +266,9 @@ def process_trial(
batch_size=config["alignment_batch_size"],
fill_value=-100.0,
)
mylogger.info(f"Rotation: {round(float(angle_refref[0]),2)} degree")
mylogger.info(f"Rotation: {round(float(angle_refref[0]), 2)} degree")
mylogger.info(
f"Translation: {round(float(tvec_refref[0]),1)} x {round(float(tvec_refref[1]),1)} pixel"
f"Translation: {round(float(tvec_refref[0]), 1)} x {round(float(tvec_refref[1]), 1)} pixel"
)
if config["save_alignment"]:
@ -285,7 +285,7 @@ def process_trial(
np.save(temp_path, tvec_refref.cpu())
mylogger.info("Moving & rotating the oxygenation ref image")
ref_image_oxygenation = tv.transforms.functional.affine(
ref_image_oxygenation = tv.transforms.functional.affine( # type: ignore
img=ref_image_oxygenation.unsqueeze(0),
angle=-float(angle_refref),
translate=[0, 0],
@ -295,7 +295,7 @@ def process_trial(
fill=-100.0,
)
ref_image_oxygenation = tv.transforms.functional.affine(
ref_image_oxygenation = tv.transforms.functional.affine( # type: ignore
img=ref_image_oxygenation,
angle=0,
translate=[tvec_refref[1], tvec_refref[0]],
@ -313,8 +313,8 @@ def process_trial(
volume_index: int = config["required_order"].index("volume")
mylogger.info("Rotate acceptor")
data[acceptor_index, ...] = tv.transforms.functional.affine(
img=data[acceptor_index, ...],
data[acceptor_index, ...] = tv.transforms.functional.affine( # type: ignore
img=data[acceptor_index, ...], # type: ignore
angle=-float(angle_refref),
translate=[0, 0],
scale=1.0,
@ -324,7 +324,7 @@ def process_trial(
)
mylogger.info("Translate acceptor")
data[acceptor_index, ...] = tv.transforms.functional.affine(
data[acceptor_index, ...] = tv.transforms.functional.affine( # type: ignore
img=data[acceptor_index, ...],
angle=0,
translate=[tvec_refref[1], tvec_refref[0]],
@ -335,7 +335,7 @@ def process_trial(
)
mylogger.info("Rotate oxygenation")
data[oxygenation_index, ...] = tv.transforms.functional.affine(
data[oxygenation_index, ...] = tv.transforms.functional.affine( # type: ignore
img=data[oxygenation_index, ...],
angle=-float(angle_refref),
translate=[0, 0],
@ -346,7 +346,7 @@ def process_trial(
)
mylogger.info("Translate oxygenation")
data[oxygenation_index, ...] = tv.transforms.functional.affine(
data[oxygenation_index, ...] = tv.transforms.functional.affine( # type: ignore
img=data[oxygenation_index, ...],
angle=0,
translate=[tvec_refref[1], tvec_refref[0]],
@ -359,7 +359,7 @@ def process_trial(
mylogger.info("Perform rotation between donor and volume and its ref images")
mylogger.info("for all frames and then rotate all the data accordingly")
perform_donor_volume_rotation
(
data[acceptor_index, ...],
data[donor_index, ...],
@ -381,9 +381,9 @@ def process_trial(
mylogger.info(
f"angles: "
f"min {round(float(angle_donor_volume.min()),2)} "
f"max {round(float(angle_donor_volume.max()),2)} "
f"mean {round(float(angle_donor_volume.mean()),2)} "
f"min {round(float(angle_donor_volume.min()), 2)} "
f"max {round(float(angle_donor_volume.max()), 2)} "
f"mean {round(float(angle_donor_volume.mean()), 2)} "
)
if config["save_alignment"]:
@ -417,15 +417,15 @@ def process_trial(
mylogger.info(
f"translation dim 0: "
f"min {round(float(tvec_donor_volume[:,0].min()),1)} "
f"max {round(float(tvec_donor_volume[:,0].max()),1)} "
f"mean {round(float(tvec_donor_volume[:,0].mean()),1)} "
f"min {round(float(tvec_donor_volume[:, 0].min()), 1)} "
f"max {round(float(tvec_donor_volume[:, 0].max()), 1)} "
f"mean {round(float(tvec_donor_volume[:, 0].mean()), 1)} "
)
mylogger.info(
f"translation dim 1: "
f"min {round(float(tvec_donor_volume[:,1].min()),1)} "
f"max {round(float(tvec_donor_volume[:,1].max()),1)} "
f"mean {round(float(tvec_donor_volume[:,1].mean()),1)} "
f"min {round(float(tvec_donor_volume[:, 1].min()), 1)} "
f"max {round(float(tvec_donor_volume[:, 1].max()), 1)} "
f"mean {round(float(tvec_donor_volume[:, 1].mean()), 1)} "
)
if config["save_alignment"]:
@ -471,172 +471,183 @@ def process_trial(
sample_frequency: float = 1.0 / meta_frame_time
mylogger.info("Extract heartbeat from volume signal")
heartbeat_ts: torch.Tensor = bandpass(
data=data[volume_index, ...].movedim(0, -1).clone(),
low_frequency=config["lower_freqency_bandpass"],
high_frequency=config["upper_freqency_bandpass"],
fs=sample_frequency,
filtfilt_chuck_size=config["heartbeat_filtfilt_chuck_size"],
)
heartbeat_ts = heartbeat_ts.flatten(start_dim=0, end_dim=-2)
mask_flatten: torch.Tensor = mask_positve.flatten(start_dim=0, end_dim=-1)
if config["gevi"]:
assert config["heartbeat_remove"]
heartbeat_ts = heartbeat_ts[mask_flatten, :]
heartbeat_ts = heartbeat_ts.movedim(0, -1)
heartbeat_ts -= heartbeat_ts.mean(dim=0, keepdim=True)
try:
volume_heartbeat, _, _ = torch.linalg.svd(heartbeat_ts, full_matrices=False)
except torch.cuda.OutOfMemoryError:
mylogger.info("torch.cuda.OutOfMemoryError: Fallback to cpu")
volume_heartbeat_cpu, _, _ = torch.linalg.svd(
heartbeat_ts.cpu(), full_matrices=False
)
volume_heartbeat = volume_heartbeat_cpu.to(heartbeat_ts.data, copy=True)
del volume_heartbeat_cpu
volume_heartbeat = volume_heartbeat[:, 0]
volume_heartbeat -= volume_heartbeat[
config["skip_frames_in_the_beginning"] :
].mean()
del heartbeat_ts
if device != torch.device("cpu"):
torch.cuda.empty_cache()
mylogger.info("Empty CUDA cache")
free_mem = cuda_total_memory - max(
[torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)]
)
mylogger.info(f"CUDA memory: {free_mem//1024} MByte")
if config["save_heartbeat"]:
temp_path = os.path.join(
config["export_path"], experiment_name + "_volume_heartbeat.npy"
)
mylogger.info(f"Save volume heartbeat to {temp_path}")
np.save(temp_path, volume_heartbeat.cpu())
mylogger.info("-==- Done -==-")
volume_heartbeat = volume_heartbeat.unsqueeze(0).unsqueeze(0)
norm_volume_heartbeat = (
volume_heartbeat[..., config["skip_frames_in_the_beginning"] :] ** 2
).sum(dim=-1)
heartbeat_coefficients: torch.Tensor = torch.zeros(
(data.shape[0], data.shape[-2], data.shape[-1]),
dtype=data.dtype,
device=data.device,
)
for i in range(0, data.shape[0]):
y = bandpass(
data=data[i, ...].movedim(0, -1).clone(),
if config["heartbeat_remove"]:
mylogger.info("Extract heartbeat from volume signal")
heartbeat_ts: torch.Tensor = bandpass(
data=data[volume_index, ...].movedim(0, -1).clone(),
low_frequency=config["lower_freqency_bandpass"],
high_frequency=config["upper_freqency_bandpass"],
fs=sample_frequency,
filtfilt_chuck_size=config["heartbeat_filtfilt_chuck_size"],
)[..., config["skip_frames_in_the_beginning"] :]
y -= y.mean(dim=-1, keepdim=True)
heartbeat_coefficients[i, ...] = (
volume_heartbeat[..., config["skip_frames_in_the_beginning"] :] * y
).sum(dim=-1) / norm_volume_heartbeat
heartbeat_coefficients[i, ...] *= mask_positve.type(
dtype=heartbeat_coefficients.dtype
)
del y
heartbeat_ts = heartbeat_ts.flatten(start_dim=0, end_dim=-2)
mask_flatten: torch.Tensor = mask_positve.flatten(start_dim=0, end_dim=-1)
if config["save_heartbeat"]:
temp_path = os.path.join(
config["export_path"], experiment_name + "_heartbeat_coefficients.npy"
heartbeat_ts = heartbeat_ts[mask_flatten, :]
heartbeat_ts = heartbeat_ts.movedim(0, -1)
heartbeat_ts -= heartbeat_ts.mean(dim=0, keepdim=True)
try:
volume_heartbeat, _, _ = torch.linalg.svd(heartbeat_ts, full_matrices=False)
except torch.cuda.OutOfMemoryError:
mylogger.info("torch.cuda.OutOfMemoryError: Fallback to cpu")
volume_heartbeat_cpu, _, _ = torch.linalg.svd(
heartbeat_ts.cpu(), full_matrices=False
)
volume_heartbeat = volume_heartbeat_cpu.to(heartbeat_ts.data, copy=True)
del volume_heartbeat_cpu
volume_heartbeat = volume_heartbeat[:, 0]
volume_heartbeat -= volume_heartbeat[
config["skip_frames_in_the_beginning"] :
].mean()
del heartbeat_ts
if device != torch.device("cpu"):
torch.cuda.empty_cache()
mylogger.info("Empty CUDA cache")
free_mem = cuda_total_memory - max(
[
torch.cuda.memory_reserved(device),
torch.cuda.memory_allocated(device),
]
)
mylogger.info(f"CUDA memory: {free_mem // 1024} MByte")
if config["save_heartbeat"]:
temp_path = os.path.join(
config["export_path"], experiment_name + "_volume_heartbeat.npy"
)
mylogger.info(f"Save volume heartbeat to {temp_path}")
np.save(temp_path, volume_heartbeat.cpu())
mylogger.info("-==- Done -==-")
volume_heartbeat = volume_heartbeat.unsqueeze(0).unsqueeze(0)
norm_volume_heartbeat = (
volume_heartbeat[..., config["skip_frames_in_the_beginning"] :] ** 2
).sum(dim=-1)
heartbeat_coefficients: torch.Tensor = torch.zeros(
(data.shape[0], data.shape[-2], data.shape[-1]),
dtype=data.dtype,
device=data.device,
)
mylogger.info(f"Save heartbeat coefficients to {temp_path}")
np.save(temp_path, heartbeat_coefficients.cpu())
mylogger.info("-==- Done -==-")
for i in range(0, data.shape[0]):
y = bandpass(
data=data[i, ...].movedim(0, -1).clone(),
low_frequency=config["lower_freqency_bandpass"],
high_frequency=config["upper_freqency_bandpass"],
fs=sample_frequency,
filtfilt_chuck_size=config["heartbeat_filtfilt_chuck_size"],
)[..., config["skip_frames_in_the_beginning"] :]
y -= y.mean(dim=-1, keepdim=True)
mylogger.info("Remove heart beat from data")
data -= heartbeat_coefficients.unsqueeze(1) * volume_heartbeat.unsqueeze(0).movedim(
-1, 1
)
mylogger.info("-==- Done -==-")
heartbeat_coefficients[i, ...] = (
volume_heartbeat[..., config["skip_frames_in_the_beginning"] :] * y
).sum(dim=-1) / norm_volume_heartbeat
donor_heartbeat_factor = heartbeat_coefficients[donor_index, ...].clone()
acceptor_heartbeat_factor = heartbeat_coefficients[acceptor_index, ...].clone()
del heartbeat_coefficients
heartbeat_coefficients[i, ...] *= mask_positve.type(
dtype=heartbeat_coefficients.dtype
)
del y
if device != torch.device("cpu"):
torch.cuda.empty_cache()
mylogger.info("Empty CUDA cache")
free_mem = cuda_total_memory - max(
[torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)]
if config["save_heartbeat"]:
temp_path = os.path.join(
config["export_path"], experiment_name + "_heartbeat_coefficients.npy"
)
mylogger.info(f"Save heartbeat coefficients to {temp_path}")
np.save(temp_path, heartbeat_coefficients.cpu())
mylogger.info("-==- Done -==-")
mylogger.info("Remove heart beat from data")
data -= heartbeat_coefficients.unsqueeze(1) * volume_heartbeat.unsqueeze(
0
).movedim(-1, 1)
mylogger.info("-==- Done -==-")
donor_heartbeat_factor = heartbeat_coefficients[donor_index, ...].clone()
acceptor_heartbeat_factor = heartbeat_coefficients[acceptor_index, ...].clone()
del heartbeat_coefficients
if device != torch.device("cpu"):
torch.cuda.empty_cache()
mylogger.info("Empty CUDA cache")
free_mem = cuda_total_memory - max(
[
torch.cuda.memory_reserved(device),
torch.cuda.memory_allocated(device),
]
)
mylogger.info(f"CUDA memory: {free_mem // 1024} MByte")
mylogger.info("Calculate scaling factor for donor and acceptor")
donor_factor: torch.Tensor = (
donor_heartbeat_factor + acceptor_heartbeat_factor
) / (2 * donor_heartbeat_factor)
acceptor_factor: torch.Tensor = (
donor_heartbeat_factor + acceptor_heartbeat_factor
) / (2 * acceptor_heartbeat_factor)
del donor_heartbeat_factor
del acceptor_heartbeat_factor
if config["save_factors"]:
temp_path = os.path.join(
config["export_path"], experiment_name + "_donor_factor.npy"
)
mylogger.info(f"Save donor factor to {temp_path}")
np.save(temp_path, donor_factor.cpu())
temp_path = os.path.join(
config["export_path"], experiment_name + "_acceptor_factor.npy"
)
mylogger.info(f"Save acceptor factor to {temp_path}")
np.save(temp_path, acceptor_factor.cpu())
mylogger.info("-==- Done -==-")
mylogger.info("Scale acceptor to heart beat amplitude")
mylogger.info("Calculate mean")
mean_values_acceptor = data[
acceptor_index, config["skip_frames_in_the_beginning"] :, ...
].nanmean(dim=0, keepdim=True)
mylogger.info("Remove mean")
data[acceptor_index, ...] -= mean_values_acceptor
mylogger.info("Apply acceptor_factor and mask")
data[acceptor_index, ...] *= acceptor_factor.unsqueeze(
0
) * mask_positve.unsqueeze(0)
mylogger.info("Add mean")
data[acceptor_index, ...] += mean_values_acceptor
mylogger.info("-==- Done -==-")
mylogger.info("Scale donor to heart beat amplitude")
mylogger.info("Calculate mean")
mean_values_donor = data[
donor_index, config["skip_frames_in_the_beginning"] :, ...
].nanmean(dim=0, keepdim=True)
mylogger.info("Remove mean")
data[donor_index, ...] -= mean_values_donor
mylogger.info("Apply donor_factor and mask")
data[donor_index, ...] *= donor_factor.unsqueeze(0) * mask_positve.unsqueeze(0)
mylogger.info("Add mean")
data[donor_index, ...] += mean_values_donor
mylogger.info("-==- Done -==-")
mylogger.info("Divide by mean over time")
data /= data[:, config["skip_frames_in_the_beginning"] :, ...].nanmean(
dim=1,
keepdim=True,
)
mylogger.info(f"CUDA memory: {free_mem//1024} MByte")
mylogger.info("Calculate scaling factor for donor and acceptor")
donor_factor: torch.Tensor = (
donor_heartbeat_factor + acceptor_heartbeat_factor
) / (2 * donor_heartbeat_factor)
acceptor_factor: torch.Tensor = (
donor_heartbeat_factor + acceptor_heartbeat_factor
) / (2 * acceptor_heartbeat_factor)
mylogger.info("-==- Done -==-")
del donor_heartbeat_factor
del acceptor_heartbeat_factor
if config["save_factors"]:
temp_path = os.path.join(
config["export_path"], experiment_name + "_donor_factor.npy"
)
mylogger.info(f"Save donor factor to {temp_path}")
np.save(temp_path, donor_factor.cpu())
temp_path = os.path.join(
config["export_path"], experiment_name + "_acceptor_factor.npy"
)
mylogger.info(f"Save acceptor factor to {temp_path}")
np.save(temp_path, acceptor_factor.cpu())
mylogger.info("-==- Done -==-")
mylogger.info("Scale acceptor to heart beat amplitude")
mylogger.info("Calculate mean")
mean_values_acceptor = data[
acceptor_index, config["skip_frames_in_the_beginning"] :, ...
].nanmean(dim=0, keepdim=True)
mylogger.info("Remove mean")
data[acceptor_index, ...] -= mean_values_acceptor
mylogger.info("Apply acceptor_factor and mask")
data[acceptor_index, ...] *= acceptor_factor.unsqueeze(0) * mask_positve.unsqueeze(
0
)
mylogger.info("Add mean")
data[acceptor_index, ...] += mean_values_acceptor
mylogger.info("-==- Done -==-")
mylogger.info("Scale donor to heart beat amplitude")
mylogger.info("Calculate mean")
mean_values_donor = data[
donor_index, config["skip_frames_in_the_beginning"] :, ...
].nanmean(dim=0, keepdim=True)
mylogger.info("Remove mean")
data[donor_index, ...] -= mean_values_donor
mylogger.info("Apply donor_factor and mask")
data[donor_index, ...] *= donor_factor.unsqueeze(0) * mask_positve.unsqueeze(0)
mylogger.info("Add mean")
data[donor_index, ...] += mean_values_donor
mylogger.info("-==- Done -==-")
mylogger.info("Divide by mean over time")
data /= data[:, config["skip_frames_in_the_beginning"] :, ...].nanmean(
dim=1,
keepdim=True,
)
data = data.nan_to_num(nan=0.0)
mylogger.info("-==- Done -==-")
mylogger.info("Preparation for regression -- Gauss smear")
spatial_width = float(config["gauss_smear_spatial_width"])
@ -669,7 +680,7 @@ def process_trial(
free_mem = cuda_total_memory - max(
[torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)]
)
mylogger.info(f"CUDA memory: {free_mem//1024} MByte")
mylogger.info(f"CUDA memory: {free_mem // 1024} MByte")
overwrite_fft_gauss: None | torch.Tensor = None
for i in range(0, data_filtered.shape[0]):
@ -703,7 +714,7 @@ def process_trial(
free_mem = cuda_total_memory - max(
[torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)]
)
mylogger.info(f"CUDA memory: {free_mem//1024} MByte")
mylogger.info(f"CUDA memory: {free_mem // 1024} MByte")
mylogger.info("-==- Done -==-")
mylogger.info("Preperation for Regression")
@ -747,6 +758,9 @@ def process_trial(
mylogger.info("-==- Done -==-")
else:
dual_signal_mode = False
target_id = config["required_order"].index("acceptor")
data_acceptor = data[target_id, ...].clone()
data_acceptor[mask_negative, :] = 0.0
if len(config["target_camera_donor"]) > 0:
mylogger.info("Regression Donor")
@ -781,6 +795,9 @@ def process_trial(
mylogger.info("-==- Done -==-")
else:
dual_signal_mode = False
target_id = config["required_order"].index("donor")
data_donor = data[target_id, ...].clone()
data_donor[mask_negative, :] = 0.0
del data
del data_filtered
@ -791,24 +808,119 @@ def process_trial(
free_mem = cuda_total_memory - max(
[torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)]
)
mylogger.info(f"CUDA memory: {free_mem//1024} MByte")
mylogger.info(f"CUDA memory: {free_mem // 1024} MByte")
# #####################
if config["gevi"]:
assert dual_signal_mode
else:
assert dual_signal_mode is False
if dual_signal_mode is False:
mylogger.info("mono signal model")
mylogger.info("Remove nan")
data_acceptor = torch.nan_to_num(data_acceptor, nan=0.0)
data_donor = torch.nan_to_num(data_donor, nan=0.0)
mylogger.info("-==- Done -==-")
if config["binning_enable"] and config["binning_at_the_end"]:
mylogger.info("Binning of data")
mylogger.info(
(
f"kernel_size={int(config['binning_kernel_size'])}, "
f"stride={int(config['binning_stride'])}, "
"divisor_override=None"
)
)
data_acceptor = binning(
data_acceptor.unsqueeze(-1),
kernel_size=int(config["binning_kernel_size"]),
stride=int(config["binning_stride"]),
divisor_override=None,
).squeeze(-1)
data_donor = binning(
data_donor.unsqueeze(-1),
kernel_size=int(config["binning_kernel_size"]),
stride=int(config["binning_stride"]),
divisor_override=None,
).squeeze(-1)
mask_positve = (
binning(
mask_positve.unsqueeze(-1).unsqueeze(-1).type(dtype=dtype),
kernel_size=int(config["binning_kernel_size"]),
stride=int(config["binning_stride"]),
divisor_override=None,
)
.squeeze(-1)
.squeeze(-1)
)
mask_positve = (mask_positve > 0).type(torch.bool)
if config["save_as_python"]:
temp_path = os.path.join(
config["export_path"], experiment_name + "_acceptor_donor.npz"
)
mylogger.info(f"Save data donor and acceptor and mask to {temp_path}")
np.savez_compressed(
temp_path,
data_acceptor=data_acceptor.cpu(),
data_donor=data_donor.cpu(),
mask=mask_positve.cpu(),
)
if config["save_as_matlab"]:
temp_path = os.path.join(
config["export_path"], experiment_name + "_acceptor_donor.hd5"
)
mylogger.info(f"Save data donor and acceptor and mask to {temp_path}")
file_handle = h5py.File(temp_path, "w")
mask_positve = mask_positve.movedim(0, -1)
data_acceptor = data_acceptor.movedim(1, -1).movedim(0, -1)
data_donor = data_donor.movedim(1, -1).movedim(0, -1)
_ = file_handle.create_dataset(
"mask",
data=mask_positve.type(torch.uint8).cpu(),
compression="gzip",
compression_opts=9,
)
_ = file_handle.create_dataset(
"data_acceptor",
data=data_acceptor.cpu(),
compression="gzip",
compression_opts=9,
)
_ = file_handle.create_dataset(
"data_donor",
data=data_donor.cpu(),
compression="gzip",
compression_opts=9,
)
mylogger.info("Reminder: How to read with matlab:")
mylogger.info(f"mask = h5read('{temp_path}','/mask');")
mylogger.info(f"data_acceptor = h5read('{temp_path}','/data_acceptor');")
mylogger.info(f"data_donor = h5read('{temp_path}','/data_donor');")
file_handle.close()
return
# #####################
mylogger.info("Calculate ratio sequence")
if dual_signal_mode:
if config["classical_ratio_mode"]:
mylogger.info("via acceptor / donor")
ratio_sequence: torch.Tensor = data_acceptor / data_donor
mylogger.info("via / mean over time")
ratio_sequence /= ratio_sequence.mean(dim=-1, keepdim=True)
else:
mylogger.info("via 1.0 + acceptor - donor")
ratio_sequence = 1.0 + data_acceptor - data_donor
if config["classical_ratio_mode"]:
mylogger.info("via acceptor / donor")
ratio_sequence: torch.Tensor = data_acceptor / data_donor
mylogger.info("via / mean over time")
ratio_sequence /= ratio_sequence.mean(dim=-1, keepdim=True)
else:
mylogger.info("mono signal model")
if len(config["target_camera_donor"]) > 0:
ratio_sequence = data_donor.clone()
else:
ratio_sequence = data_acceptor.clone()
mylogger.info("via 1.0 + acceptor - donor")
ratio_sequence = 1.0 + data_acceptor - data_donor
mylogger.info("Remove nan")
ratio_sequence = torch.nan_to_num(ratio_sequence, nan=0.0)