Add files via upload

This commit is contained in:
David Rotermund 2023-07-10 13:06:54 +02:00 committed by GitHub
parent cc7de6dbbb
commit 71c2784437
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 97 additions and 94 deletions

View file

@ -11,7 +11,7 @@ class Anime:
def show( def show(
self, self,
input: torch.Tensor | np.ndarray, input: torch.Tensor | np.ndarray,
mask: torch.Tensor | np.ndarray | None, mask: torch.Tensor | np.ndarray | None = None,
vmin: float | None = None, vmin: float | None = None,
vmax: float | None = None, vmax: float | None = None,
cmap: str = "hot", cmap: str = "hot",
@ -60,10 +60,10 @@ class Anime:
vmax=vmax, vmax=vmax,
) )
if colorbar is True: if colorbar:
plt.colorbar() plt.colorbar()
if axis_off is True: if axis_off:
plt.axis("off") plt.axis("off")
def next_frame(i: int) -> None: def next_frame(i: int) -> None:
@ -72,7 +72,7 @@ class Anime:
image[mask_np] = float("NaN") image[mask_np] = float("NaN")
image_handle.set_data(image) image_handle.set_data(image)
if show_frame_count is True: if show_frame_count:
bar_length: int = 10 bar_length: int = 10
filled_length = int(round(bar_length * i / input_np.shape[0])) filled_length = int(round(bar_length * i / input_np.shape[0]))
bar = "\u25A0" * filled_length + "\u25A1" * (bar_length - filled_length) bar = "\u25A0" * filled_length + "\u25A1" * (bar_length - filled_length)

View file

@ -50,9 +50,10 @@ class DataContainer(torch.nn.Module):
volume_eigenvalues: torch.Tensor | None = None volume_eigenvalues: torch.Tensor | None = None
volume_residuum: torch.Tensor | None = None volume_residuum: torch.Tensor | None = None
power_d_initial: torch.Tensor | None = None acceptor_scale: torch.Tensor | None = None
power_d_final: torch.Tensor | None = None donor_scale: torch.Tensor | None = None
power_d_amplitude: torch.Tensor | None = None oxygenation_scale: torch.Tensor | None = None
volume_scale: torch.Tensor | None = None
# ------- # -------
image_alignment: ImageAlignment image_alignment: ImageAlignment
@ -96,7 +97,7 @@ class DataContainer(torch.nn.Module):
self.logger = logging.getLogger("DataContainer") self.logger = logging.getLogger("DataContainer")
self.logger.setLevel(logging.DEBUG) self.logger.setLevel(logging.DEBUG)
if save_logging_messages is True: if save_logging_messages:
time_format = "%b %-d %Y %H:%M:%S" time_format = "%b %-d %Y %H:%M:%S"
logformat = "%(asctime)s %(message)s" logformat = "%(asctime)s %(message)s"
file_formatter = logging.Formatter(fmt=logformat, datefmt=time_format) file_formatter = logging.Formatter(fmt=logformat, datefmt=time_format)
@ -106,7 +107,7 @@ class DataContainer(torch.nn.Module):
file_handler.setFormatter(file_formatter) file_handler.setFormatter(file_formatter)
self.logger.addHandler(file_handler) self.logger.addHandler(file_handler)
if display_logging_messages is True: if display_logging_messages:
time_format = "%b %-d %Y %H:%M:%S" time_format = "%b %-d %Y %H:%M:%S"
logformat = "%(asctime)s %(message)s" logformat = "%(asctime)s %(message)s"
stream_formatter = logging.Formatter(fmt=logformat, datefmt=time_format) stream_formatter = logging.Formatter(fmt=logformat, datefmt=time_format)
@ -128,7 +129,7 @@ class DataContainer(torch.nn.Module):
json_postfix: str = "_meta.txt" json_postfix: str = "_meta.txt"
found_name_json: str = file_input_ref_image.replace(".npy", json_postfix) found_name_json: str = file_input_ref_image.replace(".npy", json_postfix)
assert os.path.isfile(found_name_json) is True assert os.path.isfile(found_name_json)
with open(found_name_json, "r") as file_handle: with open(found_name_json, "r") as file_handle:
metadata = json.load(file_handle) metadata = json.load(file_handle)
@ -203,9 +204,7 @@ class DataContainer(torch.nn.Module):
f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}_meta.txt", f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}_meta.txt",
) )
while (os.path.isfile(filename_np) is True) and ( while (os.path.isfile(filename_np)) and (os.path.isfile(filename_meta)):
os.path.isfile(filename_meta) is True
):
self.logger.info(f"{self.level3} work in {filename_np}") self.logger.info(f"{self.level3} work in {filename_np}")
# Check if channel asignment is still okay # Check if channel asignment is still okay
with open(filename_meta, "r") as file_handle: with open(filename_meta, "r") as file_handle:
@ -218,7 +217,7 @@ class DataContainer(torch.nn.Module):
# Load the data... # Load the data...
self.logger.info(f"{self.level3} np.load") self.logger.info(f"{self.level3} np.load")
if mmap_mode is True: if mmap_mode:
temp: np.ndarray = np.load(filename_np, mmap_mode="r") temp: np.ndarray = np.load(filename_np, mmap_mode="r")
else: else:
temp = np.load(filename_np) temp = np.load(filename_np)
@ -275,7 +274,7 @@ class DataContainer(torch.nn.Module):
dim=2, dim=2,
) )
if enable_secondary_data is True: if enable_secondary_data:
self.logger.info(f"{self.level3} organize oxygenation") self.logger.info(f"{self.level3} organize oxygenation")
if self.oxygenation is None: if self.oxygenation is None:
self.oxygenation = torch.tensor( self.oxygenation = torch.tensor(
@ -344,13 +343,13 @@ class DataContainer(torch.nn.Module):
self.acceptor = self.acceptor.moveaxis(-1, 0) self.acceptor = self.acceptor.moveaxis(-1, 0)
self.donor = self.donor.moveaxis(-1, 0) self.donor = self.donor.moveaxis(-1, 0)
if enable_secondary_data is True: if enable_secondary_data:
assert self.oxygenation is not None assert self.oxygenation is not None
assert self.volume is not None assert self.volume is not None
self.oxygenation = self.oxygenation.moveaxis(-1, 0) self.oxygenation = self.oxygenation.moveaxis(-1, 0)
self.volume = self.volume.moveaxis(-1, 0) self.volume = self.volume.moveaxis(-1, 0)
if align is True: if align:
self.logger.info(f"{self.level3} move intra timeseries") self.logger.info(f"{self.level3} move intra timeseries")
self._move_intra_timeseries( self._move_intra_timeseries(
enable_secondary_data=enable_secondary_data, enable_secondary_data=enable_secondary_data,
@ -494,7 +493,7 @@ class DataContainer(torch.nn.Module):
fill=self.fill_value, fill=self.fill_value,
) )
if enable_secondary_data is True: if enable_secondary_data:
assert self.volume is not None assert self.volume is not None
self.volume = tv.transforms.functional.affine( self.volume = tv.transforms.functional.affine(
img=self.volume, img=self.volume,
@ -522,7 +521,7 @@ class DataContainer(torch.nn.Module):
shear=0, shear=0,
fill=self.fill_value, fill=self.fill_value,
) )
if enable_secondary_data is True: if enable_secondary_data:
assert self.oxygenation is not None assert self.oxygenation is not None
self.oxygenation = tv.transforms.functional.affine( self.oxygenation = tv.transforms.functional.affine(
img=self.oxygenation, img=self.oxygenation,
@ -557,7 +556,7 @@ class DataContainer(torch.nn.Module):
fill=self.fill_value, fill=self.fill_value,
) )
if enable_secondary_data is True: if enable_secondary_data:
assert self.oxygenation is not None assert self.oxygenation is not None
self.oxygenation = tv.transforms.functional.affine( self.oxygenation = tv.transforms.functional.affine(
img=self.oxygenation, img=self.oxygenation,
@ -593,7 +592,7 @@ class DataContainer(torch.nn.Module):
fill=self.fill_value, fill=self.fill_value,
) )
if enable_secondary_data is True: if enable_secondary_data:
assert self.oxygenation is not None assert self.oxygenation is not None
self.oxygenation = tv.transforms.functional.affine( self.oxygenation = tv.transforms.functional.affine(
img=self.oxygenation, img=self.oxygenation,
@ -786,7 +785,7 @@ class DataContainer(torch.nn.Module):
else: else:
self.donor_residuum += to_remove self.donor_residuum += to_remove
if enable_secondary_data is True: if enable_secondary_data:
to_remove, _, _, _ = self.volume_svd_remove( to_remove, _, _, _ = self.volume_svd_remove(
lowrank_method=lowrank_method, lowrank_method=lowrank_method,
lowrank_q=lowrank_q, lowrank_q=lowrank_q,
@ -814,7 +813,7 @@ class DataContainer(torch.nn.Module):
self.donor -= self.donor.mean(dim=0, keepdim=True) self.donor -= self.donor.mean(dim=0, keepdim=True)
self.acceptor -= self.acceptor.mean(dim=0, keepdim=True) self.acceptor -= self.acceptor.mean(dim=0, keepdim=True)
if enable_secondary_data is True: if enable_secondary_data:
assert self.volume is not None assert self.volume is not None
assert self.oxygenation is not None assert self.oxygenation is not None
self.volume -= self.volume.mean(dim=0, keepdim=True) self.volume -= self.volume.mean(dim=0, keepdim=True)
@ -827,7 +826,7 @@ class DataContainer(torch.nn.Module):
self.donor_residuum -= self.donor_residuum.mean(dim=0, keepdim=True) self.donor_residuum -= self.donor_residuum.mean(dim=0, keepdim=True)
self.acceptor_residuum -= self.acceptor_residuum.mean(dim=0, keepdim=True) self.acceptor_residuum -= self.acceptor_residuum.mean(dim=0, keepdim=True)
if enable_secondary_data is True: if enable_secondary_data:
assert self.volume_residuum is not None assert self.volume_residuum is not None
assert self.oxygenation_residuum is not None assert self.oxygenation_residuum is not None
self.volume_residuum -= self.volume_residuum.mean(dim=0, keepdim=True) self.volume_residuum -= self.volume_residuum.mean(dim=0, keepdim=True)
@ -858,7 +857,7 @@ class DataContainer(torch.nn.Module):
self.donor -= self._calculate_linear_trend_data(self.donor) self.donor -= self._calculate_linear_trend_data(self.donor)
self.acceptor -= self._calculate_linear_trend_data(self.acceptor) self.acceptor -= self._calculate_linear_trend_data(self.acceptor)
if enable_secondary_data is True: if enable_secondary_data:
assert self.volume is not None assert self.volume is not None
assert self.oxygenation is not None assert self.oxygenation is not None
self.volume -= self._calculate_linear_trend_data(self.volume) self.volume -= self._calculate_linear_trend_data(self.volume)
@ -877,7 +876,7 @@ class DataContainer(torch.nn.Module):
self.acceptor_residuum self.acceptor_residuum
) )
if enable_secondary_data is True: if enable_secondary_data:
assert self.volume_residuum is not None assert self.volume_residuum is not None
assert self.oxygenation_residuum is not None assert self.oxygenation_residuum is not None
self.volume_residuum -= self._calculate_linear_trend_data( self.volume_residuum -= self._calculate_linear_trend_data(
@ -897,7 +896,7 @@ class DataContainer(torch.nn.Module):
self.donor = self.donor[1:, :, :] self.donor = self.donor[1:, :, :]
self.acceptor = self.acceptor[1:, :, :] self.acceptor = self.acceptor[1:, :, :]
if enable_secondary_data is True: if enable_secondary_data:
assert self.volume is not None assert self.volume is not None
assert self.oxygenation is not None assert self.oxygenation is not None
self.volume = (self.volume[1:, :, :] + self.volume[:-1, :, :]) / 2.0 self.volume = (self.volume[1:, :, :] + self.volume[:-1, :, :]) / 2.0
@ -911,7 +910,7 @@ class DataContainer(torch.nn.Module):
if self.acceptor_residuum is not None: if self.acceptor_residuum is not None:
self.acceptor_residuum = self.acceptor_residuum[1:, :, :] self.acceptor_residuum = self.acceptor_residuum[1:, :, :]
if enable_secondary_data is True: if enable_secondary_data:
if self.volume_residuum is not None: if self.volume_residuum is not None:
self.volume_residuum = ( self.volume_residuum = (
self.volume_residuum[1:, :, :] + self.volume_residuum[:-1, :, :] self.volume_residuum[1:, :, :] + self.volume_residuum[:-1, :, :]
@ -943,7 +942,6 @@ class DataContainer(torch.nn.Module):
mmap_mode: bool = True, mmap_mode: bool = True,
initital_mask: torch.Tensor | None = None, initital_mask: torch.Tensor | None = None,
start_position_coefficients: int = 0, start_position_coefficients: int = 0,
calculate_amplitude: bool = False,
) -> None: ) -> None:
self.logger.info(f"{self.level2} start load_data") self.logger.info(f"{self.level2} start load_data")
self.load_data( self.load_data(
@ -962,12 +960,32 @@ class DataContainer(torch.nn.Module):
pool = torch.nn.AvgPool2d((bin_size, bin_size), stride=(bin_size, bin_size)) pool = torch.nn.AvgPool2d((bin_size, bin_size), stride=(bin_size, bin_size))
self.donor = pool(self.donor) self.donor = pool(self.donor)
self.acceptor = pool(self.acceptor) self.acceptor = pool(self.acceptor)
if enable_secondary_data is True: if enable_secondary_data:
assert self.volume is not None assert self.volume is not None
assert self.oxygenation is not None assert self.oxygenation is not None
self.volume = pool(self.volume) self.volume = pool(self.volume)
self.oxygenation = pool(self.oxygenation) self.oxygenation = pool(self.oxygenation)
if self.donor is not None:
self.donor_scale = self.donor.mean(dim=0, keepdim=True)
self.donor /= self.donor_scale
self.donor -= 1.0
if self.acceptor is not None:
self.acceptor_scale = self.acceptor.mean(dim=0, keepdim=True)
self.acceptor /= self.acceptor_scale
self.acceptor -= 1.0
if self.volume is not None:
self.volume_scale = self.volume.mean(dim=0, keepdim=True)
self.volume /= self.volume_scale
self.volume -= 1.0
if self.oxygenation is not None:
self.oxygenation_scale = self.oxygenation.mean(dim=0, keepdim=True)
self.oxygenation /= self.oxygenation_scale
self.oxygenation -= 1.0
if initital_mask is not None: if initital_mask is not None:
self.logger.info(f"{self.level2} initial mask is applied on the data") self.logger.info(f"{self.level2} initial mask is applied on the data")
assert self.acceptor is not None assert self.acceptor is not None
@ -979,26 +997,13 @@ class DataContainer(torch.nn.Module):
self.acceptor *= initital_mask.unsqueeze(0) self.acceptor *= initital_mask.unsqueeze(0)
self.donor *= initital_mask.unsqueeze(0) self.donor *= initital_mask.unsqueeze(0)
if enable_secondary_data is True: if enable_secondary_data:
assert self.oxygenation is not None assert self.oxygenation is not None
assert self.volume is not None assert self.volume is not None
self.oxygenation *= initital_mask.unsqueeze(0) self.oxygenation *= initital_mask.unsqueeze(0)
self.volume *= initital_mask.unsqueeze(0) self.volume *= initital_mask.unsqueeze(0)
if calculate_amplitude is True: if remove_heartbeat:
(
self.power_hb_low_initial,
self.power_hb_high_initial,
_,
) = self.measure_heartbeat_frequency(use_input_source="donor")
self.power_d_initial = self.measure_heartbeat_power(
use_input_source="donor",
start_position_coefficients=start_position_coefficients,
power_hb_low=self.power_hb_low_initial,
power_hb_high=self.power_hb_high_initial,
)
if remove_heartbeat is True:
self.logger.info(f"{self.level2} remove the heart beat via SVD") self.logger.info(f"{self.level2} remove the heart beat via SVD")
self.remove_heartbeat( self.remove_heartbeat(
iterations=iterations, iterations=iterations,
@ -1008,19 +1013,19 @@ class DataContainer(torch.nn.Module):
start_position_coefficients=start_position_coefficients, start_position_coefficients=start_position_coefficients,
) )
if remove_mean is True: if remove_mean:
self.logger.info(f"{self.level2} remove mean") self.logger.info(f"{self.level2} remove mean")
self.remove_mean_data(enable_secondary_data=enable_secondary_data) self.remove_mean_data(enable_secondary_data=enable_secondary_data)
if remove_linear is True: if remove_linear:
self.logger.info(f"{self.level2} remove linear trends") self.logger.info(f"{self.level2} remove linear trends")
self.remove_linear_trend_data(enable_secondary_data=enable_secondary_data) self.remove_linear_trend_data(enable_secondary_data=enable_secondary_data)
if remove_heartbeat is True: if remove_heartbeat:
if remove_heartbeat_mean is True: if remove_heartbeat_mean:
self.logger.info(f"{self.level2} remove mean (heart beat signal)") self.logger.info(f"{self.level2} remove mean (heart beat signal)")
self.remove_mean_residuum(enable_secondary_data=enable_secondary_data) self.remove_mean_residuum(enable_secondary_data=enable_secondary_data)
if remove_heartbeat_linear is True: if remove_heartbeat_linear:
self.logger.info( self.logger.info(
f"{self.level2} remove linear trends (heart beat signal)" f"{self.level2} remove linear trends (heart beat signal)"
) )
@ -1028,7 +1033,7 @@ class DataContainer(torch.nn.Module):
enable_secondary_data=enable_secondary_data enable_secondary_data=enable_secondary_data
) )
if do_frame_shift is True: if do_frame_shift:
self.logger.info(f"{self.level2} frame shift") self.logger.info(f"{self.level2} frame shift")
self.frame_shift(enable_secondary_data=enable_secondary_data) self.frame_shift(enable_secondary_data=enable_secondary_data)
@ -1141,7 +1146,7 @@ class DataContainer(torch.nn.Module):
del o_norm del o_norm
del v_norm del v_norm
if export_parameters is True: if export_parameters:
parameter_a_temp: torch.Tensor | None = torch.zeros_like(data_norm) parameter_a_temp: torch.Tensor | None = torch.zeros_like(data_norm)
parameter_d_temp: torch.Tensor | None = torch.zeros_like(data_norm) parameter_d_temp: torch.Tensor | None = torch.zeros_like(data_norm)
else: else:
@ -1149,7 +1154,7 @@ class DataContainer(torch.nn.Module):
parameter_d_temp = None parameter_d_temp = None
for mode_a in [True, False]: for mode_a in [True, False]:
if mode_a is True: if mode_a:
result = a.detach().clone() result = a.detach().clone()
result_mean_correct = a_correction result_mean_correct = a_correction
@ -1178,7 +1183,7 @@ class DataContainer(torch.nn.Module):
result -= data_selected * scale.unsqueeze(0) result -= data_selected * scale.unsqueeze(0)
if mode_a is True: if mode_a:
if i == 0: if i == 0:
initial_scale_value_a = max( initial_scale_value_a = max(
[max_scale_value_a, float(scale.max())] [max_scale_value_a, float(scale.max())]
@ -1198,7 +1203,7 @@ class DataContainer(torch.nn.Module):
-1, idx.unsqueeze(-1), scale.unsqueeze(-1) -1, idx.unsqueeze(-1), scale.unsqueeze(-1)
) )
if mode_a is True: if mode_a:
result_a[:, chunk, :] = result.detach().clone() result_a[:, chunk, :] = result.detach().clone()
max_scale_value_a = max([max_scale_value_a, float(scale.max())]) max_scale_value_a = max([max_scale_value_a, float(scale.max())])
if parameter_a_temp is not None: if parameter_a_temp is not None:
@ -1214,7 +1219,7 @@ class DataContainer(torch.nn.Module):
(parameter_d_temp, d_mean_full.squeeze(0).unsqueeze(-1)), (parameter_d_temp, d_mean_full.squeeze(0).unsqueeze(-1)),
dim=-1, dim=-1,
) )
if export_parameters is True: if export_parameters:
if (parameter_a is None) and (parameter_a_temp is not None): if (parameter_a is None) and (parameter_a_temp is not None):
parameter_a = torch.zeros( parameter_a = torch.zeros(
( (
@ -1366,7 +1371,7 @@ class DataContainer(torch.nn.Module):
heartbeat_a = torch.sqrt(scale) heartbeat_a = torch.sqrt(scale)
heartbeat_d = 1.0 / (heartbeat_a + 1e-20) heartbeat_d = 1.0 / (heartbeat_a + 1e-20)
if apply_to_data is True: if apply_to_data:
if self.donor is not None: if self.donor is not None:
self.donor *= heartbeat_d.unsqueeze(0) self.donor *= heartbeat_d.unsqueeze(0)
if self.volume is not None: if self.volume is not None:
@ -1378,8 +1383,18 @@ class DataContainer(torch.nn.Module):
if threshold is not None: if threshold is not None:
self.logger.info(f"{self.level3} calculate mask") self.logger.info(f"{self.level3} calculate mask")
mask = torch.where(hb_d.std(dim=0) > threshold, 1.0, 0.0) * torch.where( assert self.donor_scale is not None
hb_a.std(dim=0) > threshold, 1.0, 0.0 assert self.acceptor_scale is not None
temp_d = hb_d.std(dim=0) * self.donor_scale.squeeze(0)
temp_d -= temp_d.min()
temp_d /= temp_d.max()
temp_a = hb_a.std(dim=0) * self.acceptor_scale.squeeze(0)
temp_a -= temp_a.min()
temp_a /= temp_a.max()
mask = torch.where(temp_d > threshold, 1.0, 0.0) * torch.where(
temp_a > threshold, 1.0, 0.0
) )
else: else:
mask = None mask = None
@ -1506,7 +1521,7 @@ class DataContainer(torch.nn.Module):
start_position: int = 0, start_position: int = 0,
start_position_coefficients: int = 100, start_position_coefficients: int = 100,
fs: float = 100.0, fs: float = 100.0,
use_regression: bool | None = False, use_regression: bool | None = None,
# Heartbeat # Heartbeat
remove_heartbeat: bool = True, # i.e. use SVD remove_heartbeat: bool = True, # i.e. use SVD
low_frequency: float = 5, # Hz Butter Bandpass Heartbeat low_frequency: float = 5, # Hz Butter Bandpass Heartbeat
@ -1520,7 +1535,7 @@ class DataContainer(torch.nn.Module):
remove_heartbeat_mean: bool = False, remove_heartbeat_mean: bool = False,
remove_heartbeat_linear: bool = False, remove_heartbeat_linear: bool = False,
bin_size: int = 4, bin_size: int = 4,
do_frame_shift: bool = True, do_frame_shift: bool | None = None,
half_width_frequency_window: float = 3.0, # Hz (on side ) measure_heartbeat_frequency half_width_frequency_window: float = 3.0, # Hz (on side ) measure_heartbeat_frequency
mmap_mode: bool = True, mmap_mode: bool = True,
initital_mask_name: str | None = None, initital_mask_name: str | None = None,
@ -1529,18 +1544,18 @@ class DataContainer(torch.nn.Module):
gaussian_blur_kernel_size: int | None = 3, gaussian_blur_kernel_size: int | None = 3,
gaussian_blur_sigma: float = 1.0, gaussian_blur_sigma: float = 1.0,
bin_size_post: int | None = None, bin_size_post: int | None = None,
calculate_amplitude: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]: ) -> tuple[torch.Tensor, torch.Tensor | None]:
self.logger.info(f"{self.level0} start automatic_load") self.logger.info(f"{self.level0} start automatic_load")
if use_regression is None: if use_regression is None:
use_regression = not remove_heartbeat use_regression = not remove_heartbeat
if do_frame_shift is None:
do_frame_shift = not remove_heartbeat
initital_mask: torch.Tensor | None = None initital_mask: torch.Tensor | None = None
if (initital_mask_name is not None) and os.path.isfile( if (initital_mask_name is not None) and os.path.isfile(initital_mask_name):
initital_mask_name
) is True:
initital_mask = torch.tensor( initital_mask = torch.tensor(
np.load(initital_mask_name), device=self.device, dtype=torch.float32 np.load(initital_mask_name), device=self.device, dtype=torch.float32
) )
@ -1567,7 +1582,6 @@ class DataContainer(torch.nn.Module):
mmap_mode=mmap_mode, mmap_mode=mmap_mode,
initital_mask=initital_mask, initital_mask=initital_mask,
start_position_coefficients=start_position_coefficients, start_position_coefficients=start_position_coefficients,
calculate_amplitude=calculate_amplitude,
) )
heartbeat_a: torch.Tensor | None = None heartbeat_a: torch.Tensor | None = None
@ -1576,7 +1590,7 @@ class DataContainer(torch.nn.Module):
power_hb_low: torch.Tensor | None = None power_hb_low: torch.Tensor | None = None
power_hb_high: torch.Tensor | None = None power_hb_high: torch.Tensor | None = None
if remove_heartbeat is True: if remove_heartbeat:
self.logger.info(f"{self.level1} remove heart beat (heartbeat_scale)") self.logger.info(f"{self.level1} remove heart beat (heartbeat_scale)")
heartbeat_a, heartbeat_d, mask = self.heartbeat_scale( heartbeat_a, heartbeat_d, mask = self.heartbeat_scale(
low_frequency=low_frequency, low_frequency=low_frequency,
@ -1590,6 +1604,7 @@ class DataContainer(torch.nn.Module):
self.logger.info( self.logger.info(
f"{self.level1} measure heart rate (measure_heartbeat_frequency)" f"{self.level1} measure heart rate (measure_heartbeat_frequency)"
) )
assert self.volume is not None
( (
power_hb_low, power_hb_low,
power_hb_high, power_hb_high,
@ -1603,7 +1618,7 @@ class DataContainer(torch.nn.Module):
half_width_frequency_window=half_width_frequency_window, half_width_frequency_window=half_width_frequency_window,
) )
if use_regression is True: if use_regression:
self.logger.info(f"{self.level1} use regression") self.logger.info(f"{self.level1} use regression")
( (
result_a, result_a,
@ -1663,14 +1678,14 @@ class DataContainer(torch.nn.Module):
result_d *= heartbeat_d.unsqueeze(0) result_d *= heartbeat_d.unsqueeze(0)
if mask is not None: if mask is not None:
if initital_mask_update is True: if initital_mask_update:
self.logger.info(f"{self.level1} update inital mask") self.logger.info(f"{self.level1} update inital mask")
if initital_mask is None: if initital_mask is None:
initital_mask = mask.clone() initital_mask = mask.clone()
else: else:
initital_mask *= mask initital_mask *= mask
if (initital_mask_roi is True) and (initital_mask is not None): if (initital_mask_roi) and (initital_mask is not None):
self.logger.info(f"{self.level1} enter roi mask drawing modus") self.logger.info(f"{self.level1} enter roi mask drawing modus")
yes_choices = ["yes", "y"] yes_choices = ["yes", "y"]
contiue_roi: bool = True contiue_roi: bool = True
@ -1678,7 +1693,7 @@ class DataContainer(torch.nn.Module):
image: np.ndarray = (result_a - result_d)[0, ...].cpu().numpy() image: np.ndarray = (result_a - result_d)[0, ...].cpu().numpy()
image[initital_mask.cpu().numpy() == 0] = float("NaN") image[initital_mask.cpu().numpy() == 0] = float("NaN")
while contiue_roi is True: while contiue_roi:
user_input = input( user_input = input(
"Mask: Do you want to remove more pixel (yes/no)? " "Mask: Do you want to remove more pixel (yes/no)? "
) )
@ -1721,22 +1736,8 @@ class DataContainer(torch.nn.Module):
self.logger.info(f"{self.level0} end automatic_load") self.logger.info(f"{self.level0} end automatic_load")
if self.power_d_initial is not None: # result = (1.0 + result_a) / (1.0 + result_d)
self.power_d_final = self.measure_heartbeat_power( result = 1.0 + result_a - result_d
use_input_source="custom",
power_hb_low=self.power_hb_low_initial,
power_hb_high=self.power_hb_high_initial,
start_position_coefficients=start_position_coefficients,
custom_input=result_d,
)
self.power_d_amplitude = self.power_d_final / self.power_d_initial
self.power_d_amplitude = torch.nan_to_num(self.power_d_amplitude, nan=0.0)
result = result_a - result_d
if self.power_d_amplitude is not None:
result *= self.power_d_amplitude.unsqueeze(0)
result += 1.0
if (gaussian_blur_kernel_size is not None) and (gaussian_blur_kernel_size > 0): if (gaussian_blur_kernel_size is not None) and (gaussian_blur_kernel_size > 0):
gaussian_blur = tv.transforms.GaussianBlur( gaussian_blur = tv.transforms.GaussianBlur(
@ -1774,7 +1775,7 @@ if __name__ == "__main__":
start_position_coefficients: int = 100 start_position_coefficients: int = 100
remove_heartbeat: bool = True # i.e. use SVD remove_heartbeat: bool = True # i.e. use SVD
bin_size: int = 4 bin_size: int = 4
calculate_amplitude: bool = False threshold: float | None = 0.05 # Between 0 and 1.0
example_position_x: int = 280 example_position_x: int = 280
example_position_y: int = 440 example_position_y: int = 440
@ -1820,13 +1821,15 @@ if __name__ == "__main__":
gaussian_blur_kernel_size=gaussian_blur_kernel_size, gaussian_blur_kernel_size=gaussian_blur_kernel_size,
gaussian_blur_sigma=gaussian_blur_sigma, gaussian_blur_sigma=gaussian_blur_sigma,
bin_size_post=bin_size_post, bin_size_post=bin_size_post,
calculate_amplitude=calculate_amplitude, threshold=threshold,
) )
if show_example_timeseries is True: if show_example_timeseries:
plt.plot(result[:, example_position_x, example_position_y].cpu()) plt.plot(result[:, example_position_x, example_position_y].cpu())
plt.show() plt.show()
if play_movie is True: if play_movie:
ani = Anime() ani = Anime()
ani.show(result, mask=mask, vmin_scale=0.5, vmax_scale=0.5) ani.show(
result - 1.0, mask=mask, vmin_scale=0.5, vmax_scale=0.5
) # , vmin=0.98) # , vmin=1.0, vmax_scale=1.0)

View file

@ -762,7 +762,7 @@ class ImageAlignment(torch.nn.Module):
bgval: torch.Tensor | None = None, bgval: torch.Tensor | None = None,
invert=False, invert=False,
) -> torch.Tensor: ) -> torch.Tensor:
if invert is True: if invert:
if scale is not None: if scale is not None:
scale = 1.0 / scale scale = 1.0 / scale
if angle is not None: if angle is not None:
@ -927,7 +927,7 @@ class ImageAlignment(torch.nn.Module):
if succ2[pos] > succ[pos]: if succ2[pos] > succ[pos]:
pick_rotated = True pick_rotated = True
if pick_rotated is True: if pick_rotated:
tvec[pos, :] = tvec2[pos, :] tvec[pos, :] = tvec2[pos, :]
succ[pos] = succ2[pos] succ[pos] = succ2[pos]
angle[pos] += 180 angle[pos] += 180