diff --git a/reproduction_effort/functions/adjust_factor.py b/reproduction_effort/functions/adjust_factor.py index 52c902f..2adc4e3 100644 --- a/reproduction_effort/functions/adjust_factor.py +++ b/reproduction_effort/functions/adjust_factor.py @@ -9,6 +9,7 @@ def adjust_factor( upper_frequency_heartbeat: float, sample_frequency: float, mask: torch.Tensor, + power_factors: None | list[float], ) -> tuple[float, float]: number_of_active_pixel: torch.Tensor = mask.type(dtype=torch.float32).sum() @@ -23,60 +24,67 @@ def adjust_factor( signal_acceptor_offset = signal_acceptor.mean() signal_donor_offset = signal_donor.mean() - signal_acceptor = signal_acceptor - signal_acceptor_offset - signal_donor = signal_donor - signal_donor_offset + if power_factors is None: + signal_acceptor = signal_acceptor - signal_acceptor_offset + signal_donor = signal_donor - signal_donor_offset - blackman_window = torch.blackman_window( - window_length=signal_acceptor.shape[0], - periodic=True, - dtype=signal_acceptor.dtype, - device=signal_acceptor.device, - ) + blackman_window = torch.blackman_window( + window_length=signal_acceptor.shape[0], + periodic=True, + dtype=signal_acceptor.dtype, + device=signal_acceptor.device, + ) - signal_acceptor *= blackman_window - signal_donor *= blackman_window - nfft: int = int(2 ** math.ceil(math.log2(signal_donor.shape[0]))) - nfft = max([256, nfft]) + signal_acceptor *= blackman_window + signal_donor *= blackman_window + nfft: int = int(2 ** math.ceil(math.log2(signal_donor.shape[0]))) + nfft = max([256, nfft]) - signal_acceptor_fft: torch.Tensor = torch.fft.rfft(signal_acceptor, n=nfft) - signal_donor_fft: torch.Tensor = torch.fft.rfft(signal_donor, n=nfft) + signal_acceptor_fft: torch.Tensor = torch.fft.rfft(signal_acceptor, n=nfft) + signal_donor_fft: torch.Tensor = torch.fft.rfft(signal_donor, n=nfft) - frequency_axis: torch.Tensor = ( - torch.fft.rfftfreq(nfft, device=signal_acceptor_fft.device) * sample_frequency - ) + frequency_axis: torch.Tensor = ( + torch.fft.rfftfreq(nfft, device=signal_acceptor_fft.device) + * sample_frequency + ) - signal_acceptor_power: torch.Tensor = torch.abs(signal_acceptor_fft) ** 2 - signal_acceptor_power[1:-1] *= 2 + signal_acceptor_power: torch.Tensor = torch.abs(signal_acceptor_fft) ** 2 + signal_acceptor_power[1:-1] *= 2 - signal_donor_power: torch.Tensor = torch.abs(signal_donor_fft) ** 2 - signal_donor_power[1:-1] *= 2 + signal_donor_power: torch.Tensor = torch.abs(signal_donor_fft) ** 2 + signal_donor_power[1:-1] *= 2 - if frequency_axis[-1] != (sample_frequency / 2.0): - signal_acceptor_power[-1] *= 2 - signal_donor_power[-1] *= 2 + if frequency_axis[-1] != (sample_frequency / 2.0): + signal_acceptor_power[-1] *= 2 + signal_donor_power[-1] *= 2 - signal_acceptor_power /= blackman_window.sum() ** 2 - signal_donor_power /= blackman_window.sum() ** 2 + signal_acceptor_power /= blackman_window.sum() ** 2 + signal_donor_power /= blackman_window.sum() ** 2 - idx = torch.where( - (frequency_axis >= lower_frequency_heartbeat) - * (frequency_axis <= upper_frequency_heartbeat) - )[0] + idx = torch.where( + (frequency_axis >= lower_frequency_heartbeat) + * (frequency_axis <= upper_frequency_heartbeat) + )[0] - frequency_axis = frequency_axis[idx] - signal_acceptor_power = signal_acceptor_power[idx] - signal_donor_power = signal_donor_power[idx] + frequency_axis = frequency_axis[idx] + signal_acceptor_power = signal_acceptor_power[idx] + signal_donor_power = signal_donor_power[idx] - acceptor_range = signal_acceptor_power.max() - signal_acceptor_power.min() + acceptor_range: float = float( + signal_acceptor_power.max() - signal_acceptor_power.min() + ) - donor_range = signal_donor_power.max() - signal_donor_power.min() + donor_range: float = float(signal_donor_power.max() - signal_donor_power.min()) + else: + donor_range = float(power_factors[0]) + acceptor_range = float(power_factors[1]) acceptor_correction_factor: float = float( 0.5 * ( 1 - + (signal_acceptor_offset * torch.sqrt(donor_range)) - / (signal_donor_offset * torch.sqrt(acceptor_range)) + + (signal_acceptor_offset * math.sqrt(donor_range)) + / (signal_donor_offset * math.sqrt(acceptor_range)) ) ) diff --git a/reproduction_effort/functions/preprocessing.py b/reproduction_effort/functions/preprocessing.py index 7d953fc..e1b16a3 100644 --- a/reproduction_effort/functions/preprocessing.py +++ b/reproduction_effort/functions/preprocessing.py @@ -25,6 +25,7 @@ def preprocessing( upper_frequency_heartbeat: float, sample_frequency: float, dtype: torch.dtype = torch.float32, + power_factors: None | list[float] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: mask: torch.Tensor = make_mask( @@ -44,20 +45,24 @@ def preprocessing( ) # Interpolate in-between images - interpolate_along_time(camera_sequence) + if power_factors is None: + interpolate_along_time(camera_sequence) camera_sequence_filtered: list[torch.Tensor] = [] for id in range(0, len(camera_sequence)): camera_sequence_filtered.append(camera_sequence[id].clone()) - idx_volume: int = cameras.index("volume") - heart_rate: float = heart_beat_frequency( - input=camera_sequence_filtered[idx_volume], - lower_frequency_heartbeat=lower_frequency_heartbeat, - upper_frequency_heartbeat=upper_frequency_heartbeat, - sample_frequency=sample_frequency, - mask=mask, - ) + if power_factors is None: + idx_volume: int = cameras.index("volume") + heart_rate: None | float = heart_beat_frequency( + input=camera_sequence_filtered[idx_volume], + lower_frequency_heartbeat=lower_frequency_heartbeat, + upper_frequency_heartbeat=upper_frequency_heartbeat, + sample_frequency=sample_frequency, + mask=mask, + ) + else: + heart_rate = None camera_sequence_filtered = gauss_smear( camera_sequence_filtered, @@ -88,8 +93,12 @@ def preprocessing( ) results.append(output) - lower_frequency_heartbeat_selection: float = heart_rate - 3 - upper_frequency_heartbeat_selection: float = heart_rate + 3 + if heart_rate is not None: + lower_frequency_heartbeat_selection: float = heart_rate - 3 + upper_frequency_heartbeat_selection: float = heart_rate + 3 + else: + lower_frequency_heartbeat_selection = 0 + upper_frequency_heartbeat_selection = 0 donor_correction_factor, acceptor_correction_factor = adjust_factor( input_acceptor=results[0], @@ -98,6 +107,7 @@ def preprocessing( upper_frequency_heartbeat=upper_frequency_heartbeat_selection, sample_frequency=sample_frequency, mask=mask, + power_factors=power_factors, ) results[0] = acceptor_correction_factor * (