diff --git a/new_pipeline/config.json b/new_pipeline/config.json index a9e4f2c..2201c87 100644 --- a/new_pipeline/config.json +++ b/new_pipeline/config.json @@ -23,6 +23,10 @@ "lower_freqency_bandpass": 5.0, // Hz "upper_freqency_bandpass": 14.0, // Hz "heartbeat_filtfilt_chuck_size": 10, + // Gauss smear + "gauss_smear_spatial_width": 8, + "gauss_smear_temporal_width": 0.1, + "gauss_smear_use_matlab_mask": false, // LED Ramp on "skip_frames_in_the_beginning": 100, // Frames // PyTorch diff --git a/new_pipeline/stage_4_process.py b/new_pipeline/stage_4_process.py index c84b380..abd8a9e 100644 --- a/new_pipeline/stage_4_process.py +++ b/new_pipeline/stage_4_process.py @@ -19,6 +19,7 @@ from functions.align_refref import align_refref from functions.perform_donor_volume_rotation import perform_donor_volume_rotation from functions.perform_donor_volume_translation import perform_donor_volume_translation from functions.bandpass import bandpass +from functions.gauss_smear_individual import gauss_smear_individual import matplotlib.pyplot as plt @@ -600,6 +601,84 @@ def process_trial( data[donor_index, ...] += mean_values_donor mylogger.info("-==- Done -==-") + mylogger.info("Divide by mean over time") + data /= data[:, config["skip_frames_in_the_beginning"] :, ...].nanmean( + dim=1, + keepdim=True, + ) + data = data.nan_to_num(nan=0.0) + mylogger.info("-==- Done -==-") + + mylogger.info("Preparation for regression -- Gauss smear") + spatial_width = float(config["gauss_smear_spatial_width"]) + + if config["binning_enable"] and config["binning_before_alignment"]: + spatial_width /= float(config["binning_kernel_size"]) + + mylogger.info( + f"Mask -- " + f"spatial width: {spatial_width}, " + f"temporal width: {float(config['gauss_smear_temporal_width'])}, " + f"use matlab mode: {bool(config['gauss_smear_use_matlab_mask'])} " + ) + + input_mask = mask_positve.type(dtype=dtype).clone() + + filtered_mask: torch.Tensor + filtered_mask, _ = gauss_smear_individual( + input=input_mask, + spatial_width=spatial_width, + temporal_width=float(config["gauss_smear_temporal_width"]), + use_matlab_mask=bool(config["gauss_smear_use_matlab_mask"]), + epsilon=float(torch.finfo(input_mask.dtype).eps), + ) + + mylogger.info("creating a copy of the data") + data_filtered = data.clone().movedim(1, -1) + if device != torch.device("cpu"): + free_mem = cuda_total_memory - max( + [torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)] + ) + mylogger.info( + f"CUDA memory after reserving RAM for data_filtered: {free_mem//1024} MByte" + ) + + overwrite_fft_gauss: None | torch.Tensor = None + for i in range(0, data_filtered.shape[0]): + mylogger.info( + f"{config['required_order'][i]} -- " + f"spatial width: {spatial_width}, " + f"temporal width: {float(config['gauss_smear_temporal_width'])}, " + f"use matlab mode: {bool(config['gauss_smear_use_matlab_mask'])} " + ) + data_filtered[i, ...] *= input_mask.unsqueeze(-1) + data_filtered[i, ...], overwrite_fft_gauss = gauss_smear_individual( + input=data_filtered[i, ...], + spatial_width=spatial_width, + temporal_width=float(config["gauss_smear_temporal_width"]), + overwrite_fft_gauss=overwrite_fft_gauss, + use_matlab_mask=bool(config["gauss_smear_use_matlab_mask"]), + epsilon=float(torch.finfo(input_mask.dtype).eps), + ) + + data_filtered[i, ...] /= filtered_mask + 1e-20 + data_filtered[i, ...] += 1.0 - input_mask.unsqueeze(-1) + + del filtered_mask + del overwrite_fft_gauss + del input_mask + mylogger.info("data_filtered is populated") + + if device != torch.device("cpu"): + free_mem = cuda_total_memory - max( + [torch.cuda.memory_reserved(device), torch.cuda.memory_allocated(device)] + ) + mylogger.info( + f"CUDA memory after data_filtered is populated: {free_mem//1024} MByte" + ) + + mylogger.info("-==- Done -==-") + exit() return