Add files via upload

This commit is contained in:
David Rotermund 2023-07-13 11:05:07 +02:00 committed by GitHub
parent 2518b7f8f1
commit fc84e1842e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 173 additions and 34 deletions

View file

@ -1,6 +1,20 @@
import torch import torch
import numpy as np import numpy as np
from svd import calculate_svd, to_remove, temporal_filter, svd_denoise import os
import torchvision as tv
from svd import (
calculate_svd,
to_remove,
temporal_filter,
svd_denoise,
convert_avi_to_npy,
calculate_translation,
)
from ImageAlignment import ImageAlignment
if __name__ == "__main__": if __name__ == "__main__":
filename: str = "example_data_crop" filename: str = "example_data_crop"
@ -12,46 +26,86 @@ if __name__ == "__main__":
bp_low_frequency: float = 0.1 bp_low_frequency: float = 0.1
bp_high_frequency: float = 1.0 bp_high_frequency: float = 1.0
convert_overwrite: bool | None = None
fill_value: float = 0.0
torch_device: torch.device = torch.device( torch_device: torch.device = torch.device(
"cuda:0" if torch.cuda.is_available() else "cpu" "cuda:0" if torch.cuda.is_available() else "cpu"
) )
print("Load data") if (
input = np.load(filename + str(".npy")) (convert_overwrite is None)
data = torch.tensor(input, device=torch_device) and (os.path.isfile("example_data_crop" + ".npy") is False)
) or (convert_overwrite):
print("Convert AVI file to npy file.")
convert_avi_to_npy(filename)
print("--==-- DONE --==--")
print("Movement compensation [MISSING!!!!]") with torch.no_grad():
print("(include ImageAlignment.py into processing chain)") print("Load data")
input = np.load(filename + str(".npy"))
data = torch.tensor(input, device=torch_device)
print("SVD") print("Movement compensation [BROKEN!!!!]")
whiten_mean, whiten_k, eigenvalues = calculate_svd(data) print("During development, information about what could move was missing.")
print("Thus the preprocessing before shift determination may not work.")
data -= data.min(dim=0)[0]
data /= data.std(dim=0, keepdim=True) + 1e-20
print("Calculate to_remove") image_alignment = ImageAlignment(
data = torch.tensor(input, device=torch_device) default_dtype=torch.float32, device=torch_device
to_remove_data = to_remove(data, whiten_k, whiten_mean) )
data -= to_remove_data tvec = calculate_translation(
del to_remove_data input=data,
reference_image=data[0, ...].clone(),
image_alignment=image_alignment,
)
tvec_media = tvec.median(dim=0)[0]
print(f"Median of movement: {tvec_media[0]}, {tvec_media[1]}")
print("apply temporal filter") data = torch.tensor(input, device=torch_device)
data = temporal_filter(
data,
device=torch_device,
orig_freq=orig_freq,
new_freq=new_freq,
filtfilt_chuck_size=filtfilt_chuck_size,
bp_low_frequency=bp_low_frequency,
bp_high_frequency=bp_high_frequency,
)
print("SVD Denosing") for id in range(0, data.shape[0]):
data_out = svd_denoise(data, window_size=window_size) data[id, ...] = tv.transforms.functional.affine(
img=data[id, ...].unsqueeze(0),
angle=0,
translate=[tvec[id, 1], tvec[id, 0]],
scale=1.0,
shear=0,
fill=fill_value,
).squeeze(0)
print("Pooling") print("SVD")
avage_pooling = torch.nn.AvgPool2d( whiten_mean, whiten_k, eigenvalues = calculate_svd(data)
kernel_size=(kernel_size_pooling, kernel_size_pooling),
stride=(kernel_size_pooling, kernel_size_pooling),
)
data_out = avage_pooling(data_out)
np.save(filename + str("_decorrelated.npy"), data_out.cpu()) print("Calculate to_remove")
data = torch.tensor(input, device=torch_device)
to_remove_data = to_remove(data, whiten_k, whiten_mean)
data -= to_remove_data
del to_remove_data
print("apply temporal filter")
data = temporal_filter(
data,
device=torch_device,
orig_freq=orig_freq,
new_freq=new_freq,
filtfilt_chuck_size=filtfilt_chuck_size,
bp_low_frequency=bp_low_frequency,
bp_high_frequency=bp_high_frequency,
)
print("SVD Denosing")
data_out = svd_denoise(data, window_size=window_size)
print("Pooling")
avage_pooling = torch.nn.AvgPool2d(
kernel_size=(kernel_size_pooling, kernel_size_pooling),
stride=(kernel_size_pooling, kernel_size_pooling),
)
data_out = avage_pooling(data_out)
np.save(filename + str("_decorrelated.npy"), data_out.cpu())

24
show.py Normal file
View file

@ -0,0 +1,24 @@
import numpy as np
import torch
from Anime import Anime
# Convert from avi to npy
filename: str = "example_data_crop"
torch_device: torch.device = torch.device(
"cuda:0" if torch.cuda.is_available() else "cpu"
)
print("Load data")
input = np.load(filename + str("_decorrelated.npy"))
data = torch.tensor(input, device=torch_device)
del input
print("loading done")
data = data.nan_to_num(nan=0.0)
#data -= data.min(dim=0, keepdim=True)[0]
ani = Anime()
ani.show(data, vmin=0.0)

24
show_b.py Normal file
View file

@ -0,0 +1,24 @@
import numpy as np
import torch
from Anime import Anime
# Convert from avi to npy
filename: str = "example_data_crop"
torch_device: torch.device = torch.device(
"cuda:0" if torch.cuda.is_available() else "cpu"
)
print("Load data")
input = np.load(filename + str("_decorrelated.npy"))
data = torch.tensor(input, device=torch_device)
del input
print("loading done")
data = data.nan_to_num(nan=0.0)
data -= data.min(dim=0, keepdim=True)[0]
data *= data.std(dim=0, keepdim=True)
ani = Anime()
ani.show(data)

41
svd.py
View file

@ -96,6 +96,7 @@ def filtfilt(
output = ta.functional.filtfilt( output = ta.functional.filtfilt(
process_data_padded.unsqueeze(0), butter_a, butter_b, clamp=False process_data_padded.unsqueeze(0), butter_a, butter_b, clamp=False
).squeeze(0) ).squeeze(0)
output = output[..., padding_length:-padding_length].movedim(-1, 0) output = output[..., padding_length:-padding_length].movedim(-1, 0)
return output return output
@ -124,7 +125,7 @@ def chunk_iterator(array: torch.Tensor, chunk_size: int):
@torch.no_grad() @torch.no_grad()
def lowpass( def bandpass(
data: torch.Tensor, data: torch.Tensor,
device: torch.device, device: torch.device,
low_frequency: float = 0.1, low_frequency: float = 0.1,
@ -168,7 +169,7 @@ def temporal_filter(
data.movedim(0, -1), orig_freq=orig_freq, new_freq=new_freq data.movedim(0, -1), orig_freq=orig_freq, new_freq=new_freq
).movedim(-1, 0) ).movedim(-1, 0)
data = lowpass( data = bandpass(
data, data,
device=device, device=device,
low_frequency=bp_low_frequency, low_frequency=bp_low_frequency,
@ -202,3 +203,39 @@ def svd_denoise(data: torch.Tensor, window_size: int) -> torch.Tensor:
to_remove_data = to_remove(data_sel, whiten_k, whiten_mean) to_remove_data = to_remove(data_sel, whiten_k, whiten_mean)
data_out[:, x, y] = to_remove_data[:, window_size, window_size] data_out[:, x, y] = to_remove_data[:, window_size, window_size]
return data_out return data_out
@torch.no_grad()
def calculate_translation(
input: torch.Tensor,
reference_image: torch.Tensor,
image_alignment,
start_position_coefficients: int = 0,
batch_size: int = 100,
) -> torch.Tensor:
tvec = torch.zeros((input.shape[0], 2))
data_loader = torch.utils.data.DataLoader(
torch.utils.data.TensorDataset(input[start_position_coefficients:, ...]),
batch_size=batch_size,
shuffle=False,
)
start_position: int = 0
for input_batch in data_loader:
assert len(input_batch) == 1
end_position = start_position + input_batch[0].shape[0]
tvec_temp = image_alignment.dry_run_translation(
input=input_batch[0],
new_reference_image=reference_image,
)
assert tvec_temp is not None
tvec[start_position:end_position, :] = tvec_temp
start_position += input_batch[0].shape[0]
tvec = torch.round(tvec)
return tvec