Add files via upload

This commit is contained in:
David Rotermund 2024-02-22 14:30:14 +01:00 committed by GitHub
parent a64b39745a
commit 0a0626f18b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -25,7 +25,7 @@ def preprocessing(
upper_frequency_heartbeat: float,
sample_frequency: float,
dtype: torch.dtype = torch.float32,
power_factors: None | list[float] = None,
power_factors: None | list[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
mask: torch.Tensor = make_mask(
@ -100,22 +100,38 @@ def preprocessing(
lower_frequency_heartbeat_selection = 0
upper_frequency_heartbeat_selection = 0
donor_correction_factor, acceptor_correction_factor = adjust_factor(
input_acceptor=results[0],
input_donor=results[1],
lower_frequency_heartbeat=lower_frequency_heartbeat_selection,
upper_frequency_heartbeat=upper_frequency_heartbeat_selection,
sample_frequency=sample_frequency,
mask=mask,
power_factors=power_factors,
)
donor_correction_factor: torch.Tensor | float
acceptor_correction_factor: torch.Tensor | float
if heart_rate is not None:
donor_correction_factor, acceptor_correction_factor = adjust_factor(
input_acceptor=results[0],
input_donor=results[1],
lower_frequency_heartbeat=lower_frequency_heartbeat_selection,
upper_frequency_heartbeat=upper_frequency_heartbeat_selection,
sample_frequency=sample_frequency,
mask=mask,
power_factors=power_factors,
)
results[0] = acceptor_correction_factor * (
results[0] - results[0].mean(dim=-1, keepdim=True)
) + results[0].mean(dim=-1, keepdim=True)
results[0] = acceptor_correction_factor * (
results[0] - results[0].mean(dim=-1, keepdim=True)
) + results[0].mean(dim=-1, keepdim=True)
results[1] = donor_correction_factor * (
results[1] - results[1].mean(dim=-1, keepdim=True)
) + results[1].mean(dim=-1, keepdim=True)
results[1] = donor_correction_factor * (
results[1] - results[1].mean(dim=-1, keepdim=True)
) + results[1].mean(dim=-1, keepdim=True)
else:
assert power_factors is not None
donor_correction_factor = power_factors[0]
acceptor_correction_factor = power_factors[1]
donor_factor: torch.Tensor = (
donor_correction_factor + acceptor_correction_factor
) / (2 * donor_correction_factor)
acceptor_factor: torch.Tensor = (
donor_correction_factor + acceptor_correction_factor
) / (2 * acceptor_correction_factor)
results[0] *= acceptor_factor * mask.unsqueeze(-1)
results[1] *= donor_factor * mask.unsqueeze(-1)
return results[0], results[1], mask