Update stage_4_process.py

This commit is contained in:
David Rotermund 2025-03-06 17:10:02 +01:00 committed by GitHub
parent 572afe55ce
commit a9c1611610
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1220,6 +1220,21 @@ def process_trial(
divisor_override=None,
).squeeze(-1)
if config["save_gevi_with_donor_acceptor"]:
data_acceptor = binning(
data_acceptor.unsqueeze(-1),
kernel_size=int(config["binning_kernel_size"]),
stride=int(config["binning_stride"]),
divisor_override=None,
).squeeze(-1)
data_donor = binning(
data_donor.unsqueeze(-1),
kernel_size=int(config["binning_kernel_size"]),
stride=int(config["binning_stride"]),
divisor_override=None,
).squeeze(-1)
mask_positve = (
binning(
mask_positve.unsqueeze(-1).unsqueeze(-1).type(dtype=dtype),
@ -1237,10 +1252,15 @@ def process_trial(
config["export_path"], experiment_name + "_ratio_sequence.npz"
)
mylogger.info(f"Save ratio_sequence and mask to {temp_path}")
np.savez_compressed(
temp_path, ratio_sequence=ratio_sequence.cpu(), mask=mask_positve.cpu()
)
if config["save_gevi_with_donor_acceptor"]:
np.savez_compressed(
temp_path, ratio_sequence=ratio_sequence.cpu(), mask=mask_positve.cpu(), data_acceptor=data_acceptor.cpu(), data_donor=data_donor.cpu()
)
else:
np.savez_compressed(
temp_path, ratio_sequence=ratio_sequence.cpu(), mask=mask_positve.cpu()
)
if config["save_as_matlab"]:
temp_path = os.path.join(
config["export_path"], experiment_name + "_ratio_sequence.hd5"
@ -1262,9 +1282,25 @@ def process_trial(
compression="gzip",
compression_opts=9,
)
if config["save_gevi_with_donor_acceptor"]:
_ = file_handle.create_dataset(
"data_acceptor",
data=data_acceptor.cpu(),
compression="gzip",
compression_opts=9,
)
_ = file_handle.create_dataset(
"data_donor",
data=data_donor.cpu(),
compression="gzip",
compression_opts=9,
)
mylogger.info("Reminder: How to read with matlab:")
mylogger.info(f"mask = h5read('{temp_path}','/mask');")
mylogger.info(f"ratio_sequence = h5read('{temp_path}','/ratio_sequence');")
if config["save_gevi_with_donor_acceptor"]:
mylogger.info(f"data_donor = h5read('{temp_path}','/data_donor');")
mylogger.info(f"data_acceptor = h5read('{temp_path}','/data_acceptor');")
file_handle.close()
del ratio_sequence