diff --git a/reproduction_effort/functions/preprocessing.py b/reproduction_effort/functions/preprocessing.py index e1b16a3..27a1172 100644 --- a/reproduction_effort/functions/preprocessing.py +++ b/reproduction_effort/functions/preprocessing.py @@ -25,7 +25,7 @@ def preprocessing( upper_frequency_heartbeat: float, sample_frequency: float, dtype: torch.dtype = torch.float32, - power_factors: None | list[float] = None, + power_factors: None | list[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: mask: torch.Tensor = make_mask( @@ -100,22 +100,38 @@ def preprocessing( lower_frequency_heartbeat_selection = 0 upper_frequency_heartbeat_selection = 0 - donor_correction_factor, acceptor_correction_factor = adjust_factor( - input_acceptor=results[0], - input_donor=results[1], - lower_frequency_heartbeat=lower_frequency_heartbeat_selection, - upper_frequency_heartbeat=upper_frequency_heartbeat_selection, - sample_frequency=sample_frequency, - mask=mask, - power_factors=power_factors, - ) + donor_correction_factor: torch.Tensor | float + acceptor_correction_factor: torch.Tensor | float + if heart_rate is not None: + donor_correction_factor, acceptor_correction_factor = adjust_factor( + input_acceptor=results[0], + input_donor=results[1], + lower_frequency_heartbeat=lower_frequency_heartbeat_selection, + upper_frequency_heartbeat=upper_frequency_heartbeat_selection, + sample_frequency=sample_frequency, + mask=mask, + power_factors=power_factors, + ) - results[0] = acceptor_correction_factor * ( - results[0] - results[0].mean(dim=-1, keepdim=True) - ) + results[0].mean(dim=-1, keepdim=True) + results[0] = acceptor_correction_factor * ( + results[0] - results[0].mean(dim=-1, keepdim=True) + ) + results[0].mean(dim=-1, keepdim=True) - results[1] = donor_correction_factor * ( - results[1] - results[1].mean(dim=-1, keepdim=True) - ) + results[1].mean(dim=-1, keepdim=True) + results[1] = donor_correction_factor * ( + results[1] - results[1].mean(dim=-1, keepdim=True) + ) + results[1].mean(dim=-1, keepdim=True) + else: + assert power_factors is not None + donor_correction_factor = power_factors[0] + acceptor_correction_factor = power_factors[1] + donor_factor: torch.Tensor = ( + donor_correction_factor + acceptor_correction_factor + ) / (2 * donor_correction_factor) + acceptor_factor: torch.Tensor = ( + donor_correction_factor + acceptor_correction_factor + ) / (2 * acceptor_correction_factor) + + results[0] *= acceptor_factor * mask.unsqueeze(-1) + results[1] *= donor_factor * mask.unsqueeze(-1) return results[0], results[1], mask