gevi/reproduction_effort/functions/adjust_factor.py
2024-02-16 09:59:49 +01:00

95 lines
3.1 KiB
Python

import torch
import math
def adjust_factor(
input_acceptor: torch.Tensor,
input_donor: torch.Tensor,
lower_frequency_heartbeat: float,
upper_frequency_heartbeat: float,
sample_frequency: float,
mask: torch.Tensor,
power_factors: None | list[float],
) -> tuple[float, float]:
number_of_active_pixel: torch.Tensor = mask.type(dtype=torch.float32).sum()
signal_acceptor: torch.Tensor = (input_acceptor * mask.unsqueeze(-1)).sum(
dim=0
).sum(dim=0) / number_of_active_pixel
signal_donor: torch.Tensor = (input_donor * mask.unsqueeze(-1)).sum(dim=0).sum(
dim=0
) / number_of_active_pixel
signal_acceptor_offset = signal_acceptor.mean()
signal_donor_offset = signal_donor.mean()
if power_factors is None:
signal_acceptor = signal_acceptor - signal_acceptor_offset
signal_donor = signal_donor - signal_donor_offset
blackman_window = torch.blackman_window(
window_length=signal_acceptor.shape[0],
periodic=True,
dtype=signal_acceptor.dtype,
device=signal_acceptor.device,
)
signal_acceptor *= blackman_window
signal_donor *= blackman_window
nfft: int = int(2 ** math.ceil(math.log2(signal_donor.shape[0])))
nfft = max([256, nfft])
signal_acceptor_fft: torch.Tensor = torch.fft.rfft(signal_acceptor, n=nfft)
signal_donor_fft: torch.Tensor = torch.fft.rfft(signal_donor, n=nfft)
frequency_axis: torch.Tensor = (
torch.fft.rfftfreq(nfft, device=signal_acceptor_fft.device)
* sample_frequency
)
signal_acceptor_power: torch.Tensor = torch.abs(signal_acceptor_fft) ** 2
signal_acceptor_power[1:-1] *= 2
signal_donor_power: torch.Tensor = torch.abs(signal_donor_fft) ** 2
signal_donor_power[1:-1] *= 2
if frequency_axis[-1] != (sample_frequency / 2.0):
signal_acceptor_power[-1] *= 2
signal_donor_power[-1] *= 2
signal_acceptor_power /= blackman_window.sum() ** 2
signal_donor_power /= blackman_window.sum() ** 2
idx = torch.where(
(frequency_axis >= lower_frequency_heartbeat)
* (frequency_axis <= upper_frequency_heartbeat)
)[0]
frequency_axis = frequency_axis[idx]
signal_acceptor_power = signal_acceptor_power[idx]
signal_donor_power = signal_donor_power[idx]
acceptor_range: float = float(
signal_acceptor_power.max() - signal_acceptor_power.min()
)
donor_range: float = float(signal_donor_power.max() - signal_donor_power.min())
else:
donor_range = float(power_factors[0])
acceptor_range = float(power_factors[1])
acceptor_correction_factor: float = float(
0.5
* (
1
+ (signal_acceptor_offset * math.sqrt(donor_range))
/ (signal_donor_offset * math.sqrt(acceptor_range))
)
)
donor_correction_factor: float = float(
acceptor_correction_factor / (2 * acceptor_correction_factor - 1)
)
return donor_correction_factor, acceptor_correction_factor