Add files via upload

This commit is contained in:
David Rotermund 2023-07-12 14:02:35 +02:00 committed by GitHub
parent 256be3a3c7
commit b37a79f487
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 1368 additions and 0 deletions

7
0_convert_avi_to_npy.py Normal file
View file

@ -0,0 +1,7 @@
from svd import convert_avi_to_npy
if __name__ == "__main__":
# Convert from avi to npy
filename: str = "example_data_crop"
convert_avi_to_npy(filename)

90
Anime.py Normal file
View file

@ -0,0 +1,90 @@
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.animation
class Anime:
def __init__(self) -> None:
super().__init__()
def show(
self,
input: torch.Tensor | np.ndarray,
mask: torch.Tensor | np.ndarray | None = None,
vmin: float | None = None,
vmax: float | None = None,
cmap: str = "hot",
axis_off: bool = True,
show_frame_count: bool = True,
interval: int = 100,
repeat: bool = False,
colorbar: bool = True,
vmin_scale: float | None = None,
vmax_scale: float | None = None,
) -> None:
assert input.ndim == 3
if isinstance(input, torch.Tensor):
input_np: np.ndarray = input.cpu().numpy()
if mask is not None:
mask_np: np.ndarray | None = (mask == 0).cpu().numpy()
else:
mask_np = None
else:
input_np = input
if mask is not None:
mask_np = mask == 0 # type: ignore
else:
mask_np = None
if vmin is None:
vmin = float(np.where(np.isfinite(input_np), input_np, 0.0).min())
if vmax is None:
vmax = float(np.where(np.isfinite(input_np), input_np, 0.0).max())
if vmin_scale is not None:
vmin *= vmin_scale
if vmax_scale is not None:
vmax *= vmax_scale
fig = plt.figure()
image = np.nan_to_num(input_np[0, ...], copy=True, nan=0.0)
if mask_np is not None:
image[mask_np] = float("NaN")
image_handle = plt.imshow(
image,
cmap=cmap,
vmin=vmin,
vmax=vmax,
)
if colorbar:
plt.colorbar()
if axis_off:
plt.axis("off")
def next_frame(i: int) -> None:
image = np.nan_to_num(input_np[i, ...], copy=True, nan=0.0)
if mask_np is not None:
image[mask_np] = float("NaN")
image_handle.set_data(image)
if show_frame_count:
bar_length: int = 10
filled_length = int(round(bar_length * i / input_np.shape[0]))
bar = "\u25A0" * filled_length + "\u25A1" * (bar_length - filled_length)
plt.title(f"{bar} {i} of {int(input_np.shape[0]-1)}", loc="left")
return
_ = matplotlib.animation.FuncAnimation(
fig,
next_frame,
frames=int(input.shape[0]),
interval=interval,
repeat=repeat,
)
plt.show()

1010
ImageAlignment.py Normal file

File diff suppressed because it is too large Load diff

57
run_svd.py Normal file
View file

@ -0,0 +1,57 @@
import torch
import numpy as np
from svd import calculate_svd, to_remove, temporal_filter, svd_denoise
if __name__ == "__main__":
filename: str = "example_data_crop"
window_size: int = 2
kernel_size_pooling: int = 2
orig_freq: int = 30
new_freq: int = 3
filtfilt_chuck_size: int = 10
bp_low_frequency: float = 0.1
bp_high_frequency: float = 1.0
torch_device: torch.device = torch.device(
"cuda:0" if torch.cuda.is_available() else "cpu"
)
print("Load data")
input = np.load(filename + str(".npy"))
data = torch.tensor(input, device=torch_device)
print("Movement compensation [MISSING!!!!]")
print("(include ImageAlignment.py into processing chain)")
print("SVD")
whiten_mean, whiten_k, eigenvalues = calculate_svd(data)
print("Calculate to_remove")
data = torch.tensor(input, device=torch_device)
to_remove_data = to_remove(data, whiten_k, whiten_mean)
data -= to_remove_data
del to_remove_data
print("apply temporal filter")
data = temporal_filter(
data,
device=torch_device,
orig_freq=orig_freq,
new_freq=new_freq,
filtfilt_chuck_size=filtfilt_chuck_size,
bp_low_frequency=bp_low_frequency,
bp_high_frequency=bp_high_frequency,
)
print("SVD Denosing")
data_out = svd_denoise(data, window_size=window_size)
print("Pooling")
avage_pooling = torch.nn.AvgPool2d(
kernel_size=(kernel_size_pooling, kernel_size_pooling),
stride=(kernel_size_pooling, kernel_size_pooling),
)
data_out = avage_pooling(data_out)
np.save(filename + str("_decorrelated.npy"), data_out.cpu())

204
svd.py Normal file
View file

@ -0,0 +1,204 @@
import torch
import torchaudio as ta
import cv2
import numpy as np
from tqdm import trange
def convert_avi_to_npy(filename: str) -> None:
capture_from_file = cv2.VideoCapture(filename + str(".avi"))
avi_length = int(capture_from_file.get(cv2.CAP_PROP_FRAME_COUNT))
# To torch and beyond
data: np.ndarray | None = None
for i in trange(0, avi_length):
read_ok, frame = capture_from_file.read()
assert read_ok
if data is None:
data = np.empty(
(avi_length, frame.shape[0], frame.shape[1]),
dtype=np.float32,
)
assert data is not None
data[i, :, :] = frame.mean(axis=-1).astype(np.float32)
assert data is not None
np.save(filename + str(".npy"), data)
@torch.no_grad()
def to_remove(
data: torch.Tensor, whiten_k: torch.Tensor, whiten_mean: torch.Tensor
) -> torch.Tensor:
whiten_mean = whiten_mean
whiten_k = whiten_k[:, :, 0]
data = (data - whiten_mean.unsqueeze(0)) * whiten_k.unsqueeze(0)
data_svd = data.sum(dim=-1).sum(dim=-1).unsqueeze(-1).unsqueeze(-1)
factor = (data * data_svd).sum(dim=0, keepdim=True) / (data_svd**2).sum(
dim=0, keepdim=True
)
to_remove = data_svd * factor
to_remove /= whiten_k.unsqueeze(0) + 1e-20
to_remove += whiten_mean.unsqueeze(0)
return to_remove
@torch.no_grad()
def calculate_svd(
input: torch.Tensor, lowrank_q: int = 6
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
selection = torch.flatten(
input.clone().movedim(0, -1),
start_dim=0,
end_dim=1,
)
whiten_mean = torch.mean(selection, dim=-1)
selection -= whiten_mean.unsqueeze(-1)
whiten_mean = whiten_mean.reshape((input.shape[1], input.shape[2]))
svd_u, svd_s, _ = torch.svd_lowrank(selection, q=lowrank_q)
whiten_k = (
torch.sign(svd_u[0, :]).unsqueeze(0) * svd_u / (svd_s.unsqueeze(0) + 1e-20)
)
whiten_k = whiten_k.reshape((input.shape[1], input.shape[2], svd_s.shape[0]))
eigenvalues = svd_s
return whiten_mean, whiten_k, eigenvalues
@torch.no_grad()
def filtfilt(
input: torch.Tensor,
butter_a: torch.Tensor,
butter_b: torch.Tensor,
) -> torch.Tensor:
assert butter_a.ndim == 1
assert butter_b.ndim == 1
assert butter_a.shape[0] == butter_b.shape[0]
process_data: torch.Tensor = input.movedim(0, -1).detach().clone()
padding_length = 12 * int(butter_a.shape[0])
left_padding = 2 * process_data[..., 0].unsqueeze(-1) - process_data[
..., 1 : padding_length + 1
].flip(-1)
right_padding = 2 * process_data[..., -1].unsqueeze(-1) - process_data[
..., -(padding_length + 1) : -1
].flip(-1)
process_data_padded = torch.cat((left_padding, process_data, right_padding), dim=-1)
output = ta.functional.filtfilt(
process_data_padded.unsqueeze(0), butter_a, butter_b, clamp=False
).squeeze(0)
output = output[..., padding_length:-padding_length].movedim(-1, 0)
return output
@torch.no_grad()
def butter_bandpass(
device: torch.device,
low_frequency: float = 0.1,
high_frequency: float = 1.0,
fs: float = 30.0,
) -> tuple[torch.Tensor, torch.Tensor]:
import scipy
butter_b_np, butter_a_np = scipy.signal.butter(
4, [low_frequency, high_frequency], btype="bandpass", output="ba", fs=fs
)
butter_a = torch.tensor(butter_a_np, device=device, dtype=torch.float32)
butter_b = torch.tensor(butter_b_np, device=device, dtype=torch.float32)
return butter_a, butter_b
@torch.no_grad()
def chunk_iterator(array: torch.Tensor, chunk_size: int):
for i in range(0, array.shape[0], chunk_size):
yield array[i : i + chunk_size]
@torch.no_grad()
def lowpass(
data: torch.Tensor,
device: torch.device,
low_frequency: float = 0.1,
high_frequency: float = 1.0,
fs=30.0,
filtfilt_chuck_size: int = 10,
) -> torch.Tensor:
butter_a, butter_b = butter_bandpass(
device=device,
low_frequency=low_frequency,
high_frequency=high_frequency,
fs=fs,
)
index_full_dataset: torch.Tensor = torch.arange(
0, data.shape[1], device=device, dtype=torch.int64
)
for chunk in chunk_iterator(index_full_dataset, filtfilt_chuck_size):
temp_filtfilt = filtfilt(
data[:, chunk, :],
butter_a=butter_a,
butter_b=butter_b,
)
data[:, chunk, :] = temp_filtfilt
return data
@torch.no_grad()
def temporal_filter(
data: torch.Tensor,
device: torch.device,
orig_freq: int = 30,
new_freq: int = 3,
filtfilt_chuck_size: int = 10,
bp_low_frequency: float = 0.1,
bp_high_frequency: float = 1.0,
) -> torch.Tensor:
data = ta.functional.resample(
data.movedim(0, -1), orig_freq=orig_freq, new_freq=new_freq
).movedim(-1, 0)
data = lowpass(
data,
device=device,
low_frequency=bp_low_frequency,
high_frequency=bp_high_frequency,
fs=float(new_freq),
filtfilt_chuck_size=filtfilt_chuck_size,
)
return data
@torch.no_grad()
def svd_denoise(data: torch.Tensor, window_size: int) -> torch.Tensor:
data_out = torch.zeros_like(data)
for x in trange(0, data.shape[1]):
for y in range(0, data.shape[2]):
if (
((x - window_size) > 0)
and ((y - window_size) > 0)
and ((x + window_size) <= data.shape[1])
and ((y + window_size) <= data.shape[2])
):
data_sel: torch.Tensor = data[
:,
x - window_size : x + window_size + 1,
y - window_size : y + window_size + 1,
]
whiten_mean, whiten_k, eigenvalues = calculate_svd(data_sel.clone())
to_remove_data = to_remove(data_sel, whiten_k, whiten_mean)
data_out[:, x, y] = to_remove_data[:, window_size, window_size]
return data_out