diff --git a/run_svd.py b/run_svd.py index e12ea03..aaa7b50 100644 --- a/run_svd.py +++ b/run_svd.py @@ -1,6 +1,20 @@ import torch import numpy as np -from svd import calculate_svd, to_remove, temporal_filter, svd_denoise +import os + +import torchvision as tv + +from svd import ( + calculate_svd, + to_remove, + temporal_filter, + svd_denoise, + convert_avi_to_npy, + calculate_translation, +) + +from ImageAlignment import ImageAlignment + if __name__ == "__main__": filename: str = "example_data_crop" @@ -12,46 +26,86 @@ if __name__ == "__main__": bp_low_frequency: float = 0.1 bp_high_frequency: float = 1.0 + convert_overwrite: bool | None = None + + fill_value: float = 0.0 + torch_device: torch.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu" ) - print("Load data") - input = np.load(filename + str(".npy")) - data = torch.tensor(input, device=torch_device) + if ( + (convert_overwrite is None) + and (os.path.isfile("example_data_crop" + ".npy") is False) + ) or (convert_overwrite): + print("Convert AVI file to npy file.") + convert_avi_to_npy(filename) + print("--==-- DONE --==--") - print("Movement compensation [MISSING!!!!]") - print("(include ImageAlignment.py into processing chain)") + with torch.no_grad(): + print("Load data") + input = np.load(filename + str(".npy")) + data = torch.tensor(input, device=torch_device) - print("SVD") - whiten_mean, whiten_k, eigenvalues = calculate_svd(data) + print("Movement compensation [BROKEN!!!!]") + print("During development, information about what could move was missing.") + print("Thus the preprocessing before shift determination may not work.") + data -= data.min(dim=0)[0] + data /= data.std(dim=0, keepdim=True) + 1e-20 - print("Calculate to_remove") - data = torch.tensor(input, device=torch_device) - to_remove_data = to_remove(data, whiten_k, whiten_mean) + image_alignment = ImageAlignment( + default_dtype=torch.float32, device=torch_device + ) - data -= to_remove_data - del to_remove_data + tvec = calculate_translation( + input=data, + reference_image=data[0, ...].clone(), + image_alignment=image_alignment, + ) + tvec_media = tvec.median(dim=0)[0] + print(f"Median of movement: {tvec_media[0]}, {tvec_media[1]}") - print("apply temporal filter") - data = temporal_filter( - data, - device=torch_device, - orig_freq=orig_freq, - new_freq=new_freq, - filtfilt_chuck_size=filtfilt_chuck_size, - bp_low_frequency=bp_low_frequency, - bp_high_frequency=bp_high_frequency, - ) + data = torch.tensor(input, device=torch_device) - print("SVD Denosing") - data_out = svd_denoise(data, window_size=window_size) + for id in range(0, data.shape[0]): + data[id, ...] = tv.transforms.functional.affine( + img=data[id, ...].unsqueeze(0), + angle=0, + translate=[tvec[id, 1], tvec[id, 0]], + scale=1.0, + shear=0, + fill=fill_value, + ).squeeze(0) - print("Pooling") - avage_pooling = torch.nn.AvgPool2d( - kernel_size=(kernel_size_pooling, kernel_size_pooling), - stride=(kernel_size_pooling, kernel_size_pooling), - ) - data_out = avage_pooling(data_out) + print("SVD") + whiten_mean, whiten_k, eigenvalues = calculate_svd(data) - np.save(filename + str("_decorrelated.npy"), data_out.cpu()) + print("Calculate to_remove") + data = torch.tensor(input, device=torch_device) + to_remove_data = to_remove(data, whiten_k, whiten_mean) + + data -= to_remove_data + del to_remove_data + + print("apply temporal filter") + data = temporal_filter( + data, + device=torch_device, + orig_freq=orig_freq, + new_freq=new_freq, + filtfilt_chuck_size=filtfilt_chuck_size, + bp_low_frequency=bp_low_frequency, + bp_high_frequency=bp_high_frequency, + ) + + print("SVD Denosing") + data_out = svd_denoise(data, window_size=window_size) + + print("Pooling") + avage_pooling = torch.nn.AvgPool2d( + kernel_size=(kernel_size_pooling, kernel_size_pooling), + stride=(kernel_size_pooling, kernel_size_pooling), + ) + data_out = avage_pooling(data_out) + + np.save(filename + str("_decorrelated.npy"), data_out.cpu()) diff --git a/show.py b/show.py new file mode 100644 index 0000000..7276dd6 --- /dev/null +++ b/show.py @@ -0,0 +1,24 @@ +import numpy as np +import torch +from Anime import Anime + +# Convert from avi to npy +filename: str = "example_data_crop" + + +torch_device: torch.device = torch.device( + "cuda:0" if torch.cuda.is_available() else "cpu" +) + +print("Load data") +input = np.load(filename + str("_decorrelated.npy")) +data = torch.tensor(input, device=torch_device) +del input +print("loading done") + +data = data.nan_to_num(nan=0.0) +#data -= data.min(dim=0, keepdim=True)[0] + + +ani = Anime() +ani.show(data, vmin=0.0) diff --git a/show_b.py b/show_b.py new file mode 100644 index 0000000..fc6d637 --- /dev/null +++ b/show_b.py @@ -0,0 +1,24 @@ +import numpy as np +import torch +from Anime import Anime + +# Convert from avi to npy +filename: str = "example_data_crop" + + +torch_device: torch.device = torch.device( + "cuda:0" if torch.cuda.is_available() else "cpu" +) + +print("Load data") +input = np.load(filename + str("_decorrelated.npy")) +data = torch.tensor(input, device=torch_device) +del input +print("loading done") + +data = data.nan_to_num(nan=0.0) +data -= data.min(dim=0, keepdim=True)[0] +data *= data.std(dim=0, keepdim=True) + +ani = Anime() +ani.show(data) diff --git a/svd.py b/svd.py index 5c21120..f99b51d 100644 --- a/svd.py +++ b/svd.py @@ -96,6 +96,7 @@ def filtfilt( output = ta.functional.filtfilt( process_data_padded.unsqueeze(0), butter_a, butter_b, clamp=False ).squeeze(0) + output = output[..., padding_length:-padding_length].movedim(-1, 0) return output @@ -124,7 +125,7 @@ def chunk_iterator(array: torch.Tensor, chunk_size: int): @torch.no_grad() -def lowpass( +def bandpass( data: torch.Tensor, device: torch.device, low_frequency: float = 0.1, @@ -168,7 +169,7 @@ def temporal_filter( data.movedim(0, -1), orig_freq=orig_freq, new_freq=new_freq ).movedim(-1, 0) - data = lowpass( + data = bandpass( data, device=device, low_frequency=bp_low_frequency, @@ -202,3 +203,39 @@ def svd_denoise(data: torch.Tensor, window_size: int) -> torch.Tensor: to_remove_data = to_remove(data_sel, whiten_k, whiten_mean) data_out[:, x, y] = to_remove_data[:, window_size, window_size] return data_out + + +@torch.no_grad() +def calculate_translation( + input: torch.Tensor, + reference_image: torch.Tensor, + image_alignment, + start_position_coefficients: int = 0, + batch_size: int = 100, +) -> torch.Tensor: + tvec = torch.zeros((input.shape[0], 2)) + + data_loader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset(input[start_position_coefficients:, ...]), + batch_size=batch_size, + shuffle=False, + ) + start_position: int = 0 + for input_batch in data_loader: + assert len(input_batch) == 1 + + end_position = start_position + input_batch[0].shape[0] + + tvec_temp = image_alignment.dry_run_translation( + input=input_batch[0], + new_reference_image=reference_image, + ) + + assert tvec_temp is not None + + tvec[start_position:end_position, :] = tvec_temp + + start_position += input_batch[0].shape[0] + + tvec = torch.round(tvec) + return tvec