diff --git a/DataContainer.py b/DataContainer.py index a312d43..f7a7fd9 100644 --- a/DataContainer.py +++ b/DataContainer.py @@ -50,6 +50,10 @@ class DataContainer(torch.nn.Module): volume_eigenvalues: torch.Tensor | None = None volume_residuum: torch.Tensor | None = None + power_d_initial: torch.Tensor | None = None + power_d_final: torch.Tensor | None = None + power_d_amplitude: torch.Tensor | None = None + # ------- image_alignment: ImageAlignment @@ -939,6 +943,7 @@ class DataContainer(torch.nn.Module): mmap_mode: bool = True, initital_mask: torch.Tensor | None = None, start_position_coefficients: int = 0, + calculate_amplitude: bool = False, ) -> None: self.logger.info(f"{self.level2} start load_data") self.load_data( @@ -980,6 +985,19 @@ class DataContainer(torch.nn.Module): self.oxygenation *= initital_mask.unsqueeze(0) self.volume *= initital_mask.unsqueeze(0) + if calculate_amplitude is True: + ( + self.power_hb_low_initial, + self.power_hb_high_initial, + _, + ) = self.measure_heartbeat_frequency(use_input_source="donor") + self.power_d_initial = self.measure_heartbeat_power( + use_input_source="donor", + start_position_coefficients=start_position_coefficients, + power_hb_low=self.power_hb_low_initial, + power_hb_high=self.power_hb_high_initial, + ) + if remove_heartbeat is True: self.logger.info(f"{self.level2} remove the heart beat via SVD") self.remove_heartbeat( @@ -1425,6 +1443,7 @@ class DataContainer(torch.nn.Module): start_position_coefficients: int = 0, power_hb_low: torch.Tensor | None = None, power_hb_high: torch.Tensor | None = None, + custom_input: torch.Tensor | None = None, ) -> torch.Tensor: if use_input_source == "donor": assert self.donor is not None @@ -1438,6 +1457,9 @@ class DataContainer(torch.nn.Module): assert self.volume is not None hb = self.volume[start_position_coefficients:, ...] + elif use_input_source == "custom": + assert custom_input is not None + hb = custom_input[start_position_coefficients:, ...] else: assert self.oxygenation is not None hb = self.oxygenation[start_position_coefficients:, ...] @@ -1484,9 +1506,9 @@ class DataContainer(torch.nn.Module): start_position: int = 0, start_position_coefficients: int = 100, fs: float = 100.0, - use_regression: bool | None = None, + use_regression: bool | None = False, # Heartbeat - remove_heartbeat: bool = False, # i.e. use SVD + remove_heartbeat: bool = True, # i.e. use SVD low_frequency: float = 5, # Hz Butter Bandpass Heartbeat high_frequency: float = 15, # Hz Butter Bandpass Heartbeat threshold: float | None = 0.5, # For the mask @@ -1503,10 +1525,11 @@ class DataContainer(torch.nn.Module): mmap_mode: bool = True, initital_mask_name: str | None = None, initital_mask_update: bool = True, - initital_mask_roi: bool = True, - gaussian_blur_kernel_size: int | None = None, + initital_mask_roi: bool = False, + gaussian_blur_kernel_size: int | None = 3, gaussian_blur_sigma: float = 1.0, bin_size_post: int | None = None, + calculate_amplitude: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: self.logger.info(f"{self.level0} start automatic_load") @@ -1544,6 +1567,7 @@ class DataContainer(torch.nn.Module): mmap_mode=mmap_mode, initital_mask=initital_mask, start_position_coefficients=start_position_coefficients, + calculate_amplitude=calculate_amplitude, ) heartbeat_a: torch.Tensor | None = None @@ -1697,8 +1721,23 @@ class DataContainer(torch.nn.Module): self.logger.info(f"{self.level0} end automatic_load") + if self.power_d_initial is not None: + self.power_d_final = self.measure_heartbeat_power( + use_input_source="custom", + power_hb_low=self.power_hb_low_initial, + power_hb_high=self.power_hb_high_initial, + start_position_coefficients=start_position_coefficients, + custom_input=result_d, + ) + self.power_d_amplitude = self.power_d_final / self.power_d_initial + self.power_d_amplitude = torch.nan_to_num(self.power_d_amplitude, nan=0.0) + result = result_a - result_d + if self.power_d_amplitude is not None: + result *= self.power_d_amplitude.unsqueeze(0) + result += 1.0 + if (gaussian_blur_kernel_size is not None) and (gaussian_blur_kernel_size > 0): gaussian_blur = tv.transforms.GaussianBlur( kernel_size=[gaussian_blur_kernel_size, gaussian_blur_kernel_size], @@ -1735,6 +1774,7 @@ if __name__ == "__main__": start_position_coefficients: int = 100 remove_heartbeat: bool = True # i.e. use SVD bin_size: int = 4 + calculate_amplitude: bool = False example_position_x: int = 280 example_position_y: int = 440 @@ -1780,6 +1820,7 @@ if __name__ == "__main__": gaussian_blur_kernel_size=gaussian_blur_kernel_size, gaussian_blur_sigma=gaussian_blur_sigma, bin_size_post=bin_size_post, + calculate_amplitude=calculate_amplitude, ) if show_example_timeseries is True: