Add files via upload
This commit is contained in:
parent
256be3a3c7
commit
b37a79f487
5 changed files with 1368 additions and 0 deletions
7
0_convert_avi_to_npy.py
Normal file
7
0_convert_avi_to_npy.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
from svd import convert_avi_to_npy
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Convert from avi to npy
|
||||
filename: str = "example_data_crop"
|
||||
convert_avi_to_npy(filename)
|
90
Anime.py
Normal file
90
Anime.py
Normal file
|
@ -0,0 +1,90 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.animation
|
||||
|
||||
|
||||
class Anime:
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def show(
|
||||
self,
|
||||
input: torch.Tensor | np.ndarray,
|
||||
mask: torch.Tensor | np.ndarray | None = None,
|
||||
vmin: float | None = None,
|
||||
vmax: float | None = None,
|
||||
cmap: str = "hot",
|
||||
axis_off: bool = True,
|
||||
show_frame_count: bool = True,
|
||||
interval: int = 100,
|
||||
repeat: bool = False,
|
||||
colorbar: bool = True,
|
||||
vmin_scale: float | None = None,
|
||||
vmax_scale: float | None = None,
|
||||
) -> None:
|
||||
assert input.ndim == 3
|
||||
|
||||
if isinstance(input, torch.Tensor):
|
||||
input_np: np.ndarray = input.cpu().numpy()
|
||||
if mask is not None:
|
||||
mask_np: np.ndarray | None = (mask == 0).cpu().numpy()
|
||||
else:
|
||||
mask_np = None
|
||||
else:
|
||||
input_np = input
|
||||
if mask is not None:
|
||||
mask_np = mask == 0 # type: ignore
|
||||
else:
|
||||
mask_np = None
|
||||
|
||||
if vmin is None:
|
||||
vmin = float(np.where(np.isfinite(input_np), input_np, 0.0).min())
|
||||
if vmax is None:
|
||||
vmax = float(np.where(np.isfinite(input_np), input_np, 0.0).max())
|
||||
|
||||
if vmin_scale is not None:
|
||||
vmin *= vmin_scale
|
||||
|
||||
if vmax_scale is not None:
|
||||
vmax *= vmax_scale
|
||||
|
||||
fig = plt.figure()
|
||||
image = np.nan_to_num(input_np[0, ...], copy=True, nan=0.0)
|
||||
if mask_np is not None:
|
||||
image[mask_np] = float("NaN")
|
||||
image_handle = plt.imshow(
|
||||
image,
|
||||
cmap=cmap,
|
||||
vmin=vmin,
|
||||
vmax=vmax,
|
||||
)
|
||||
|
||||
if colorbar:
|
||||
plt.colorbar()
|
||||
|
||||
if axis_off:
|
||||
plt.axis("off")
|
||||
|
||||
def next_frame(i: int) -> None:
|
||||
image = np.nan_to_num(input_np[i, ...], copy=True, nan=0.0)
|
||||
if mask_np is not None:
|
||||
image[mask_np] = float("NaN")
|
||||
|
||||
image_handle.set_data(image)
|
||||
if show_frame_count:
|
||||
bar_length: int = 10
|
||||
filled_length = int(round(bar_length * i / input_np.shape[0]))
|
||||
bar = "\u25A0" * filled_length + "\u25A1" * (bar_length - filled_length)
|
||||
plt.title(f"{bar} {i} of {int(input_np.shape[0]-1)}", loc="left")
|
||||
return
|
||||
|
||||
_ = matplotlib.animation.FuncAnimation(
|
||||
fig,
|
||||
next_frame,
|
||||
frames=int(input.shape[0]),
|
||||
interval=interval,
|
||||
repeat=repeat,
|
||||
)
|
||||
|
||||
plt.show()
|
1010
ImageAlignment.py
Normal file
1010
ImageAlignment.py
Normal file
File diff suppressed because it is too large
Load diff
57
run_svd.py
Normal file
57
run_svd.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from svd import calculate_svd, to_remove, temporal_filter, svd_denoise
|
||||
|
||||
if __name__ == "__main__":
|
||||
filename: str = "example_data_crop"
|
||||
window_size: int = 2
|
||||
kernel_size_pooling: int = 2
|
||||
orig_freq: int = 30
|
||||
new_freq: int = 3
|
||||
filtfilt_chuck_size: int = 10
|
||||
bp_low_frequency: float = 0.1
|
||||
bp_high_frequency: float = 1.0
|
||||
|
||||
torch_device: torch.device = torch.device(
|
||||
"cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
print("Load data")
|
||||
input = np.load(filename + str(".npy"))
|
||||
data = torch.tensor(input, device=torch_device)
|
||||
|
||||
print("Movement compensation [MISSING!!!!]")
|
||||
print("(include ImageAlignment.py into processing chain)")
|
||||
|
||||
print("SVD")
|
||||
whiten_mean, whiten_k, eigenvalues = calculate_svd(data)
|
||||
|
||||
print("Calculate to_remove")
|
||||
data = torch.tensor(input, device=torch_device)
|
||||
to_remove_data = to_remove(data, whiten_k, whiten_mean)
|
||||
|
||||
data -= to_remove_data
|
||||
del to_remove_data
|
||||
|
||||
print("apply temporal filter")
|
||||
data = temporal_filter(
|
||||
data,
|
||||
device=torch_device,
|
||||
orig_freq=orig_freq,
|
||||
new_freq=new_freq,
|
||||
filtfilt_chuck_size=filtfilt_chuck_size,
|
||||
bp_low_frequency=bp_low_frequency,
|
||||
bp_high_frequency=bp_high_frequency,
|
||||
)
|
||||
|
||||
print("SVD Denosing")
|
||||
data_out = svd_denoise(data, window_size=window_size)
|
||||
|
||||
print("Pooling")
|
||||
avage_pooling = torch.nn.AvgPool2d(
|
||||
kernel_size=(kernel_size_pooling, kernel_size_pooling),
|
||||
stride=(kernel_size_pooling, kernel_size_pooling),
|
||||
)
|
||||
data_out = avage_pooling(data_out)
|
||||
|
||||
np.save(filename + str("_decorrelated.npy"), data_out.cpu())
|
204
svd.py
Normal file
204
svd.py
Normal file
|
@ -0,0 +1,204 @@
|
|||
import torch
|
||||
import torchaudio as ta
|
||||
import cv2
|
||||
import numpy as np
|
||||
from tqdm import trange
|
||||
|
||||
|
||||
def convert_avi_to_npy(filename: str) -> None:
|
||||
capture_from_file = cv2.VideoCapture(filename + str(".avi"))
|
||||
avi_length = int(capture_from_file.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
# To torch and beyond
|
||||
data: np.ndarray | None = None
|
||||
for i in trange(0, avi_length):
|
||||
read_ok, frame = capture_from_file.read()
|
||||
|
||||
assert read_ok
|
||||
|
||||
if data is None:
|
||||
data = np.empty(
|
||||
(avi_length, frame.shape[0], frame.shape[1]),
|
||||
dtype=np.float32,
|
||||
)
|
||||
assert data is not None
|
||||
data[i, :, :] = frame.mean(axis=-1).astype(np.float32)
|
||||
assert data is not None
|
||||
np.save(filename + str(".npy"), data)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def to_remove(
|
||||
data: torch.Tensor, whiten_k: torch.Tensor, whiten_mean: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
whiten_mean = whiten_mean
|
||||
whiten_k = whiten_k[:, :, 0]
|
||||
|
||||
data = (data - whiten_mean.unsqueeze(0)) * whiten_k.unsqueeze(0)
|
||||
data_svd = data.sum(dim=-1).sum(dim=-1).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
factor = (data * data_svd).sum(dim=0, keepdim=True) / (data_svd**2).sum(
|
||||
dim=0, keepdim=True
|
||||
)
|
||||
to_remove = data_svd * factor
|
||||
to_remove /= whiten_k.unsqueeze(0) + 1e-20
|
||||
to_remove += whiten_mean.unsqueeze(0)
|
||||
|
||||
return to_remove
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def calculate_svd(
|
||||
input: torch.Tensor, lowrank_q: int = 6
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
selection = torch.flatten(
|
||||
input.clone().movedim(0, -1),
|
||||
start_dim=0,
|
||||
end_dim=1,
|
||||
)
|
||||
|
||||
whiten_mean = torch.mean(selection, dim=-1)
|
||||
selection -= whiten_mean.unsqueeze(-1)
|
||||
whiten_mean = whiten_mean.reshape((input.shape[1], input.shape[2]))
|
||||
|
||||
svd_u, svd_s, _ = torch.svd_lowrank(selection, q=lowrank_q)
|
||||
|
||||
whiten_k = (
|
||||
torch.sign(svd_u[0, :]).unsqueeze(0) * svd_u / (svd_s.unsqueeze(0) + 1e-20)
|
||||
)
|
||||
whiten_k = whiten_k.reshape((input.shape[1], input.shape[2], svd_s.shape[0]))
|
||||
eigenvalues = svd_s
|
||||
|
||||
return whiten_mean, whiten_k, eigenvalues
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def filtfilt(
|
||||
input: torch.Tensor,
|
||||
butter_a: torch.Tensor,
|
||||
butter_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert butter_a.ndim == 1
|
||||
assert butter_b.ndim == 1
|
||||
assert butter_a.shape[0] == butter_b.shape[0]
|
||||
|
||||
process_data: torch.Tensor = input.movedim(0, -1).detach().clone()
|
||||
|
||||
padding_length = 12 * int(butter_a.shape[0])
|
||||
left_padding = 2 * process_data[..., 0].unsqueeze(-1) - process_data[
|
||||
..., 1 : padding_length + 1
|
||||
].flip(-1)
|
||||
right_padding = 2 * process_data[..., -1].unsqueeze(-1) - process_data[
|
||||
..., -(padding_length + 1) : -1
|
||||
].flip(-1)
|
||||
process_data_padded = torch.cat((left_padding, process_data, right_padding), dim=-1)
|
||||
|
||||
output = ta.functional.filtfilt(
|
||||
process_data_padded.unsqueeze(0), butter_a, butter_b, clamp=False
|
||||
).squeeze(0)
|
||||
output = output[..., padding_length:-padding_length].movedim(-1, 0)
|
||||
return output
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def butter_bandpass(
|
||||
device: torch.device,
|
||||
low_frequency: float = 0.1,
|
||||
high_frequency: float = 1.0,
|
||||
fs: float = 30.0,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
import scipy
|
||||
|
||||
butter_b_np, butter_a_np = scipy.signal.butter(
|
||||
4, [low_frequency, high_frequency], btype="bandpass", output="ba", fs=fs
|
||||
)
|
||||
butter_a = torch.tensor(butter_a_np, device=device, dtype=torch.float32)
|
||||
butter_b = torch.tensor(butter_b_np, device=device, dtype=torch.float32)
|
||||
return butter_a, butter_b
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def chunk_iterator(array: torch.Tensor, chunk_size: int):
|
||||
for i in range(0, array.shape[0], chunk_size):
|
||||
yield array[i : i + chunk_size]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def lowpass(
|
||||
data: torch.Tensor,
|
||||
device: torch.device,
|
||||
low_frequency: float = 0.1,
|
||||
high_frequency: float = 1.0,
|
||||
fs=30.0,
|
||||
filtfilt_chuck_size: int = 10,
|
||||
) -> torch.Tensor:
|
||||
butter_a, butter_b = butter_bandpass(
|
||||
device=device,
|
||||
low_frequency=low_frequency,
|
||||
high_frequency=high_frequency,
|
||||
fs=fs,
|
||||
)
|
||||
|
||||
index_full_dataset: torch.Tensor = torch.arange(
|
||||
0, data.shape[1], device=device, dtype=torch.int64
|
||||
)
|
||||
|
||||
for chunk in chunk_iterator(index_full_dataset, filtfilt_chuck_size):
|
||||
temp_filtfilt = filtfilt(
|
||||
data[:, chunk, :],
|
||||
butter_a=butter_a,
|
||||
butter_b=butter_b,
|
||||
)
|
||||
data[:, chunk, :] = temp_filtfilt
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def temporal_filter(
|
||||
data: torch.Tensor,
|
||||
device: torch.device,
|
||||
orig_freq: int = 30,
|
||||
new_freq: int = 3,
|
||||
filtfilt_chuck_size: int = 10,
|
||||
bp_low_frequency: float = 0.1,
|
||||
bp_high_frequency: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
data = ta.functional.resample(
|
||||
data.movedim(0, -1), orig_freq=orig_freq, new_freq=new_freq
|
||||
).movedim(-1, 0)
|
||||
|
||||
data = lowpass(
|
||||
data,
|
||||
device=device,
|
||||
low_frequency=bp_low_frequency,
|
||||
high_frequency=bp_high_frequency,
|
||||
fs=float(new_freq),
|
||||
filtfilt_chuck_size=filtfilt_chuck_size,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def svd_denoise(data: torch.Tensor, window_size: int) -> torch.Tensor:
|
||||
data_out = torch.zeros_like(data)
|
||||
|
||||
for x in trange(0, data.shape[1]):
|
||||
for y in range(0, data.shape[2]):
|
||||
if (
|
||||
((x - window_size) > 0)
|
||||
and ((y - window_size) > 0)
|
||||
and ((x + window_size) <= data.shape[1])
|
||||
and ((y + window_size) <= data.shape[2])
|
||||
):
|
||||
data_sel: torch.Tensor = data[
|
||||
:,
|
||||
x - window_size : x + window_size + 1,
|
||||
y - window_size : y + window_size + 1,
|
||||
]
|
||||
|
||||
whiten_mean, whiten_k, eigenvalues = calculate_svd(data_sel.clone())
|
||||
to_remove_data = to_remove(data_sel, whiten_k, whiten_mean)
|
||||
data_out[:, x, y] = to_remove_data[:, window_size, window_size]
|
||||
return data_out
|
Loading…
Reference in a new issue