diff --git a/callum_config_M0134M.json b/callum_config_M0134M.json new file mode 100644 index 0000000..f283745 --- /dev/null +++ b/callum_config_M0134M.json @@ -0,0 +1,66 @@ +{ + "basic_path": "/data_1/fatma/GEVI/", + "recoding_data": "2024-11-18", + "mouse_identifier": "M0134M_SessionA", + "raw_path": "raw", + "export_path": "output_M0134M_SessionA", + "ref_image_path": "ref_images_M0134M_SessionA", + "raw_path": "raw", + "heartbeat_remove": true, + "gevi": true, // true => gevi, false => geci + // Ratio Sequence + "classical_ratio_mode": true, // true: a/d false: 1+a-d + // Regression + "target_camera_acceptor": "acceptor", + "regressor_cameras_acceptor": [ + "oxygenation", + "volume" + ], + "target_camera_donor": "donor", + "regressor_cameras_donor": [ + "oxygenation", + "volume" + ], + // binning + "binning_enable": true, + "binning_at_the_end": false, + "binning_kernel_size": 4, + "binning_stride": 4, + "binning_divisor_override": 1, + // alignment + "alignment_batch_size": 200, + "rotation_stabilization_threshold_factor": 3.0, // >= 1.0 + "rotation_stabilization_threshold_border": 0.9, // <= 1.0 + // Heart beat detection + "lower_freqency_bandpass": 5.0, // Hz + "upper_freqency_bandpass": 14.0, // Hz + "heartbeat_filtfilt_chuck_size": 10, + // Gauss smear + "gauss_smear_spatial_width": 8, + "gauss_smear_temporal_width": 0.1, + "gauss_smear_use_matlab_mask": false, + // LED Ramp on + "skip_frames_in_the_beginning": 100, // Frames + // PyTorch + "dtype": "float32", + "force_to_cpu": false, + // Save + "save_as_python": true, // produces .npz files (compressed) + "save_as_matlab": false, // produces .hd5 file (compressed) + // Save extra information + "save_alignment": false, + "save_heartbeat": false, + "save_factors": false, + "save_regression_coefficients": false, + "save_aligned_as_python": false, + "save_aligned_as_matlab": false, + "save_oxyvol_as_python": false, + "save_oxyvol_as_matlab": false, + // Not important parameter + "required_order": [ + "acceptor", + "donor", + "oxygenation", + "volume" + ] +} diff --git a/callum_config_M3905F.json b/callum_config_M3905F.json new file mode 100644 index 0000000..8517026 --- /dev/null +++ b/callum_config_M3905F.json @@ -0,0 +1,66 @@ +{ + "basic_path": "/data_1/fatma/GEVI_GECI_ES", + "recoding_data": "session_B", + "mouse_identifier": "M3905F", + "raw_path": "raw", + "export_path": "output_M3905F_session_B", + "ref_image_path": "ref_images_M3905F_session_B", + "raw_path": "raw", + "heartbeat_remove": true, + "gevi": true, // true => gevi, false => geci + // Ratio Sequence + "classical_ratio_mode": true, // true: a/d false: 1+a-d + // Regression + "target_camera_acceptor": "acceptor", + "regressor_cameras_acceptor": [ + "oxygenation", + "volume" + ], + "target_camera_donor": "donor", + "regressor_cameras_donor": [ + "oxygenation", + "volume" + ], + // binning + "binning_enable": true, + "binning_at_the_end": false, + "binning_kernel_size": 4, + "binning_stride": 4, + "binning_divisor_override": 1, + // alignment + "alignment_batch_size": 200, + "rotation_stabilization_threshold_factor": 3.0, // >= 1.0 + "rotation_stabilization_threshold_border": 0.9, // <= 1.0 + // Heart beat detection + "lower_freqency_bandpass": 5.0, // Hz + "upper_freqency_bandpass": 14.0, // Hz + "heartbeat_filtfilt_chuck_size": 10, + // Gauss smear + "gauss_smear_spatial_width": 8, + "gauss_smear_temporal_width": 0.1, + "gauss_smear_use_matlab_mask": false, + // LED Ramp on + "skip_frames_in_the_beginning": 100, // Frames + // PyTorch + "dtype": "float32", + "force_to_cpu": false, + // Save + "save_as_python": true, // produces .npz files (compressed) + "save_as_matlab": false, // produces .hd5 file (compressed) + // Save extra information + "save_alignment": false, + "save_heartbeat": false, + "save_factors": false, + "save_regression_coefficients": false, + "save_aligned_as_python": false, + "save_aligned_as_matlab": false, + "save_oxyvol_as_python": true, + "save_oxyvol_as_matlab": true, + // Not important parameter + "required_order": [ + "acceptor", + "donor", + "oxygenation", + "volume" + ] +} diff --git a/stage_2_make_heartbeat_mask.py b/stage_2_make_heartbeat_mask.py index 36b3282..dfa8c63 100644 --- a/stage_2_make_heartbeat_mask.py +++ b/stage_2_make_heartbeat_mask.py @@ -137,7 +137,7 @@ def main(*, config_filename: str = "config.json") -> None: valmin=0, valmax=value_sort_max, valinit=threshold, - valstep=value_sort_max / 100.0, + valstep=value_sort_max / 1000.0, ) axbutton_accept = fig.add_axes(rect=(0.3, 0.85, 0.2, 0.04)) button_accept = Button( diff --git a/stage_4_process.py b/stage_4_process.py index 2bdf57d..cb81601 100644 --- a/stage_4_process.py +++ b/stage_4_process.py @@ -1,4 +1,4 @@ -#%% +# %% import numpy as np import torch @@ -49,6 +49,11 @@ def process_trial( else: cuda_total_memory = 0 + mylogger.info("") + mylogger.info("(A) LOADING DATA, REFERENCE, AND MASK") + mylogger.info("-----------------------------------------------") + mylogger.info("") + raw_data_path: str = os.path.join( config["basic_path"], config["recoding_data"], @@ -183,6 +188,12 @@ def process_trial( mylogger.info("-==- Done -==-") if config["binning_enable"] and (config["binning_at_the_end"] is False): + + mylogger.info("") + mylogger.info("(B-OPTIONAL) BINNING") + mylogger.info("-----------------------------------------------") + mylogger.info("") + mylogger.info("Binning of data") mylogger.info( ( @@ -251,6 +262,11 @@ def process_trial( mylogger.info(f"Data shape: {data.shape}") mylogger.info("-==- Done -==-") + mylogger.info("") + mylogger.info("(C) ALIGNMENT OF SECOND TO FIRST CAMERA") + mylogger.info("-----------------------------------------------") + mylogger.info("") + mylogger.info("Preparing alignment") mylogger.info("Re-order Raw data") data = data.moveaxis(-2, 0).moveaxis(-1, 0) @@ -462,6 +478,81 @@ def process_trial( data *= mask_positve.unsqueeze(0).unsqueeze(0).type(dtype=dtype) mylogger.info("-==- Done -==-") + if config["save_aligned_as_python"]: + + temp_path = os.path.join( + config["export_path"], experiment_name + "_aligned.npz" + ) + mylogger.info(f"Save aligned data and mask to {temp_path}") + np.savez_compressed( + temp_path, + data=data.cpu(), + mask=mask_positve.cpu(), + acceptor_index=acceptor_index, + donor_index=donor_index, + oxygenation_index=oxygenation_index, + volume_index=volume_index, + ) + + if config["save_aligned_as_matlab"]: + temp_path = os.path.join( + config["export_path"], experiment_name + "_aligned.hd5" + ) + mylogger.info(f"Save aligned data and mask to {temp_path}") + file_handle = h5py.File(temp_path, "w") + + _ = file_handle.create_dataset( + "mask", + data=mask_positve.movedim(0, -1).type(torch.uint8).cpu(), + compression="gzip", + compression_opts=9, + ) + + _ = file_handle.create_dataset( + "data", + data=data.movedim(1, -1).movedim(0, -1).cpu(), + compression="gzip", + compression_opts=9, + ) + + _ = file_handle.create_dataset( + "acceptor_index", + data=torch.tensor((acceptor_index,)), + compression="gzip", + compression_opts=9, + ) + + _ = file_handle.create_dataset( + "donor_index", + data=torch.tensor((donor_index,)), + compression="gzip", + compression_opts=9, + ) + + _ = file_handle.create_dataset( + "oxygenation_index", + data=torch.tensor((oxygenation_index,)), + compression="gzip", + compression_opts=9, + ) + + _ = file_handle.create_dataset( + "volume_index", + data=torch.tensor((volume_index,)), + compression="gzip", + compression_opts=9, + ) + + mylogger.info("Reminder: How to read with matlab:") + mylogger.info(f"mask = h5read('{temp_path}','/mask');") + mylogger.info(f"data_acceptor = h5read('{temp_path}','/data');") + file_handle.close() + + mylogger.info("") + mylogger.info("(D) INTER-FRAME INTERPOLATION") + mylogger.info("-----------------------------------------------") + mylogger.info("") + mylogger.info("Interpolate the 'in-between' frames for oxygenation and volume") data[oxygenation_index, 1:, ...] = ( data[oxygenation_index, 1:, ...] + data[oxygenation_index, :-1, ...] @@ -477,6 +568,12 @@ def process_trial( assert config["heartbeat_remove"] if config["heartbeat_remove"]: + + mylogger.info("") + mylogger.info("(E-OPTIONAL) HEARTBEAT REMOVAL") + mylogger.info("-----------------------------------------------") + mylogger.info("") + mylogger.info("Extract heartbeat from volume signal") heartbeat_ts: torch.Tensor = bandpass( data=data[volume_index, ...].movedim(0, -1).clone(), @@ -574,8 +671,16 @@ def process_trial( mylogger.info("-==- Done -==-") if config["gevi"]: # UDO scaling performed! + + mylogger.info("") + mylogger.info("(F-OPTIONAL) DONOR/ACCEPTOR SCALING") + mylogger.info("-----------------------------------------------") + mylogger.info("") + donor_heartbeat_factor = heartbeat_coefficients[donor_index, ...].clone() - acceptor_heartbeat_factor = heartbeat_coefficients[acceptor_index, ...].clone() + acceptor_heartbeat_factor = heartbeat_coefficients[ + acceptor_index, ... + ].clone() del heartbeat_coefficients if device != torch.device("cpu"): @@ -590,16 +695,38 @@ def process_trial( mylogger.info(f"CUDA memory: {free_mem // 1024} MByte") mylogger.info("Calculate scaling factor for donor and acceptor") - donor_factor: torch.Tensor = ( - donor_heartbeat_factor + acceptor_heartbeat_factor - ) / (2 * donor_heartbeat_factor) - acceptor_factor: torch.Tensor = ( - donor_heartbeat_factor + acceptor_heartbeat_factor - ) / (2 * acceptor_heartbeat_factor) + # donor_factor: torch.Tensor = ( + # donor_heartbeat_factor + acceptor_heartbeat_factor + # ) / (2 * donor_heartbeat_factor) + # acceptor_factor: torch.Tensor = ( + # donor_heartbeat_factor + acceptor_heartbeat_factor + # ) / (2 * acceptor_heartbeat_factor) + donor_factor = torch.sqrt( + acceptor_heartbeat_factor / donor_heartbeat_factor + ) + acceptor_factor = 1 / donor_factor + + # import matplotlib.pyplot as plt + # plt.pcolor(donor_factor, vmin=0.5, vmax=2.0) + # plt.colorbar() + # plt.show() + # plt.pcolor(acceptor_factor, vmin=0.5, vmax=2.0) + # plt.colorbar() + # plt.show() + # TODO remove del donor_heartbeat_factor del acceptor_heartbeat_factor + # import matplotlib.pyplot as plt + # plt.pcolor(torch.std(data[acceptor_index, config["skip_frames_in_the_beginning"] :, ...], axis=0), vmin=0, vmax=500) + # plt.colorbar() + # plt.show() + # plt.pcolor(torch.std(data[donor_index, config["skip_frames_in_the_beginning"] :, ...], axis=0), vmin=0, vmax=500) + # plt.colorbar() + # plt.show() + # TODO remove + if config["save_factors"]: temp_path = os.path.join( config["export_path"], experiment_name + "_donor_factor.npy" @@ -614,38 +741,66 @@ def process_trial( np.save(temp_path, acceptor_factor.cpu()) mylogger.info("-==- Done -==-") - mylogger.info("Scale acceptor to heart beat amplitude") - mylogger.info("Calculate mean") + # TODO we have to calculate means first! + mylogger.info("Extract means for acceptor and donor first") mean_values_acceptor = data[ acceptor_index, config["skip_frames_in_the_beginning"] :, ... ].nanmean(dim=0, keepdim=True) + mean_values_donor = data[ + donor_index, config["skip_frames_in_the_beginning"] :, ... + ].nanmean(dim=0, keepdim=True) + mylogger.info("Scale acceptor to heart beat amplitude") mylogger.info("Remove mean") data[acceptor_index, ...] -= mean_values_acceptor mylogger.info("Apply acceptor_factor and mask") + # data[acceptor_index, ...] *= acceptor_factor.unsqueeze( + # 0 + # ) * mask_positve.unsqueeze(0) + acceptor_factor_correction = np.sqrt( + mean_values_acceptor / mean_values_donor + ) data[acceptor_index, ...] *= acceptor_factor.unsqueeze( 0 - ) * mask_positve.unsqueeze(0) + ) * acceptor_factor_correction * mask_positve.unsqueeze(0) mylogger.info("Add mean") data[acceptor_index, ...] += mean_values_acceptor mylogger.info("-==- Done -==-") mylogger.info("Scale donor to heart beat amplitude") - mylogger.info("Calculate mean") - mean_values_donor = data[ - donor_index, config["skip_frames_in_the_beginning"] :, ... - ].nanmean(dim=0, keepdim=True) mylogger.info("Remove mean") data[donor_index, ...] -= mean_values_donor mylogger.info("Apply donor_factor and mask") - data[donor_index, ...] *= donor_factor.unsqueeze(0) * mask_positve.unsqueeze(0) + # data[donor_index, ...] *= donor_factor.unsqueeze( + # 0 + # ) * mask_positve.unsqueeze(0) + donor_factor_correction = 1 / acceptor_factor_correction + data[donor_index, ...] *= donor_factor.unsqueeze( + 0 + ) * donor_factor_correction * mask_positve.unsqueeze(0) mylogger.info("Add mean") data[donor_index, ...] += mean_values_donor mylogger.info("-==- Done -==-") + + # import matplotlib.pyplot as plt + # plt.pcolor(mean_values_acceptor[0]) + # plt.colorbar() + # plt.show() + # plt.pcolor(mean_values_donor[0]) + # plt.colorbar() + # plt.show() + # TODO remove + + # TODO SCHNUGGEL else: mylogger.info("GECI does not require acceptor/donor scaling, skipping!") mylogger.info("-==- Done -==-") + mylogger.info("") + mylogger.info("(G) CONVERSION TO RELATIVE SIGNAL CHANGES (DIV/MEAN)") + mylogger.info("-----------------------------------------------") + mylogger.info("") + mylogger.info("Divide by mean over time") data /= data[:, config["skip_frames_in_the_beginning"] :, ...].nanmean( dim=1, @@ -653,6 +808,11 @@ def process_trial( ) mylogger.info("-==- Done -==-") + mylogger.info("") + mylogger.info("(H) CLEANING BY REGRESSION") + mylogger.info("-----------------------------------------------") + mylogger.info("") + data = data.nan_to_num(nan=0.0) mylogger.info("Preparation for regression -- Gauss smear") spatial_width = float(config["gauss_smear_spatial_width"]) @@ -805,6 +965,107 @@ def process_trial( data_donor = data[target_id, ...].clone() data_donor[mask_negative, :] = 0.0 + # TODO clean up ---> + if config["save_oxyvol_as_python"] or config["save_oxyvol_as_matlab"]: + + mylogger.info("") + mylogger.info("(I-OPTIONAL) SAVE OXY/VOL/MASK") + mylogger.info("-----------------------------------------------") + mylogger.info("") + + # extract oxy and vol + mylogger.info("Save Oxygenation/Volume/Mask") + data_oxygenation = data[oxygenation_index, ...].clone() + data_volume = data[volume_index, ...].clone() + data_mask = mask_positve.clone() + + # bin, if required... + if config["binning_enable"] and config["binning_at_the_end"]: + mylogger.info("Binning of data") + mylogger.info( + ( + f"kernel_size={int(config['binning_kernel_size'])}, " + f"stride={int(config['binning_stride'])}, " + "divisor_override=None" + ) + ) + + data_oxygenation = binning( + data_oxygenation.unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=None, + ).squeeze(-1) + + data_volume = binning( + data_volume.unsqueeze(-1), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=None, + ).squeeze(-1) + + data_mask = ( + binning( + data_mask.unsqueeze(-1).unsqueeze(-1).type(dtype=dtype), + kernel_size=int(config["binning_kernel_size"]), + stride=int(config["binning_stride"]), + divisor_override=None, + ) + .squeeze(-1) + .squeeze(-1) + ) + data_mask = (data_mask > 0).type(torch.bool) + + if config["save_oxyvol_as_python"]: + + # export it! + temp_path = os.path.join( + config["export_path"], experiment_name + "_oxygenation_volume.npz" + ) + mylogger.info(f"Save data oxygenation and volume to {temp_path}") + np.savez_compressed( + temp_path, + data_oxygenation=data_oxygenation.cpu(), + data_volume=data_volume.cpu(), + data_mask=data_mask.cpu(), + ) + + if config["save_oxyvol_as_matlab"]: + + temp_path = os.path.join( + config["export_path"], experiment_name + "_oxygenation_volume.hd5" + ) + mylogger.info(f"Save data oxygenation and volume to {temp_path}") + file_handle = h5py.File(temp_path, "w") + + data_mask = data_mask.movedim(0, -1) + data_oxygenation = data_oxygenation.movedim(1, -1).movedim(0, -1) + data_volume = data_volume.movedim(1, -1).movedim(0, -1) + _ = file_handle.create_dataset( + "data_mask", + data=data_mask.type(torch.uint8).cpu(), + compression="gzip", + compression_opts=9, + ) + _ = file_handle.create_dataset( + "data_oxygenation", + data=data_oxygenation.cpu(), + compression="gzip", + compression_opts=9, + ) + _ = file_handle.create_dataset( + "data_volume", + data=data_volume.cpu(), + compression="gzip", + compression_opts=9, + ) + mylogger.info("Reminder: How to read with matlab:") + mylogger.info(f"data_mask = h5read('{temp_path}','/data_mask');") + mylogger.info(f"data_oxygenation = h5read('{temp_path}','/data_oxygenation');") + mylogger.info(f"data_volume = h5read('{temp_path}','/data_volume');") + file_handle.close() + # TODO <------ clean up + del data del data_filtered @@ -825,6 +1086,11 @@ def process_trial( if dual_signal_mode is False: + mylogger.info("") + mylogger.info("(J1-OPTIONAL) SAVE ACC/DON/MASK (NO RATIO!+OPT BIN@END)") + mylogger.info("-----------------------------------------------") + mylogger.info("") + mylogger.info("mono signal model") mylogger.info("Remove nan") @@ -917,6 +1183,11 @@ def process_trial( return # ##################### + mylogger.info("") + mylogger.info("(J2-OPTIONAL) BUILD AND SAVE RATIO (+OPT BIN@END)") + mylogger.info("-----------------------------------------------") + mylogger.info("") + mylogger.info("Calculate ratio sequence") if config["classical_ratio_mode"]: @@ -1102,3 +1373,5 @@ def main( if __name__ == "__main__": argh.dispatch_command(main) + +# %%