Add files via upload
This commit is contained in:
parent
2518b7f8f1
commit
fc84e1842e
4 changed files with 173 additions and 34 deletions
118
run_svd.py
118
run_svd.py
|
@ -1,6 +1,20 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from svd import calculate_svd, to_remove, temporal_filter, svd_denoise
|
||||
import os
|
||||
|
||||
import torchvision as tv
|
||||
|
||||
from svd import (
|
||||
calculate_svd,
|
||||
to_remove,
|
||||
temporal_filter,
|
||||
svd_denoise,
|
||||
convert_avi_to_npy,
|
||||
calculate_translation,
|
||||
)
|
||||
|
||||
from ImageAlignment import ImageAlignment
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
filename: str = "example_data_crop"
|
||||
|
@ -12,46 +26,86 @@ if __name__ == "__main__":
|
|||
bp_low_frequency: float = 0.1
|
||||
bp_high_frequency: float = 1.0
|
||||
|
||||
convert_overwrite: bool | None = None
|
||||
|
||||
fill_value: float = 0.0
|
||||
|
||||
torch_device: torch.device = torch.device(
|
||||
"cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
print("Load data")
|
||||
input = np.load(filename + str(".npy"))
|
||||
data = torch.tensor(input, device=torch_device)
|
||||
if (
|
||||
(convert_overwrite is None)
|
||||
and (os.path.isfile("example_data_crop" + ".npy") is False)
|
||||
) or (convert_overwrite):
|
||||
print("Convert AVI file to npy file.")
|
||||
convert_avi_to_npy(filename)
|
||||
print("--==-- DONE --==--")
|
||||
|
||||
print("Movement compensation [MISSING!!!!]")
|
||||
print("(include ImageAlignment.py into processing chain)")
|
||||
with torch.no_grad():
|
||||
print("Load data")
|
||||
input = np.load(filename + str(".npy"))
|
||||
data = torch.tensor(input, device=torch_device)
|
||||
|
||||
print("SVD")
|
||||
whiten_mean, whiten_k, eigenvalues = calculate_svd(data)
|
||||
print("Movement compensation [BROKEN!!!!]")
|
||||
print("During development, information about what could move was missing.")
|
||||
print("Thus the preprocessing before shift determination may not work.")
|
||||
data -= data.min(dim=0)[0]
|
||||
data /= data.std(dim=0, keepdim=True) + 1e-20
|
||||
|
||||
print("Calculate to_remove")
|
||||
data = torch.tensor(input, device=torch_device)
|
||||
to_remove_data = to_remove(data, whiten_k, whiten_mean)
|
||||
image_alignment = ImageAlignment(
|
||||
default_dtype=torch.float32, device=torch_device
|
||||
)
|
||||
|
||||
data -= to_remove_data
|
||||
del to_remove_data
|
||||
tvec = calculate_translation(
|
||||
input=data,
|
||||
reference_image=data[0, ...].clone(),
|
||||
image_alignment=image_alignment,
|
||||
)
|
||||
tvec_media = tvec.median(dim=0)[0]
|
||||
print(f"Median of movement: {tvec_media[0]}, {tvec_media[1]}")
|
||||
|
||||
print("apply temporal filter")
|
||||
data = temporal_filter(
|
||||
data,
|
||||
device=torch_device,
|
||||
orig_freq=orig_freq,
|
||||
new_freq=new_freq,
|
||||
filtfilt_chuck_size=filtfilt_chuck_size,
|
||||
bp_low_frequency=bp_low_frequency,
|
||||
bp_high_frequency=bp_high_frequency,
|
||||
)
|
||||
data = torch.tensor(input, device=torch_device)
|
||||
|
||||
print("SVD Denosing")
|
||||
data_out = svd_denoise(data, window_size=window_size)
|
||||
for id in range(0, data.shape[0]):
|
||||
data[id, ...] = tv.transforms.functional.affine(
|
||||
img=data[id, ...].unsqueeze(0),
|
||||
angle=0,
|
||||
translate=[tvec[id, 1], tvec[id, 0]],
|
||||
scale=1.0,
|
||||
shear=0,
|
||||
fill=fill_value,
|
||||
).squeeze(0)
|
||||
|
||||
print("Pooling")
|
||||
avage_pooling = torch.nn.AvgPool2d(
|
||||
kernel_size=(kernel_size_pooling, kernel_size_pooling),
|
||||
stride=(kernel_size_pooling, kernel_size_pooling),
|
||||
)
|
||||
data_out = avage_pooling(data_out)
|
||||
print("SVD")
|
||||
whiten_mean, whiten_k, eigenvalues = calculate_svd(data)
|
||||
|
||||
np.save(filename + str("_decorrelated.npy"), data_out.cpu())
|
||||
print("Calculate to_remove")
|
||||
data = torch.tensor(input, device=torch_device)
|
||||
to_remove_data = to_remove(data, whiten_k, whiten_mean)
|
||||
|
||||
data -= to_remove_data
|
||||
del to_remove_data
|
||||
|
||||
print("apply temporal filter")
|
||||
data = temporal_filter(
|
||||
data,
|
||||
device=torch_device,
|
||||
orig_freq=orig_freq,
|
||||
new_freq=new_freq,
|
||||
filtfilt_chuck_size=filtfilt_chuck_size,
|
||||
bp_low_frequency=bp_low_frequency,
|
||||
bp_high_frequency=bp_high_frequency,
|
||||
)
|
||||
|
||||
print("SVD Denosing")
|
||||
data_out = svd_denoise(data, window_size=window_size)
|
||||
|
||||
print("Pooling")
|
||||
avage_pooling = torch.nn.AvgPool2d(
|
||||
kernel_size=(kernel_size_pooling, kernel_size_pooling),
|
||||
stride=(kernel_size_pooling, kernel_size_pooling),
|
||||
)
|
||||
data_out = avage_pooling(data_out)
|
||||
|
||||
np.save(filename + str("_decorrelated.npy"), data_out.cpu())
|
||||
|
|
24
show.py
Normal file
24
show.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from Anime import Anime
|
||||
|
||||
# Convert from avi to npy
|
||||
filename: str = "example_data_crop"
|
||||
|
||||
|
||||
torch_device: torch.device = torch.device(
|
||||
"cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
print("Load data")
|
||||
input = np.load(filename + str("_decorrelated.npy"))
|
||||
data = torch.tensor(input, device=torch_device)
|
||||
del input
|
||||
print("loading done")
|
||||
|
||||
data = data.nan_to_num(nan=0.0)
|
||||
#data -= data.min(dim=0, keepdim=True)[0]
|
||||
|
||||
|
||||
ani = Anime()
|
||||
ani.show(data, vmin=0.0)
|
24
show_b.py
Normal file
24
show_b.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from Anime import Anime
|
||||
|
||||
# Convert from avi to npy
|
||||
filename: str = "example_data_crop"
|
||||
|
||||
|
||||
torch_device: torch.device = torch.device(
|
||||
"cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
print("Load data")
|
||||
input = np.load(filename + str("_decorrelated.npy"))
|
||||
data = torch.tensor(input, device=torch_device)
|
||||
del input
|
||||
print("loading done")
|
||||
|
||||
data = data.nan_to_num(nan=0.0)
|
||||
data -= data.min(dim=0, keepdim=True)[0]
|
||||
data *= data.std(dim=0, keepdim=True)
|
||||
|
||||
ani = Anime()
|
||||
ani.show(data)
|
41
svd.py
41
svd.py
|
@ -96,6 +96,7 @@ def filtfilt(
|
|||
output = ta.functional.filtfilt(
|
||||
process_data_padded.unsqueeze(0), butter_a, butter_b, clamp=False
|
||||
).squeeze(0)
|
||||
|
||||
output = output[..., padding_length:-padding_length].movedim(-1, 0)
|
||||
return output
|
||||
|
||||
|
@ -124,7 +125,7 @@ def chunk_iterator(array: torch.Tensor, chunk_size: int):
|
|||
|
||||
|
||||
@torch.no_grad()
|
||||
def lowpass(
|
||||
def bandpass(
|
||||
data: torch.Tensor,
|
||||
device: torch.device,
|
||||
low_frequency: float = 0.1,
|
||||
|
@ -168,7 +169,7 @@ def temporal_filter(
|
|||
data.movedim(0, -1), orig_freq=orig_freq, new_freq=new_freq
|
||||
).movedim(-1, 0)
|
||||
|
||||
data = lowpass(
|
||||
data = bandpass(
|
||||
data,
|
||||
device=device,
|
||||
low_frequency=bp_low_frequency,
|
||||
|
@ -202,3 +203,39 @@ def svd_denoise(data: torch.Tensor, window_size: int) -> torch.Tensor:
|
|||
to_remove_data = to_remove(data_sel, whiten_k, whiten_mean)
|
||||
data_out[:, x, y] = to_remove_data[:, window_size, window_size]
|
||||
return data_out
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def calculate_translation(
|
||||
input: torch.Tensor,
|
||||
reference_image: torch.Tensor,
|
||||
image_alignment,
|
||||
start_position_coefficients: int = 0,
|
||||
batch_size: int = 100,
|
||||
) -> torch.Tensor:
|
||||
tvec = torch.zeros((input.shape[0], 2))
|
||||
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
torch.utils.data.TensorDataset(input[start_position_coefficients:, ...]),
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
)
|
||||
start_position: int = 0
|
||||
for input_batch in data_loader:
|
||||
assert len(input_batch) == 1
|
||||
|
||||
end_position = start_position + input_batch[0].shape[0]
|
||||
|
||||
tvec_temp = image_alignment.dry_run_translation(
|
||||
input=input_batch[0],
|
||||
new_reference_image=reference_image,
|
||||
)
|
||||
|
||||
assert tvec_temp is not None
|
||||
|
||||
tvec[start_position:end_position, :] = tvec_temp
|
||||
|
||||
start_position += input_batch[0].shape[0]
|
||||
|
||||
tvec = torch.round(tvec)
|
||||
return tvec
|
||||
|
|
Loading…
Reference in a new issue