Add files via upload
This commit is contained in:
parent
256be3a3c7
commit
b37a79f487
5 changed files with 1368 additions and 0 deletions
7
0_convert_avi_to_npy.py
Normal file
7
0_convert_avi_to_npy.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
from svd import convert_avi_to_npy
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Convert from avi to npy
|
||||||
|
filename: str = "example_data_crop"
|
||||||
|
convert_avi_to_npy(filename)
|
90
Anime.py
Normal file
90
Anime.py
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib.animation
|
||||||
|
|
||||||
|
|
||||||
|
class Anime:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def show(
|
||||||
|
self,
|
||||||
|
input: torch.Tensor | np.ndarray,
|
||||||
|
mask: torch.Tensor | np.ndarray | None = None,
|
||||||
|
vmin: float | None = None,
|
||||||
|
vmax: float | None = None,
|
||||||
|
cmap: str = "hot",
|
||||||
|
axis_off: bool = True,
|
||||||
|
show_frame_count: bool = True,
|
||||||
|
interval: int = 100,
|
||||||
|
repeat: bool = False,
|
||||||
|
colorbar: bool = True,
|
||||||
|
vmin_scale: float | None = None,
|
||||||
|
vmax_scale: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
assert input.ndim == 3
|
||||||
|
|
||||||
|
if isinstance(input, torch.Tensor):
|
||||||
|
input_np: np.ndarray = input.cpu().numpy()
|
||||||
|
if mask is not None:
|
||||||
|
mask_np: np.ndarray | None = (mask == 0).cpu().numpy()
|
||||||
|
else:
|
||||||
|
mask_np = None
|
||||||
|
else:
|
||||||
|
input_np = input
|
||||||
|
if mask is not None:
|
||||||
|
mask_np = mask == 0 # type: ignore
|
||||||
|
else:
|
||||||
|
mask_np = None
|
||||||
|
|
||||||
|
if vmin is None:
|
||||||
|
vmin = float(np.where(np.isfinite(input_np), input_np, 0.0).min())
|
||||||
|
if vmax is None:
|
||||||
|
vmax = float(np.where(np.isfinite(input_np), input_np, 0.0).max())
|
||||||
|
|
||||||
|
if vmin_scale is not None:
|
||||||
|
vmin *= vmin_scale
|
||||||
|
|
||||||
|
if vmax_scale is not None:
|
||||||
|
vmax *= vmax_scale
|
||||||
|
|
||||||
|
fig = plt.figure()
|
||||||
|
image = np.nan_to_num(input_np[0, ...], copy=True, nan=0.0)
|
||||||
|
if mask_np is not None:
|
||||||
|
image[mask_np] = float("NaN")
|
||||||
|
image_handle = plt.imshow(
|
||||||
|
image,
|
||||||
|
cmap=cmap,
|
||||||
|
vmin=vmin,
|
||||||
|
vmax=vmax,
|
||||||
|
)
|
||||||
|
|
||||||
|
if colorbar:
|
||||||
|
plt.colorbar()
|
||||||
|
|
||||||
|
if axis_off:
|
||||||
|
plt.axis("off")
|
||||||
|
|
||||||
|
def next_frame(i: int) -> None:
|
||||||
|
image = np.nan_to_num(input_np[i, ...], copy=True, nan=0.0)
|
||||||
|
if mask_np is not None:
|
||||||
|
image[mask_np] = float("NaN")
|
||||||
|
|
||||||
|
image_handle.set_data(image)
|
||||||
|
if show_frame_count:
|
||||||
|
bar_length: int = 10
|
||||||
|
filled_length = int(round(bar_length * i / input_np.shape[0]))
|
||||||
|
bar = "\u25A0" * filled_length + "\u25A1" * (bar_length - filled_length)
|
||||||
|
plt.title(f"{bar} {i} of {int(input_np.shape[0]-1)}", loc="left")
|
||||||
|
return
|
||||||
|
|
||||||
|
_ = matplotlib.animation.FuncAnimation(
|
||||||
|
fig,
|
||||||
|
next_frame,
|
||||||
|
frames=int(input.shape[0]),
|
||||||
|
interval=interval,
|
||||||
|
repeat=repeat,
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.show()
|
1010
ImageAlignment.py
Normal file
1010
ImageAlignment.py
Normal file
File diff suppressed because it is too large
Load diff
57
run_svd.py
Normal file
57
run_svd.py
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from svd import calculate_svd, to_remove, temporal_filter, svd_denoise
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
filename: str = "example_data_crop"
|
||||||
|
window_size: int = 2
|
||||||
|
kernel_size_pooling: int = 2
|
||||||
|
orig_freq: int = 30
|
||||||
|
new_freq: int = 3
|
||||||
|
filtfilt_chuck_size: int = 10
|
||||||
|
bp_low_frequency: float = 0.1
|
||||||
|
bp_high_frequency: float = 1.0
|
||||||
|
|
||||||
|
torch_device: torch.device = torch.device(
|
||||||
|
"cuda:0" if torch.cuda.is_available() else "cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Load data")
|
||||||
|
input = np.load(filename + str(".npy"))
|
||||||
|
data = torch.tensor(input, device=torch_device)
|
||||||
|
|
||||||
|
print("Movement compensation [MISSING!!!!]")
|
||||||
|
print("(include ImageAlignment.py into processing chain)")
|
||||||
|
|
||||||
|
print("SVD")
|
||||||
|
whiten_mean, whiten_k, eigenvalues = calculate_svd(data)
|
||||||
|
|
||||||
|
print("Calculate to_remove")
|
||||||
|
data = torch.tensor(input, device=torch_device)
|
||||||
|
to_remove_data = to_remove(data, whiten_k, whiten_mean)
|
||||||
|
|
||||||
|
data -= to_remove_data
|
||||||
|
del to_remove_data
|
||||||
|
|
||||||
|
print("apply temporal filter")
|
||||||
|
data = temporal_filter(
|
||||||
|
data,
|
||||||
|
device=torch_device,
|
||||||
|
orig_freq=orig_freq,
|
||||||
|
new_freq=new_freq,
|
||||||
|
filtfilt_chuck_size=filtfilt_chuck_size,
|
||||||
|
bp_low_frequency=bp_low_frequency,
|
||||||
|
bp_high_frequency=bp_high_frequency,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("SVD Denosing")
|
||||||
|
data_out = svd_denoise(data, window_size=window_size)
|
||||||
|
|
||||||
|
print("Pooling")
|
||||||
|
avage_pooling = torch.nn.AvgPool2d(
|
||||||
|
kernel_size=(kernel_size_pooling, kernel_size_pooling),
|
||||||
|
stride=(kernel_size_pooling, kernel_size_pooling),
|
||||||
|
)
|
||||||
|
data_out = avage_pooling(data_out)
|
||||||
|
|
||||||
|
np.save(filename + str("_decorrelated.npy"), data_out.cpu())
|
204
svd.py
Normal file
204
svd.py
Normal file
|
@ -0,0 +1,204 @@
|
||||||
|
import torch
|
||||||
|
import torchaudio as ta
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import trange
|
||||||
|
|
||||||
|
|
||||||
|
def convert_avi_to_npy(filename: str) -> None:
|
||||||
|
capture_from_file = cv2.VideoCapture(filename + str(".avi"))
|
||||||
|
avi_length = int(capture_from_file.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
|
||||||
|
# To torch and beyond
|
||||||
|
data: np.ndarray | None = None
|
||||||
|
for i in trange(0, avi_length):
|
||||||
|
read_ok, frame = capture_from_file.read()
|
||||||
|
|
||||||
|
assert read_ok
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
data = np.empty(
|
||||||
|
(avi_length, frame.shape[0], frame.shape[1]),
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
assert data is not None
|
||||||
|
data[i, :, :] = frame.mean(axis=-1).astype(np.float32)
|
||||||
|
assert data is not None
|
||||||
|
np.save(filename + str(".npy"), data)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def to_remove(
|
||||||
|
data: torch.Tensor, whiten_k: torch.Tensor, whiten_mean: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
whiten_mean = whiten_mean
|
||||||
|
whiten_k = whiten_k[:, :, 0]
|
||||||
|
|
||||||
|
data = (data - whiten_mean.unsqueeze(0)) * whiten_k.unsqueeze(0)
|
||||||
|
data_svd = data.sum(dim=-1).sum(dim=-1).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
|
||||||
|
factor = (data * data_svd).sum(dim=0, keepdim=True) / (data_svd**2).sum(
|
||||||
|
dim=0, keepdim=True
|
||||||
|
)
|
||||||
|
to_remove = data_svd * factor
|
||||||
|
to_remove /= whiten_k.unsqueeze(0) + 1e-20
|
||||||
|
to_remove += whiten_mean.unsqueeze(0)
|
||||||
|
|
||||||
|
return to_remove
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def calculate_svd(
|
||||||
|
input: torch.Tensor, lowrank_q: int = 6
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
selection = torch.flatten(
|
||||||
|
input.clone().movedim(0, -1),
|
||||||
|
start_dim=0,
|
||||||
|
end_dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
whiten_mean = torch.mean(selection, dim=-1)
|
||||||
|
selection -= whiten_mean.unsqueeze(-1)
|
||||||
|
whiten_mean = whiten_mean.reshape((input.shape[1], input.shape[2]))
|
||||||
|
|
||||||
|
svd_u, svd_s, _ = torch.svd_lowrank(selection, q=lowrank_q)
|
||||||
|
|
||||||
|
whiten_k = (
|
||||||
|
torch.sign(svd_u[0, :]).unsqueeze(0) * svd_u / (svd_s.unsqueeze(0) + 1e-20)
|
||||||
|
)
|
||||||
|
whiten_k = whiten_k.reshape((input.shape[1], input.shape[2], svd_s.shape[0]))
|
||||||
|
eigenvalues = svd_s
|
||||||
|
|
||||||
|
return whiten_mean, whiten_k, eigenvalues
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def filtfilt(
|
||||||
|
input: torch.Tensor,
|
||||||
|
butter_a: torch.Tensor,
|
||||||
|
butter_b: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert butter_a.ndim == 1
|
||||||
|
assert butter_b.ndim == 1
|
||||||
|
assert butter_a.shape[0] == butter_b.shape[0]
|
||||||
|
|
||||||
|
process_data: torch.Tensor = input.movedim(0, -1).detach().clone()
|
||||||
|
|
||||||
|
padding_length = 12 * int(butter_a.shape[0])
|
||||||
|
left_padding = 2 * process_data[..., 0].unsqueeze(-1) - process_data[
|
||||||
|
..., 1 : padding_length + 1
|
||||||
|
].flip(-1)
|
||||||
|
right_padding = 2 * process_data[..., -1].unsqueeze(-1) - process_data[
|
||||||
|
..., -(padding_length + 1) : -1
|
||||||
|
].flip(-1)
|
||||||
|
process_data_padded = torch.cat((left_padding, process_data, right_padding), dim=-1)
|
||||||
|
|
||||||
|
output = ta.functional.filtfilt(
|
||||||
|
process_data_padded.unsqueeze(0), butter_a, butter_b, clamp=False
|
||||||
|
).squeeze(0)
|
||||||
|
output = output[..., padding_length:-padding_length].movedim(-1, 0)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def butter_bandpass(
|
||||||
|
device: torch.device,
|
||||||
|
low_frequency: float = 0.1,
|
||||||
|
high_frequency: float = 1.0,
|
||||||
|
fs: float = 30.0,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
import scipy
|
||||||
|
|
||||||
|
butter_b_np, butter_a_np = scipy.signal.butter(
|
||||||
|
4, [low_frequency, high_frequency], btype="bandpass", output="ba", fs=fs
|
||||||
|
)
|
||||||
|
butter_a = torch.tensor(butter_a_np, device=device, dtype=torch.float32)
|
||||||
|
butter_b = torch.tensor(butter_b_np, device=device, dtype=torch.float32)
|
||||||
|
return butter_a, butter_b
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def chunk_iterator(array: torch.Tensor, chunk_size: int):
|
||||||
|
for i in range(0, array.shape[0], chunk_size):
|
||||||
|
yield array[i : i + chunk_size]
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def lowpass(
|
||||||
|
data: torch.Tensor,
|
||||||
|
device: torch.device,
|
||||||
|
low_frequency: float = 0.1,
|
||||||
|
high_frequency: float = 1.0,
|
||||||
|
fs=30.0,
|
||||||
|
filtfilt_chuck_size: int = 10,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
butter_a, butter_b = butter_bandpass(
|
||||||
|
device=device,
|
||||||
|
low_frequency=low_frequency,
|
||||||
|
high_frequency=high_frequency,
|
||||||
|
fs=fs,
|
||||||
|
)
|
||||||
|
|
||||||
|
index_full_dataset: torch.Tensor = torch.arange(
|
||||||
|
0, data.shape[1], device=device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in chunk_iterator(index_full_dataset, filtfilt_chuck_size):
|
||||||
|
temp_filtfilt = filtfilt(
|
||||||
|
data[:, chunk, :],
|
||||||
|
butter_a=butter_a,
|
||||||
|
butter_b=butter_b,
|
||||||
|
)
|
||||||
|
data[:, chunk, :] = temp_filtfilt
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def temporal_filter(
|
||||||
|
data: torch.Tensor,
|
||||||
|
device: torch.device,
|
||||||
|
orig_freq: int = 30,
|
||||||
|
new_freq: int = 3,
|
||||||
|
filtfilt_chuck_size: int = 10,
|
||||||
|
bp_low_frequency: float = 0.1,
|
||||||
|
bp_high_frequency: float = 1.0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
data = ta.functional.resample(
|
||||||
|
data.movedim(0, -1), orig_freq=orig_freq, new_freq=new_freq
|
||||||
|
).movedim(-1, 0)
|
||||||
|
|
||||||
|
data = lowpass(
|
||||||
|
data,
|
||||||
|
device=device,
|
||||||
|
low_frequency=bp_low_frequency,
|
||||||
|
high_frequency=bp_high_frequency,
|
||||||
|
fs=float(new_freq),
|
||||||
|
filtfilt_chuck_size=filtfilt_chuck_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def svd_denoise(data: torch.Tensor, window_size: int) -> torch.Tensor:
|
||||||
|
data_out = torch.zeros_like(data)
|
||||||
|
|
||||||
|
for x in trange(0, data.shape[1]):
|
||||||
|
for y in range(0, data.shape[2]):
|
||||||
|
if (
|
||||||
|
((x - window_size) > 0)
|
||||||
|
and ((y - window_size) > 0)
|
||||||
|
and ((x + window_size) <= data.shape[1])
|
||||||
|
and ((y + window_size) <= data.shape[2])
|
||||||
|
):
|
||||||
|
data_sel: torch.Tensor = data[
|
||||||
|
:,
|
||||||
|
x - window_size : x + window_size + 1,
|
||||||
|
y - window_size : y + window_size + 1,
|
||||||
|
]
|
||||||
|
|
||||||
|
whiten_mean, whiten_k, eigenvalues = calculate_svd(data_sel.clone())
|
||||||
|
to_remove_data = to_remove(data_sel, whiten_k, whiten_mean)
|
||||||
|
data_out[:, x, y] = to_remove_data[:, window_size, window_size]
|
||||||
|
return data_out
|
Loading…
Reference in a new issue