diff --git a/Anime.py b/Anime.py index 73f46a8..628624c 100644 --- a/Anime.py +++ b/Anime.py @@ -11,7 +11,7 @@ class Anime: def show( self, input: torch.Tensor | np.ndarray, - mask: torch.Tensor | np.ndarray | None, + mask: torch.Tensor | np.ndarray | None = None, vmin: float | None = None, vmax: float | None = None, cmap: str = "hot", @@ -60,10 +60,10 @@ class Anime: vmax=vmax, ) - if colorbar is True: + if colorbar: plt.colorbar() - if axis_off is True: + if axis_off: plt.axis("off") def next_frame(i: int) -> None: @@ -72,7 +72,7 @@ class Anime: image[mask_np] = float("NaN") image_handle.set_data(image) - if show_frame_count is True: + if show_frame_count: bar_length: int = 10 filled_length = int(round(bar_length * i / input_np.shape[0])) bar = "\u25A0" * filled_length + "\u25A1" * (bar_length - filled_length) diff --git a/DataContainer.py b/DataContainer.py index f7a7fd9..d606703 100644 --- a/DataContainer.py +++ b/DataContainer.py @@ -50,9 +50,10 @@ class DataContainer(torch.nn.Module): volume_eigenvalues: torch.Tensor | None = None volume_residuum: torch.Tensor | None = None - power_d_initial: torch.Tensor | None = None - power_d_final: torch.Tensor | None = None - power_d_amplitude: torch.Tensor | None = None + acceptor_scale: torch.Tensor | None = None + donor_scale: torch.Tensor | None = None + oxygenation_scale: torch.Tensor | None = None + volume_scale: torch.Tensor | None = None # ------- image_alignment: ImageAlignment @@ -96,7 +97,7 @@ class DataContainer(torch.nn.Module): self.logger = logging.getLogger("DataContainer") self.logger.setLevel(logging.DEBUG) - if save_logging_messages is True: + if save_logging_messages: time_format = "%b %-d %Y %H:%M:%S" logformat = "%(asctime)s %(message)s" file_formatter = logging.Formatter(fmt=logformat, datefmt=time_format) @@ -106,7 +107,7 @@ class DataContainer(torch.nn.Module): file_handler.setFormatter(file_formatter) self.logger.addHandler(file_handler) - if display_logging_messages is True: + if display_logging_messages: time_format = "%b %-d %Y %H:%M:%S" logformat = "%(asctime)s %(message)s" stream_formatter = logging.Formatter(fmt=logformat, datefmt=time_format) @@ -128,7 +129,7 @@ class DataContainer(torch.nn.Module): json_postfix: str = "_meta.txt" found_name_json: str = file_input_ref_image.replace(".npy", json_postfix) - assert os.path.isfile(found_name_json) is True + assert os.path.isfile(found_name_json) with open(found_name_json, "r") as file_handle: metadata = json.load(file_handle) @@ -203,9 +204,7 @@ class DataContainer(torch.nn.Module): f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part{part_id:03d}_meta.txt", ) - while (os.path.isfile(filename_np) is True) and ( - os.path.isfile(filename_meta) is True - ): + while (os.path.isfile(filename_np)) and (os.path.isfile(filename_meta)): self.logger.info(f"{self.level3} work in {filename_np}") # Check if channel asignment is still okay with open(filename_meta, "r") as file_handle: @@ -218,7 +217,7 @@ class DataContainer(torch.nn.Module): # Load the data... self.logger.info(f"{self.level3} np.load") - if mmap_mode is True: + if mmap_mode: temp: np.ndarray = np.load(filename_np, mmap_mode="r") else: temp = np.load(filename_np) @@ -275,7 +274,7 @@ class DataContainer(torch.nn.Module): dim=2, ) - if enable_secondary_data is True: + if enable_secondary_data: self.logger.info(f"{self.level3} organize oxygenation") if self.oxygenation is None: self.oxygenation = torch.tensor( @@ -344,13 +343,13 @@ class DataContainer(torch.nn.Module): self.acceptor = self.acceptor.moveaxis(-1, 0) self.donor = self.donor.moveaxis(-1, 0) - if enable_secondary_data is True: + if enable_secondary_data: assert self.oxygenation is not None assert self.volume is not None self.oxygenation = self.oxygenation.moveaxis(-1, 0) self.volume = self.volume.moveaxis(-1, 0) - if align is True: + if align: self.logger.info(f"{self.level3} move intra timeseries") self._move_intra_timeseries( enable_secondary_data=enable_secondary_data, @@ -494,7 +493,7 @@ class DataContainer(torch.nn.Module): fill=self.fill_value, ) - if enable_secondary_data is True: + if enable_secondary_data: assert self.volume is not None self.volume = tv.transforms.functional.affine( img=self.volume, @@ -522,7 +521,7 @@ class DataContainer(torch.nn.Module): shear=0, fill=self.fill_value, ) - if enable_secondary_data is True: + if enable_secondary_data: assert self.oxygenation is not None self.oxygenation = tv.transforms.functional.affine( img=self.oxygenation, @@ -557,7 +556,7 @@ class DataContainer(torch.nn.Module): fill=self.fill_value, ) - if enable_secondary_data is True: + if enable_secondary_data: assert self.oxygenation is not None self.oxygenation = tv.transforms.functional.affine( img=self.oxygenation, @@ -593,7 +592,7 @@ class DataContainer(torch.nn.Module): fill=self.fill_value, ) - if enable_secondary_data is True: + if enable_secondary_data: assert self.oxygenation is not None self.oxygenation = tv.transforms.functional.affine( img=self.oxygenation, @@ -786,7 +785,7 @@ class DataContainer(torch.nn.Module): else: self.donor_residuum += to_remove - if enable_secondary_data is True: + if enable_secondary_data: to_remove, _, _, _ = self.volume_svd_remove( lowrank_method=lowrank_method, lowrank_q=lowrank_q, @@ -814,7 +813,7 @@ class DataContainer(torch.nn.Module): self.donor -= self.donor.mean(dim=0, keepdim=True) self.acceptor -= self.acceptor.mean(dim=0, keepdim=True) - if enable_secondary_data is True: + if enable_secondary_data: assert self.volume is not None assert self.oxygenation is not None self.volume -= self.volume.mean(dim=0, keepdim=True) @@ -827,7 +826,7 @@ class DataContainer(torch.nn.Module): self.donor_residuum -= self.donor_residuum.mean(dim=0, keepdim=True) self.acceptor_residuum -= self.acceptor_residuum.mean(dim=0, keepdim=True) - if enable_secondary_data is True: + if enable_secondary_data: assert self.volume_residuum is not None assert self.oxygenation_residuum is not None self.volume_residuum -= self.volume_residuum.mean(dim=0, keepdim=True) @@ -858,7 +857,7 @@ class DataContainer(torch.nn.Module): self.donor -= self._calculate_linear_trend_data(self.donor) self.acceptor -= self._calculate_linear_trend_data(self.acceptor) - if enable_secondary_data is True: + if enable_secondary_data: assert self.volume is not None assert self.oxygenation is not None self.volume -= self._calculate_linear_trend_data(self.volume) @@ -877,7 +876,7 @@ class DataContainer(torch.nn.Module): self.acceptor_residuum ) - if enable_secondary_data is True: + if enable_secondary_data: assert self.volume_residuum is not None assert self.oxygenation_residuum is not None self.volume_residuum -= self._calculate_linear_trend_data( @@ -897,7 +896,7 @@ class DataContainer(torch.nn.Module): self.donor = self.donor[1:, :, :] self.acceptor = self.acceptor[1:, :, :] - if enable_secondary_data is True: + if enable_secondary_data: assert self.volume is not None assert self.oxygenation is not None self.volume = (self.volume[1:, :, :] + self.volume[:-1, :, :]) / 2.0 @@ -911,7 +910,7 @@ class DataContainer(torch.nn.Module): if self.acceptor_residuum is not None: self.acceptor_residuum = self.acceptor_residuum[1:, :, :] - if enable_secondary_data is True: + if enable_secondary_data: if self.volume_residuum is not None: self.volume_residuum = ( self.volume_residuum[1:, :, :] + self.volume_residuum[:-1, :, :] @@ -943,7 +942,6 @@ class DataContainer(torch.nn.Module): mmap_mode: bool = True, initital_mask: torch.Tensor | None = None, start_position_coefficients: int = 0, - calculate_amplitude: bool = False, ) -> None: self.logger.info(f"{self.level2} start load_data") self.load_data( @@ -962,12 +960,32 @@ class DataContainer(torch.nn.Module): pool = torch.nn.AvgPool2d((bin_size, bin_size), stride=(bin_size, bin_size)) self.donor = pool(self.donor) self.acceptor = pool(self.acceptor) - if enable_secondary_data is True: + if enable_secondary_data: assert self.volume is not None assert self.oxygenation is not None self.volume = pool(self.volume) self.oxygenation = pool(self.oxygenation) + if self.donor is not None: + self.donor_scale = self.donor.mean(dim=0, keepdim=True) + self.donor /= self.donor_scale + self.donor -= 1.0 + + if self.acceptor is not None: + self.acceptor_scale = self.acceptor.mean(dim=0, keepdim=True) + self.acceptor /= self.acceptor_scale + self.acceptor -= 1.0 + + if self.volume is not None: + self.volume_scale = self.volume.mean(dim=0, keepdim=True) + self.volume /= self.volume_scale + self.volume -= 1.0 + + if self.oxygenation is not None: + self.oxygenation_scale = self.oxygenation.mean(dim=0, keepdim=True) + self.oxygenation /= self.oxygenation_scale + self.oxygenation -= 1.0 + if initital_mask is not None: self.logger.info(f"{self.level2} initial mask is applied on the data") assert self.acceptor is not None @@ -979,26 +997,13 @@ class DataContainer(torch.nn.Module): self.acceptor *= initital_mask.unsqueeze(0) self.donor *= initital_mask.unsqueeze(0) - if enable_secondary_data is True: + if enable_secondary_data: assert self.oxygenation is not None assert self.volume is not None self.oxygenation *= initital_mask.unsqueeze(0) self.volume *= initital_mask.unsqueeze(0) - if calculate_amplitude is True: - ( - self.power_hb_low_initial, - self.power_hb_high_initial, - _, - ) = self.measure_heartbeat_frequency(use_input_source="donor") - self.power_d_initial = self.measure_heartbeat_power( - use_input_source="donor", - start_position_coefficients=start_position_coefficients, - power_hb_low=self.power_hb_low_initial, - power_hb_high=self.power_hb_high_initial, - ) - - if remove_heartbeat is True: + if remove_heartbeat: self.logger.info(f"{self.level2} remove the heart beat via SVD") self.remove_heartbeat( iterations=iterations, @@ -1008,19 +1013,19 @@ class DataContainer(torch.nn.Module): start_position_coefficients=start_position_coefficients, ) - if remove_mean is True: + if remove_mean: self.logger.info(f"{self.level2} remove mean") self.remove_mean_data(enable_secondary_data=enable_secondary_data) - if remove_linear is True: + if remove_linear: self.logger.info(f"{self.level2} remove linear trends") self.remove_linear_trend_data(enable_secondary_data=enable_secondary_data) - if remove_heartbeat is True: - if remove_heartbeat_mean is True: + if remove_heartbeat: + if remove_heartbeat_mean: self.logger.info(f"{self.level2} remove mean (heart beat signal)") self.remove_mean_residuum(enable_secondary_data=enable_secondary_data) - if remove_heartbeat_linear is True: + if remove_heartbeat_linear: self.logger.info( f"{self.level2} remove linear trends (heart beat signal)" ) @@ -1028,7 +1033,7 @@ class DataContainer(torch.nn.Module): enable_secondary_data=enable_secondary_data ) - if do_frame_shift is True: + if do_frame_shift: self.logger.info(f"{self.level2} frame shift") self.frame_shift(enable_secondary_data=enable_secondary_data) @@ -1141,7 +1146,7 @@ class DataContainer(torch.nn.Module): del o_norm del v_norm - if export_parameters is True: + if export_parameters: parameter_a_temp: torch.Tensor | None = torch.zeros_like(data_norm) parameter_d_temp: torch.Tensor | None = torch.zeros_like(data_norm) else: @@ -1149,7 +1154,7 @@ class DataContainer(torch.nn.Module): parameter_d_temp = None for mode_a in [True, False]: - if mode_a is True: + if mode_a: result = a.detach().clone() result_mean_correct = a_correction @@ -1178,7 +1183,7 @@ class DataContainer(torch.nn.Module): result -= data_selected * scale.unsqueeze(0) - if mode_a is True: + if mode_a: if i == 0: initial_scale_value_a = max( [max_scale_value_a, float(scale.max())] @@ -1198,7 +1203,7 @@ class DataContainer(torch.nn.Module): -1, idx.unsqueeze(-1), scale.unsqueeze(-1) ) - if mode_a is True: + if mode_a: result_a[:, chunk, :] = result.detach().clone() max_scale_value_a = max([max_scale_value_a, float(scale.max())]) if parameter_a_temp is not None: @@ -1214,7 +1219,7 @@ class DataContainer(torch.nn.Module): (parameter_d_temp, d_mean_full.squeeze(0).unsqueeze(-1)), dim=-1, ) - if export_parameters is True: + if export_parameters: if (parameter_a is None) and (parameter_a_temp is not None): parameter_a = torch.zeros( ( @@ -1366,7 +1371,7 @@ class DataContainer(torch.nn.Module): heartbeat_a = torch.sqrt(scale) heartbeat_d = 1.0 / (heartbeat_a + 1e-20) - if apply_to_data is True: + if apply_to_data: if self.donor is not None: self.donor *= heartbeat_d.unsqueeze(0) if self.volume is not None: @@ -1378,8 +1383,18 @@ class DataContainer(torch.nn.Module): if threshold is not None: self.logger.info(f"{self.level3} calculate mask") - mask = torch.where(hb_d.std(dim=0) > threshold, 1.0, 0.0) * torch.where( - hb_a.std(dim=0) > threshold, 1.0, 0.0 + assert self.donor_scale is not None + assert self.acceptor_scale is not None + temp_d = hb_d.std(dim=0) * self.donor_scale.squeeze(0) + temp_d -= temp_d.min() + temp_d /= temp_d.max() + + temp_a = hb_a.std(dim=0) * self.acceptor_scale.squeeze(0) + temp_a -= temp_a.min() + temp_a /= temp_a.max() + + mask = torch.where(temp_d > threshold, 1.0, 0.0) * torch.where( + temp_a > threshold, 1.0, 0.0 ) else: mask = None @@ -1506,7 +1521,7 @@ class DataContainer(torch.nn.Module): start_position: int = 0, start_position_coefficients: int = 100, fs: float = 100.0, - use_regression: bool | None = False, + use_regression: bool | None = None, # Heartbeat remove_heartbeat: bool = True, # i.e. use SVD low_frequency: float = 5, # Hz Butter Bandpass Heartbeat @@ -1520,7 +1535,7 @@ class DataContainer(torch.nn.Module): remove_heartbeat_mean: bool = False, remove_heartbeat_linear: bool = False, bin_size: int = 4, - do_frame_shift: bool = True, + do_frame_shift: bool | None = None, half_width_frequency_window: float = 3.0, # Hz (on side ) measure_heartbeat_frequency mmap_mode: bool = True, initital_mask_name: str | None = None, @@ -1529,18 +1544,18 @@ class DataContainer(torch.nn.Module): gaussian_blur_kernel_size: int | None = 3, gaussian_blur_sigma: float = 1.0, bin_size_post: int | None = None, - calculate_amplitude: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: self.logger.info(f"{self.level0} start automatic_load") if use_regression is None: use_regression = not remove_heartbeat + if do_frame_shift is None: + do_frame_shift = not remove_heartbeat + initital_mask: torch.Tensor | None = None - if (initital_mask_name is not None) and os.path.isfile( - initital_mask_name - ) is True: + if (initital_mask_name is not None) and os.path.isfile(initital_mask_name): initital_mask = torch.tensor( np.load(initital_mask_name), device=self.device, dtype=torch.float32 ) @@ -1567,7 +1582,6 @@ class DataContainer(torch.nn.Module): mmap_mode=mmap_mode, initital_mask=initital_mask, start_position_coefficients=start_position_coefficients, - calculate_amplitude=calculate_amplitude, ) heartbeat_a: torch.Tensor | None = None @@ -1576,7 +1590,7 @@ class DataContainer(torch.nn.Module): power_hb_low: torch.Tensor | None = None power_hb_high: torch.Tensor | None = None - if remove_heartbeat is True: + if remove_heartbeat: self.logger.info(f"{self.level1} remove heart beat (heartbeat_scale)") heartbeat_a, heartbeat_d, mask = self.heartbeat_scale( low_frequency=low_frequency, @@ -1590,6 +1604,7 @@ class DataContainer(torch.nn.Module): self.logger.info( f"{self.level1} measure heart rate (measure_heartbeat_frequency)" ) + assert self.volume is not None ( power_hb_low, power_hb_high, @@ -1603,7 +1618,7 @@ class DataContainer(torch.nn.Module): half_width_frequency_window=half_width_frequency_window, ) - if use_regression is True: + if use_regression: self.logger.info(f"{self.level1} use regression") ( result_a, @@ -1663,14 +1678,14 @@ class DataContainer(torch.nn.Module): result_d *= heartbeat_d.unsqueeze(0) if mask is not None: - if initital_mask_update is True: + if initital_mask_update: self.logger.info(f"{self.level1} update inital mask") if initital_mask is None: initital_mask = mask.clone() else: initital_mask *= mask - if (initital_mask_roi is True) and (initital_mask is not None): + if (initital_mask_roi) and (initital_mask is not None): self.logger.info(f"{self.level1} enter roi mask drawing modus") yes_choices = ["yes", "y"] contiue_roi: bool = True @@ -1678,7 +1693,7 @@ class DataContainer(torch.nn.Module): image: np.ndarray = (result_a - result_d)[0, ...].cpu().numpy() image[initital_mask.cpu().numpy() == 0] = float("NaN") - while contiue_roi is True: + while contiue_roi: user_input = input( "Mask: Do you want to remove more pixel (yes/no)? " ) @@ -1721,22 +1736,8 @@ class DataContainer(torch.nn.Module): self.logger.info(f"{self.level0} end automatic_load") - if self.power_d_initial is not None: - self.power_d_final = self.measure_heartbeat_power( - use_input_source="custom", - power_hb_low=self.power_hb_low_initial, - power_hb_high=self.power_hb_high_initial, - start_position_coefficients=start_position_coefficients, - custom_input=result_d, - ) - self.power_d_amplitude = self.power_d_final / self.power_d_initial - self.power_d_amplitude = torch.nan_to_num(self.power_d_amplitude, nan=0.0) - - result = result_a - result_d - - if self.power_d_amplitude is not None: - result *= self.power_d_amplitude.unsqueeze(0) - result += 1.0 + # result = (1.0 + result_a) / (1.0 + result_d) + result = 1.0 + result_a - result_d if (gaussian_blur_kernel_size is not None) and (gaussian_blur_kernel_size > 0): gaussian_blur = tv.transforms.GaussianBlur( @@ -1774,7 +1775,7 @@ if __name__ == "__main__": start_position_coefficients: int = 100 remove_heartbeat: bool = True # i.e. use SVD bin_size: int = 4 - calculate_amplitude: bool = False + threshold: float | None = 0.05 # Between 0 and 1.0 example_position_x: int = 280 example_position_y: int = 440 @@ -1820,13 +1821,15 @@ if __name__ == "__main__": gaussian_blur_kernel_size=gaussian_blur_kernel_size, gaussian_blur_sigma=gaussian_blur_sigma, bin_size_post=bin_size_post, - calculate_amplitude=calculate_amplitude, + threshold=threshold, ) - if show_example_timeseries is True: + if show_example_timeseries: plt.plot(result[:, example_position_x, example_position_y].cpu()) plt.show() - if play_movie is True: + if play_movie: ani = Anime() - ani.show(result, mask=mask, vmin_scale=0.5, vmax_scale=0.5) + ani.show( + result - 1.0, mask=mask, vmin_scale=0.5, vmax_scale=0.5 + ) # , vmin=0.98) # , vmin=1.0, vmax_scale=1.0) diff --git a/ImageAlignment.py b/ImageAlignment.py index cb18197..ab483b3 100644 --- a/ImageAlignment.py +++ b/ImageAlignment.py @@ -762,7 +762,7 @@ class ImageAlignment(torch.nn.Module): bgval: torch.Tensor | None = None, invert=False, ) -> torch.Tensor: - if invert is True: + if invert: if scale is not None: scale = 1.0 / scale if angle is not None: @@ -927,7 +927,7 @@ class ImageAlignment(torch.nn.Module): if succ2[pos] > succ[pos]: pick_rotated = True - if pick_rotated is True: + if pick_rotated: tvec[pos, :] = tvec2[pos, :] succ[pos] = succ2[pos] angle[pos] += 180