2023-07-12 14:02:35 +02:00
|
|
|
import torch
|
|
|
|
import numpy as np
|
2023-07-13 11:05:07 +02:00
|
|
|
import os
|
|
|
|
|
|
|
|
import torchvision as tv
|
|
|
|
|
|
|
|
from svd import (
|
|
|
|
calculate_svd,
|
|
|
|
to_remove,
|
|
|
|
temporal_filter,
|
|
|
|
svd_denoise,
|
|
|
|
convert_avi_to_npy,
|
|
|
|
calculate_translation,
|
|
|
|
)
|
|
|
|
|
|
|
|
from ImageAlignment import ImageAlignment
|
|
|
|
|
2023-07-12 14:02:35 +02:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
filename: str = "example_data_crop"
|
|
|
|
window_size: int = 2
|
|
|
|
kernel_size_pooling: int = 2
|
|
|
|
orig_freq: int = 30
|
|
|
|
new_freq: int = 3
|
|
|
|
filtfilt_chuck_size: int = 10
|
|
|
|
bp_low_frequency: float = 0.1
|
|
|
|
bp_high_frequency: float = 1.0
|
|
|
|
|
2023-07-13 11:05:07 +02:00
|
|
|
convert_overwrite: bool | None = None
|
|
|
|
|
|
|
|
fill_value: float = 0.0
|
|
|
|
|
2023-07-12 14:02:35 +02:00
|
|
|
torch_device: torch.device = torch.device(
|
|
|
|
"cuda:0" if torch.cuda.is_available() else "cpu"
|
|
|
|
)
|
|
|
|
|
2023-07-13 11:05:07 +02:00
|
|
|
if (
|
|
|
|
(convert_overwrite is None)
|
|
|
|
and (os.path.isfile("example_data_crop" + ".npy") is False)
|
|
|
|
) or (convert_overwrite):
|
|
|
|
print("Convert AVI file to npy file.")
|
|
|
|
convert_avi_to_npy(filename)
|
|
|
|
print("--==-- DONE --==--")
|
2023-07-12 14:02:35 +02:00
|
|
|
|
2023-07-13 11:05:07 +02:00
|
|
|
with torch.no_grad():
|
|
|
|
print("Load data")
|
|
|
|
input = np.load(filename + str(".npy"))
|
|
|
|
data = torch.tensor(input, device=torch_device)
|
2023-07-12 14:02:35 +02:00
|
|
|
|
2023-07-13 11:05:07 +02:00
|
|
|
print("Movement compensation [BROKEN!!!!]")
|
|
|
|
print("During development, information about what could move was missing.")
|
|
|
|
print("Thus the preprocessing before shift determination may not work.")
|
|
|
|
data -= data.min(dim=0)[0]
|
|
|
|
data /= data.std(dim=0, keepdim=True) + 1e-20
|
2023-07-12 14:02:35 +02:00
|
|
|
|
2023-07-13 11:05:07 +02:00
|
|
|
image_alignment = ImageAlignment(
|
|
|
|
default_dtype=torch.float32, device=torch_device
|
|
|
|
)
|
2023-07-12 14:02:35 +02:00
|
|
|
|
2023-07-13 11:05:07 +02:00
|
|
|
tvec = calculate_translation(
|
|
|
|
input=data,
|
|
|
|
reference_image=data[0, ...].clone(),
|
|
|
|
image_alignment=image_alignment,
|
|
|
|
)
|
|
|
|
tvec_media = tvec.median(dim=0)[0]
|
|
|
|
print(f"Median of movement: {tvec_media[0]}, {tvec_media[1]}")
|
2023-07-12 14:02:35 +02:00
|
|
|
|
2023-07-13 11:05:07 +02:00
|
|
|
data = torch.tensor(input, device=torch_device)
|
2023-07-12 14:02:35 +02:00
|
|
|
|
2023-07-13 11:05:07 +02:00
|
|
|
for id in range(0, data.shape[0]):
|
|
|
|
data[id, ...] = tv.transforms.functional.affine(
|
|
|
|
img=data[id, ...].unsqueeze(0),
|
|
|
|
angle=0,
|
|
|
|
translate=[tvec[id, 1], tvec[id, 0]],
|
|
|
|
scale=1.0,
|
|
|
|
shear=0,
|
|
|
|
fill=fill_value,
|
|
|
|
).squeeze(0)
|
2023-07-12 14:02:35 +02:00
|
|
|
|
2023-07-13 11:05:07 +02:00
|
|
|
print("SVD")
|
|
|
|
whiten_mean, whiten_k, eigenvalues = calculate_svd(data)
|
|
|
|
|
|
|
|
print("Calculate to_remove")
|
|
|
|
data = torch.tensor(input, device=torch_device)
|
|
|
|
to_remove_data = to_remove(data, whiten_k, whiten_mean)
|
|
|
|
|
|
|
|
data -= to_remove_data
|
|
|
|
del to_remove_data
|
|
|
|
|
|
|
|
print("apply temporal filter")
|
|
|
|
data = temporal_filter(
|
|
|
|
data,
|
|
|
|
device=torch_device,
|
|
|
|
orig_freq=orig_freq,
|
|
|
|
new_freq=new_freq,
|
|
|
|
filtfilt_chuck_size=filtfilt_chuck_size,
|
|
|
|
bp_low_frequency=bp_low_frequency,
|
|
|
|
bp_high_frequency=bp_high_frequency,
|
|
|
|
)
|
|
|
|
|
|
|
|
print("SVD Denosing")
|
|
|
|
data_out = svd_denoise(data, window_size=window_size)
|
|
|
|
|
|
|
|
print("Pooling")
|
|
|
|
avage_pooling = torch.nn.AvgPool2d(
|
|
|
|
kernel_size=(kernel_size_pooling, kernel_size_pooling),
|
|
|
|
stride=(kernel_size_pooling, kernel_size_pooling),
|
|
|
|
)
|
|
|
|
data_out = avage_pooling(data_out)
|
2023-07-12 14:02:35 +02:00
|
|
|
|
2023-07-13 11:05:07 +02:00
|
|
|
np.save(filename + str("_decorrelated.npy"), data_out.cpu())
|