Add files via upload

This commit is contained in:
David Rotermund 2023-07-09 19:26:30 +02:00 committed by GitHub
parent 88015cf989
commit cc7de6dbbb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -50,6 +50,10 @@ class DataContainer(torch.nn.Module):
volume_eigenvalues: torch.Tensor | None = None volume_eigenvalues: torch.Tensor | None = None
volume_residuum: torch.Tensor | None = None volume_residuum: torch.Tensor | None = None
power_d_initial: torch.Tensor | None = None
power_d_final: torch.Tensor | None = None
power_d_amplitude: torch.Tensor | None = None
# ------- # -------
image_alignment: ImageAlignment image_alignment: ImageAlignment
@ -939,6 +943,7 @@ class DataContainer(torch.nn.Module):
mmap_mode: bool = True, mmap_mode: bool = True,
initital_mask: torch.Tensor | None = None, initital_mask: torch.Tensor | None = None,
start_position_coefficients: int = 0, start_position_coefficients: int = 0,
calculate_amplitude: bool = False,
) -> None: ) -> None:
self.logger.info(f"{self.level2} start load_data") self.logger.info(f"{self.level2} start load_data")
self.load_data( self.load_data(
@ -980,6 +985,19 @@ class DataContainer(torch.nn.Module):
self.oxygenation *= initital_mask.unsqueeze(0) self.oxygenation *= initital_mask.unsqueeze(0)
self.volume *= initital_mask.unsqueeze(0) self.volume *= initital_mask.unsqueeze(0)
if calculate_amplitude is True:
(
self.power_hb_low_initial,
self.power_hb_high_initial,
_,
) = self.measure_heartbeat_frequency(use_input_source="donor")
self.power_d_initial = self.measure_heartbeat_power(
use_input_source="donor",
start_position_coefficients=start_position_coefficients,
power_hb_low=self.power_hb_low_initial,
power_hb_high=self.power_hb_high_initial,
)
if remove_heartbeat is True: if remove_heartbeat is True:
self.logger.info(f"{self.level2} remove the heart beat via SVD") self.logger.info(f"{self.level2} remove the heart beat via SVD")
self.remove_heartbeat( self.remove_heartbeat(
@ -1425,6 +1443,7 @@ class DataContainer(torch.nn.Module):
start_position_coefficients: int = 0, start_position_coefficients: int = 0,
power_hb_low: torch.Tensor | None = None, power_hb_low: torch.Tensor | None = None,
power_hb_high: torch.Tensor | None = None, power_hb_high: torch.Tensor | None = None,
custom_input: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
if use_input_source == "donor": if use_input_source == "donor":
assert self.donor is not None assert self.donor is not None
@ -1438,6 +1457,9 @@ class DataContainer(torch.nn.Module):
assert self.volume is not None assert self.volume is not None
hb = self.volume[start_position_coefficients:, ...] hb = self.volume[start_position_coefficients:, ...]
elif use_input_source == "custom":
assert custom_input is not None
hb = custom_input[start_position_coefficients:, ...]
else: else:
assert self.oxygenation is not None assert self.oxygenation is not None
hb = self.oxygenation[start_position_coefficients:, ...] hb = self.oxygenation[start_position_coefficients:, ...]
@ -1484,9 +1506,9 @@ class DataContainer(torch.nn.Module):
start_position: int = 0, start_position: int = 0,
start_position_coefficients: int = 100, start_position_coefficients: int = 100,
fs: float = 100.0, fs: float = 100.0,
use_regression: bool | None = None, use_regression: bool | None = False,
# Heartbeat # Heartbeat
remove_heartbeat: bool = False, # i.e. use SVD remove_heartbeat: bool = True, # i.e. use SVD
low_frequency: float = 5, # Hz Butter Bandpass Heartbeat low_frequency: float = 5, # Hz Butter Bandpass Heartbeat
high_frequency: float = 15, # Hz Butter Bandpass Heartbeat high_frequency: float = 15, # Hz Butter Bandpass Heartbeat
threshold: float | None = 0.5, # For the mask threshold: float | None = 0.5, # For the mask
@ -1503,10 +1525,11 @@ class DataContainer(torch.nn.Module):
mmap_mode: bool = True, mmap_mode: bool = True,
initital_mask_name: str | None = None, initital_mask_name: str | None = None,
initital_mask_update: bool = True, initital_mask_update: bool = True,
initital_mask_roi: bool = True, initital_mask_roi: bool = False,
gaussian_blur_kernel_size: int | None = None, gaussian_blur_kernel_size: int | None = 3,
gaussian_blur_sigma: float = 1.0, gaussian_blur_sigma: float = 1.0,
bin_size_post: int | None = None, bin_size_post: int | None = None,
calculate_amplitude: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]: ) -> tuple[torch.Tensor, torch.Tensor | None]:
self.logger.info(f"{self.level0} start automatic_load") self.logger.info(f"{self.level0} start automatic_load")
@ -1544,6 +1567,7 @@ class DataContainer(torch.nn.Module):
mmap_mode=mmap_mode, mmap_mode=mmap_mode,
initital_mask=initital_mask, initital_mask=initital_mask,
start_position_coefficients=start_position_coefficients, start_position_coefficients=start_position_coefficients,
calculate_amplitude=calculate_amplitude,
) )
heartbeat_a: torch.Tensor | None = None heartbeat_a: torch.Tensor | None = None
@ -1697,8 +1721,23 @@ class DataContainer(torch.nn.Module):
self.logger.info(f"{self.level0} end automatic_load") self.logger.info(f"{self.level0} end automatic_load")
if self.power_d_initial is not None:
self.power_d_final = self.measure_heartbeat_power(
use_input_source="custom",
power_hb_low=self.power_hb_low_initial,
power_hb_high=self.power_hb_high_initial,
start_position_coefficients=start_position_coefficients,
custom_input=result_d,
)
self.power_d_amplitude = self.power_d_final / self.power_d_initial
self.power_d_amplitude = torch.nan_to_num(self.power_d_amplitude, nan=0.0)
result = result_a - result_d result = result_a - result_d
if self.power_d_amplitude is not None:
result *= self.power_d_amplitude.unsqueeze(0)
result += 1.0
if (gaussian_blur_kernel_size is not None) and (gaussian_blur_kernel_size > 0): if (gaussian_blur_kernel_size is not None) and (gaussian_blur_kernel_size > 0):
gaussian_blur = tv.transforms.GaussianBlur( gaussian_blur = tv.transforms.GaussianBlur(
kernel_size=[gaussian_blur_kernel_size, gaussian_blur_kernel_size], kernel_size=[gaussian_blur_kernel_size, gaussian_blur_kernel_size],
@ -1735,6 +1774,7 @@ if __name__ == "__main__":
start_position_coefficients: int = 100 start_position_coefficients: int = 100
remove_heartbeat: bool = True # i.e. use SVD remove_heartbeat: bool = True # i.e. use SVD
bin_size: int = 4 bin_size: int = 4
calculate_amplitude: bool = False
example_position_x: int = 280 example_position_x: int = 280
example_position_y: int = 440 example_position_y: int = 440
@ -1780,6 +1820,7 @@ if __name__ == "__main__":
gaussian_blur_kernel_size=gaussian_blur_kernel_size, gaussian_blur_kernel_size=gaussian_blur_kernel_size,
gaussian_blur_sigma=gaussian_blur_sigma, gaussian_blur_sigma=gaussian_blur_sigma,
bin_size_post=bin_size_post, bin_size_post=bin_size_post,
calculate_amplitude=calculate_amplitude,
) )
if show_example_timeseries is True: if show_example_timeseries is True: