gevi/functions/bandpass.py

114 lines
3.1 KiB
Python
Raw Permalink Normal View History

2024-02-28 16:14:50 +01:00
import torchaudio as ta # type: ignore
import torch
@torch.no_grad()
def filtfilt(
input: torch.Tensor,
butter_a: torch.Tensor,
butter_b: torch.Tensor,
) -> torch.Tensor:
assert butter_a.ndim == 1
assert butter_b.ndim == 1
assert butter_a.shape[0] == butter_b.shape[0]
process_data: torch.Tensor = input.detach().clone()
padding_length = 12 * int(butter_a.shape[0])
left_padding = 2 * process_data[..., 0].unsqueeze(-1) - process_data[
..., 1 : padding_length + 1
].flip(-1)
right_padding = 2 * process_data[..., -1].unsqueeze(-1) - process_data[
..., -(padding_length + 1) : -1
].flip(-1)
process_data_padded = torch.cat((left_padding, process_data, right_padding), dim=-1)
output = ta.functional.filtfilt(
process_data_padded.unsqueeze(0), butter_a, butter_b, clamp=False
).squeeze(0)
output = output[..., padding_length:-padding_length]
return output
@torch.no_grad()
def butter_bandpass(
device: torch.device,
low_frequency: float = 0.1,
high_frequency: float = 1.0,
fs: float = 30.0,
) -> tuple[torch.Tensor, torch.Tensor]:
import scipy # type: ignore
butter_b_np, butter_a_np = scipy.signal.butter(
4, [low_frequency, high_frequency], btype="bandpass", output="ba", fs=fs
)
butter_a = torch.tensor(butter_a_np, device=device, dtype=torch.float32)
butter_b = torch.tensor(butter_b_np, device=device, dtype=torch.float32)
return butter_a, butter_b
@torch.no_grad()
def chunk_iterator(array: torch.Tensor, chunk_size: int):
for i in range(0, array.shape[0], chunk_size):
yield array[i : i + chunk_size]
@torch.no_grad()
def bandpass(
data: torch.Tensor,
2024-02-28 18:55:37 +01:00
low_frequency: float = 0.1,
high_frequency: float = 1.0,
fs=30.0,
filtfilt_chuck_size: int = 10,
) -> torch.Tensor:
try:
return bandpass_internal(
data=data,
low_frequency=low_frequency,
high_frequency=high_frequency,
fs=fs,
filtfilt_chuck_size=filtfilt_chuck_size,
)
except torch.cuda.OutOfMemoryError:
return bandpass_internal(
data=data.cpu(),
low_frequency=low_frequency,
high_frequency=high_frequency,
fs=fs,
filtfilt_chuck_size=filtfilt_chuck_size,
).to(device=data.device)
@torch.no_grad()
def bandpass_internal(
data: torch.Tensor,
2024-02-28 16:14:50 +01:00
low_frequency: float = 0.1,
high_frequency: float = 1.0,
fs=30.0,
filtfilt_chuck_size: int = 10,
) -> torch.Tensor:
butter_a, butter_b = butter_bandpass(
2024-02-28 18:55:37 +01:00
device=data.device,
2024-02-28 16:14:50 +01:00
low_frequency=low_frequency,
high_frequency=high_frequency,
fs=fs,
)
index_full_dataset: torch.Tensor = torch.arange(
2024-02-28 18:55:37 +01:00
0, data.shape[1], device=data.device, dtype=torch.int64
2024-02-28 16:14:50 +01:00
)
for chunk in chunk_iterator(index_full_dataset, filtfilt_chuck_size):
temp_filtfilt = filtfilt(
data[:, chunk, :],
butter_a=butter_a,
butter_b=butter_b,
)
data[:, chunk, :] = temp_filtfilt
return data