diff --git a/reproduction_effort/aligned.py b/reproduction_effort/aligned.py deleted file mode 100644 index a1d6aea..0000000 --- a/reproduction_effort/aligned.py +++ /dev/null @@ -1,119 +0,0 @@ -import scipy.io as sio # type: ignore -import torch -import numpy as np -import matplotlib.pyplot as plt -import json - -from functions.align_cameras import align_cameras - -if __name__ == "__main__": - - if torch.cuda.is_available(): - device_name: str = "cuda:0" - else: - device_name = "cpu" - print(f"Using device: {device_name}") - device: torch.device = torch.device(device_name) - dtype: torch.dtype = torch.float32 - - filename_raw_json: str = "raw/Exp001_Trial001_Part001_meta.txt" - filename_data_binning_replace: str = "bin_old/Exp001_Trial001_Part001.mat" - batch_size: int = 200 - - filename_aligned_mat: str = "aligned_old/Exp001_Trial001_Part001.mat" - - with open(filename_raw_json, "r") as file_handle: - metadata: dict = json.load(file_handle) - channels: list[str] = metadata["channelKey"] - - data = torch.tensor( - sio.loadmat(filename_data_binning_replace)["nparray"].astype(np.float32), - dtype=dtype, - device=device, - ) - - ref_image = data[:, :, data.shape[-2] // 2, :].clone() - - ( - acceptor, - donor, - oxygenation, - volume, - angle_donor_volume, - tvec_donor_volume, - angle_refref, - tvec_refref, - ) = align_cameras( - channels=channels, - data=data, - ref_image=ref_image, - device=device, - dtype=dtype, - batch_size=batch_size, - fill_value=-1, - ) - del data - - mat_data = torch.tensor( - sio.loadmat(filename_aligned_mat)["data"].astype(dtype=np.float32), - dtype=dtype, - device=device, - ) - - old: list = [] - old.append(mat_data[..., 0].movedim(-1, 0)) - old.append(mat_data[..., 1].movedim(-1, 0)) - old.append(mat_data[..., 2].movedim(-1, 0)) - old.append(mat_data[..., 3].movedim(-1, 0)) - - new: list = [] - new.append(acceptor) - new.append(donor) - new.append(oxygenation) - new.append(volume) - - names: list = [] - new.append("acceptor") - new.append("donor") - new.append("oxygenation") - new.append("volume") - - mask = torch.zeros( - (acceptor.shape[-2], acceptor.shape[-1]), - dtype=torch.bool, - device=device, - ) - - mask[torch.any(acceptor < 0, dim=0)] = True - mask[torch.any(donor < 0, dim=0)] = True - mask[torch.any(oxygenation < 0, dim=0)] = True - mask[torch.any(volume < 0, dim=0)] = True - - frame_id: int = 0 - image: list = [] - for channel_id in range(0, len(old)): - temp = np.zeros((new[channel_id].shape[-2], new[channel_id].shape[-1], 3)) - temp[:, :, 0] = ( - old[channel_id][frame_id, ...] / old[channel_id][frame_id, ...].max() - ).cpu() - temp[:, :, 1] = ( - new[channel_id][frame_id, ...] / new[channel_id][frame_id, ...].max() - ).cpu() - temp[:, :, 2] = 0.0 - image.append(temp) - - subplot_position: int = 1 - for channel_id in range(0, len(old)): - difference = (image[channel_id][..., 0] - image[channel_id][..., 1]) / ( - image[channel_id][..., 0] + image[channel_id][..., 1] - ) - plt.subplot(4, 2, subplot_position) - plt.imshow(difference, cmap="hot") - plt.colorbar() - subplot_position += 1 - - plt.subplot(4, 2, subplot_position) - plt.plot(np.sort(difference.flatten())) - subplot_position += 1 - - plt.show() diff --git a/reproduction_effort/binning.py b/reproduction_effort/binning.py deleted file mode 100644 index be70805..0000000 --- a/reproduction_effort/binning.py +++ /dev/null @@ -1,18 +0,0 @@ -import numpy as np -import torch -import os -import scipy.io as sio # type: ignore - -from functions.binning import binning - -filename_raw: str = f"raw{os.sep}Exp001_Trial001_Part001.npy" -filename_old_mat: str = "Exp001_Trial001_Part001.mat" - -data = torch.tensor(np.load(filename_raw).astype(np.float32)) - -data = binning(data) - -mat_data = torch.tensor(sio.loadmat(filename_old_mat)["nparray"].astype(np.float32)) - -diff = torch.abs(mat_data - data) -print(diff.min(), diff.max()) diff --git a/reproduction_effort/binning_aligned_process.py b/reproduction_effort/binning_aligned_process.py deleted file mode 100644 index d3bfd89..0000000 --- a/reproduction_effort/binning_aligned_process.py +++ /dev/null @@ -1,257 +0,0 @@ -import numpy as np -import torch -import os -import json -import matplotlib.pyplot as plt -import h5py # type: ignore -import scipy.io as sio # type: ignore - - -from functions.binning import binning -from functions.align_cameras import align_cameras -from functions.preprocessing import preprocessing -from functions.bandpass import bandpass -from functions.make_mask import make_mask -from functions.interpolate_along_time import interpolate_along_time - -if torch.cuda.is_available(): - device_name: str = "cuda:0" -else: - device_name = "cpu" -print(f"Using device: {device_name}") -device: torch.device = torch.device(device_name) -dtype: torch.dtype = torch.float32 - - -filename_raw: str = f"raw{os.sep}Exp001_Trial001_Part001.npy" -filename_raw_json: str = f"raw{os.sep}Exp001_Trial001_Part001_meta.txt" -filename_mask: str = "2020-12-08maskPixelraw2.mat" - -first_none_ramp_frame: int = 100 -spatial_width: float = 2 -temporal_width: float = 0.1 - -lower_freqency_bandpass: float = 5.0 -upper_freqency_bandpass: float = 14.0 - -lower_frequency_heartbeat: float = 5.0 -upper_frequency_heartbeat: float = 14.0 -sample_frequency: float = 100.0 - -target_camera: list[str] = ["acceptor", "donor"] -regressor_cameras: list[str] = ["oxygenation", "volume"] -batch_size: int = 200 -required_order: list[str] = ["acceptor", "donor", "oxygenation", "volume"] - - -test_overwrite_with_old_bining: bool = False -test_overwrite_with_old_aligned: bool = True -filename_data_binning_replace: str = "bin_old/Exp001_Trial001_Part001.mat" -filename_data_aligned_replace: str = "aligned_old/Exp001_Trial001_Part001.mat" - -data = torch.tensor(np.load(filename_raw).astype(np.float32), dtype=dtype) - -with open(filename_raw_json, "r") as file_handle: - metadata: dict = json.load(file_handle) -channels: list[str] = metadata["channelKey"] - -data = binning(data).to(device) - -if test_overwrite_with_old_bining: - data = torch.tensor( - sio.loadmat(filename_data_binning_replace)["nparray"].astype(np.float32), - dtype=dtype, - device=device, - ) - -ref_image = data[:, :, data.shape[-2] // 2, :].clone() - -( - acceptor, - donor, - oxygenation, - volume, - angle_donor_volume, - tvec_donor_volume, - angle_refref, - tvec_refref, -) = align_cameras( - channels=channels, - data=data, - ref_image=ref_image, - device=device, - dtype=dtype, - batch_size=batch_size, - fill_value=-1, -) -del data - - -camera_sequence: list[torch.Tensor] = [] - -for cam in required_order: - if cam.startswith("acceptor"): - camera_sequence.append(acceptor.movedim(0, -1).clone()) - del acceptor - if cam.startswith("donor"): - camera_sequence.append(donor.movedim(0, -1).clone()) - del donor - if cam.startswith("oxygenation"): - camera_sequence.append(oxygenation.movedim(0, -1).clone()) - del oxygenation - if cam.startswith("volume"): - camera_sequence.append(volume.movedim(0, -1).clone()) - del volume - -if test_overwrite_with_old_aligned: - - data_aligned_replace: torch.Tensor = torch.tensor( - sio.loadmat(filename_data_aligned_replace)["data"].astype(np.float32), - device=device, - dtype=dtype, - ) - - camera_sequence[0] = data_aligned_replace[..., 0].clone() - camera_sequence[1] = data_aligned_replace[..., 1].clone() - camera_sequence[2] = data_aligned_replace[..., 2].clone() - camera_sequence[3] = data_aligned_replace[..., 3].clone() - del data_aligned_replace - -# -> - - -mask: torch.Tensor = make_mask( - filename_mask=filename_mask, - camera_sequence=camera_sequence, - device=device, - dtype=dtype, -) - -mask_flatten = mask.flatten(start_dim=0, end_dim=-1) - -interpolate_along_time(camera_sequence) - -heartbeat_ts: torch.Tensor = bandpass( - data=camera_sequence[channels.index("volume")].clone(), - device=camera_sequence[channels.index("volume")].device, - low_frequency=lower_freqency_bandpass, - high_frequency=upper_freqency_bandpass, - fs=sample_frequency, - filtfilt_chuck_size=10, -) - -heartbeat_ts_copy = heartbeat_ts.clone() - -heartbeat_ts = heartbeat_ts.flatten(start_dim=0, end_dim=-2) -heartbeat_ts = heartbeat_ts[mask_flatten, :] - -heartbeat_ts = heartbeat_ts.movedim(0, -1) -heartbeat_ts -= heartbeat_ts.mean(dim=0, keepdim=True) - -volume_heartbeat, _, _ = torch.linalg.svd(heartbeat_ts, full_matrices=False) -volume_heartbeat = volume_heartbeat[:, 0] -volume_heartbeat -= volume_heartbeat[first_none_ramp_frame:].mean() -volume_heartbeat = volume_heartbeat.unsqueeze(0).unsqueeze(0) - -heartbeat_coefficients: list[torch.Tensor] = [] -for i in range(0, len(camera_sequence)): - y = bandpass( - data=camera_sequence[i].clone(), - device=camera_sequence[i].device, - low_frequency=lower_freqency_bandpass, - high_frequency=upper_freqency_bandpass, - fs=sample_frequency, - filtfilt_chuck_size=10, - )[..., first_none_ramp_frame:] - y -= y.mean(dim=-1, keepdim=True) - - heartbeat_coefficients.append( - ( - (volume_heartbeat[..., first_none_ramp_frame:] * y).sum( - dim=-1, keepdim=True - ) - / (volume_heartbeat[..., first_none_ramp_frame:] ** 2).sum( - dim=-1, keepdim=True - ) - ) - * mask.unsqueeze(-1) - ) -del y - -donor_correction_factor = heartbeat_coefficients[channels.index("donor")].clone() -acceptor_correction_factor = heartbeat_coefficients[channels.index("acceptor")].clone() - - -for i in range(0, len(camera_sequence)): - camera_sequence[i] -= heartbeat_coefficients[i] * volume_heartbeat - - -donor_factor: torch.Tensor = (donor_correction_factor + acceptor_correction_factor) / ( - 2 * donor_correction_factor -) -acceptor_factor: torch.Tensor = ( - donor_correction_factor + acceptor_correction_factor -) / (2 * acceptor_correction_factor) - - -# mean_values: list = [] -# for i in range(0, len(channels)): -# mean_values.append( -# camera_sequence[i][..., first_none_ramp_frame:].nanmean(dim=-1, keepdim=True) -# ) -# camera_sequence[i] -= mean_values[i] - -camera_sequence[channels.index("acceptor")] *= acceptor_factor * mask.unsqueeze(-1) -camera_sequence[channels.index("donor")] *= donor_factor * mask.unsqueeze(-1) - -# for i in range(0, len(channels)): -# camera_sequence[i] -= mean_values[i] - -# exit() -# <- - -data_acceptor, data_donor, mask = preprocessing( - cameras=channels, - camera_sequence=camera_sequence, - filename_mask=filename_mask, - device=device, - first_none_ramp_frame=first_none_ramp_frame, - spatial_width=spatial_width, - temporal_width=temporal_width, - target_camera=target_camera, - regressor_cameras=regressor_cameras, -) - -ratio_sequence: torch.Tensor = data_acceptor / data_donor - -ratio_sequence /= ratio_sequence.mean(dim=-1, keepdim=True) -ratio_sequence = torch.nan_to_num(ratio_sequence, nan=0.0) - -new: np.ndarray = ratio_sequence.cpu().numpy() - -file_handle = h5py.File("old.mat", "r") -old: np.ndarray = np.array(file_handle["ratioSequence"]) # type:ignore -# HDF5 loads everything backwards... -old = np.moveaxis(old, 0, -1) -old = np.moveaxis(old, 0, -2) - -pos_x = 25 -pos_y = 75 - -plt.figure(1) -plt.subplot(2, 1, 1) -new_select = new[pos_x, pos_y, :] -old_select = old[pos_x, pos_y, :] -plt.plot(old_select, "r", label="Old") -plt.plot(new_select, "k", label="New") - -# plt.plot(old_select - new_select + 1.0, label="Old - New + 1") -plt.title(f"Position: {pos_x}, {pos_y}") -plt.legend() - -plt.subplot(2, 1, 2) -differences = (np.abs(new - old)).max(axis=-1) * mask.cpu().numpy() -plt.imshow(differences, cmap="hot") -plt.title("Max of abs(new-old) along time") -plt.colorbar() -plt.show() diff --git a/reproduction_effort/binning_aligned_process_classic.py b/reproduction_effort/binning_aligned_process_classic.py deleted file mode 100644 index 19a9e85..0000000 --- a/reproduction_effort/binning_aligned_process_classic.py +++ /dev/null @@ -1,292 +0,0 @@ -import numpy as np -import torch -import os -import json -import matplotlib.pyplot as plt -import h5py # type: ignore -import scipy.io as sio # type: ignore - -from functions.binning import binning -from functions.align_cameras import align_cameras -from functions.preprocessing_classsic import preprocessing -from functions.bandpass import bandpass - - -if torch.cuda.is_available(): - device_name: str = "cuda:0" -else: - device_name = "cpu" -print(f"Using device: {device_name}") -device: torch.device = torch.device(device_name) -dtype: torch.dtype = torch.float32 - - -filename_raw: str = f"raw{os.sep}Exp001_Trial001_Part001.npy" -filename_raw_json: str = f"raw{os.sep}Exp001_Trial001_Part001_meta.txt" -filename_mask: str = "2020-12-08maskPixelraw2.mat" - -first_none_ramp_frame: int = 100 -spatial_width: float = 2 -temporal_width: float = 0.1 - -lower_freqency_bandpass: float = 5.0 -upper_freqency_bandpass: float = 14.0 - -lower_frequency_heartbeat: float = 5.0 -upper_frequency_heartbeat: float = 14.0 -sample_frequency: float = 100.0 - -target_camera: list[str] = ["acceptor", "donor"] -regressor_cameras: list[str] = ["oxygenation", "volume"] -batch_size: int = 200 -required_order: list[str] = ["acceptor", "donor", "oxygenation", "volume"] - - -test_overwrite_with_old_bining: bool = False -test_overwrite_with_old_aligned: bool = False -filename_data_binning_replace: str = "bin_old/Exp001_Trial001_Part001.mat" -filename_data_aligned_replace: str = "aligned_old/Exp001_Trial001_Part001.mat" - -data = torch.tensor(np.load(filename_raw).astype(np.float32), dtype=dtype) - -with open(filename_raw_json, "r") as file_handle: - metadata: dict = json.load(file_handle) -channels: list[str] = metadata["channelKey"] - -data = binning(data).to(device) - -if test_overwrite_with_old_bining: - data = torch.tensor( - sio.loadmat(filename_data_binning_replace)["nparray"].astype(np.float32), - dtype=dtype, - device=device, - ) - -ref_image = data[:, :, data.shape[-2] // 2, :].clone() - -( - acceptor, - donor, - oxygenation, - volume, - angle_donor_volume, - tvec_donor_volume, - angle_refref, - tvec_refref, -) = align_cameras( - channels=channels, - data=data, - ref_image=ref_image, - device=device, - dtype=dtype, - batch_size=batch_size, - fill_value=-1, -) -del data - - -camera_sequence: list[torch.Tensor] = [] - -for cam in required_order: - if cam.startswith("acceptor"): - camera_sequence.append(acceptor.movedim(0, -1).clone()) - del acceptor - if cam.startswith("donor"): - camera_sequence.append(donor.movedim(0, -1).clone()) - del donor - if cam.startswith("oxygenation"): - camera_sequence.append(oxygenation.movedim(0, -1).clone()) - del oxygenation - if cam.startswith("volume"): - camera_sequence.append(volume.movedim(0, -1).clone()) - del volume - -if test_overwrite_with_old_aligned: - - data_aligned_replace: torch.Tensor = torch.tensor( - sio.loadmat(filename_data_aligned_replace)["data"].astype(np.float32), - device=device, - dtype=dtype, - ) - - camera_sequence[0] = data_aligned_replace[..., 0].clone() - camera_sequence[1] = data_aligned_replace[..., 1].clone() - camera_sequence[2] = data_aligned_replace[..., 2].clone() - camera_sequence[3] = data_aligned_replace[..., 3].clone() - del data_aligned_replace - -data_acceptor, data_donor, mask = preprocessing( - cameras=channels, - camera_sequence=camera_sequence, - filename_mask=filename_mask, - device=device, - first_none_ramp_frame=first_none_ramp_frame, - spatial_width=spatial_width, - temporal_width=temporal_width, - target_camera=target_camera, - regressor_cameras=regressor_cameras, - lower_frequency_heartbeat=lower_frequency_heartbeat, - upper_frequency_heartbeat=upper_frequency_heartbeat, - sample_frequency=sample_frequency, -) - -ratio_sequence: torch.Tensor = data_acceptor / data_donor - -ratio_sequence /= ratio_sequence.mean(dim=-1, keepdim=True) -ratio_sequence = torch.nan_to_num(ratio_sequence, nan=0.0) - -new: np.ndarray = ratio_sequence.cpu().numpy() - -file_handle = h5py.File("old.mat", "r") -old: np.ndarray = np.array(file_handle["ratioSequence"]) # type:ignore -# HDF5 loads everything backwards... -old = np.moveaxis(old, 0, -1) -old = np.moveaxis(old, 0, -2) - -# pos_x = 25 -# pos_y = 75 - -# plt.figure(1) -# new_select = new[pos_x, pos_y, :] -# old_select = old[pos_x, pos_y, :] -# plt.plot(new_select, label="New") -# plt.plot(old_select, "--", label="Old") -# # plt.plot(old_select - new_select + 1.0, label="Old - New + 1") -# plt.title(f"Position: {pos_x}, {pos_y}") -# plt.legend() - -# plt.show(block=False) - - -# plt.figure(2) -# new_select1 = new[pos_x + 1, pos_y, :] -# old_select1 = old[pos_x + 1, pos_y, :] -# plt.plot(new_select1, label="New") -# plt.plot(old_select1, "--", label="Old") -# # plt.plot(old_select - new_select + 1.0, label="Old - New + 1") -# plt.title(f"Position: {pos_x+1}, {pos_y}") -# plt.legend() - -# plt.show(block=False) - - -# plt.figure(3) -# plt.plot(old_select, label="Old") -# plt.plot(old_select1, "--", label="Old") -# # plt.plot(old_select - new_select + 1.0, label="Old - New + 1") -# # plt.title(f"Position: {pos_x+1}, {pos_y}") -# plt.legend() - -# s1 = old_select[np.newaxis, 100:] -# s2 = new_select[np.newaxis, 100:] -# s3 = old_select1[np.newaxis, 100:] - -# print("old-new", np.corrcoef(np.concatenate((s1, s2)))) -# print("old-oldshift", np.corrcoef(np.concatenate((s1, s3)))) - -plt.figure(4) -mask = mask.cpu().numpy() -mask_flatten = np.reshape(mask, (mask.shape[0] * mask.shape[1])) -data = np.reshape(old, (old.shape[0] * old.shape[1], old.shape[-1])) -data = data[mask_flatten == 1, 100:] -cc = np.corrcoef(data) - -cc_back = np.zeros_like(mask, dtype=np.float32) -cc_back = np.reshape(cc_back, (mask.shape[0] * mask.shape[1])) - -rng = np.random.default_rng() -cc_back[mask_flatten] = cc[:, 400] -cc_back = np.reshape(cc_back, (mask.shape[0], mask.shape[1])) - -plt.subplot(1, 2, 1) -plt.imshow(cc_back, cmap="hot") -plt.colorbar() - -plt.subplot(1, 2, 2) -plt.plot(cc[:, 400]) - - -plt.show(block=True) - - -# block=False -# ratio_sequence_a = bandpass( -# data=data_acceptor, -# device=data_acceptor.device, -# low_frequency=lower_freqency_bandpass, -# high_frequency=upper_freqency_bandpass, -# fs=100.0, -# filtfilt_chuck_size=10, -# ) - -# ratio_sequence_b = bandpass( -# data=data_donor, -# device=data_donor.device, -# low_frequency=lower_freqency_bandpass, -# high_frequency=upper_freqency_bandpass, -# fs=100.0, -# filtfilt_chuck_size=10, -# ) - -# original_shape = ratio_sequence_a.shape - -# ratio_sequence_a = ratio_sequence_a.flatten(start_dim=0, end_dim=-2) -# ratio_sequence_b = ratio_sequence_b.flatten(start_dim=0, end_dim=-2) - -# mask = mask.flatten(start_dim=0, end_dim=-1) -# ratio_sequence_a = ratio_sequence_a[mask, :] -# ratio_sequence_b = ratio_sequence_b[mask, :] - -# ratio_sequence_a = ratio_sequence_a.movedim(0, -1) -# ratio_sequence_b = ratio_sequence_b.movedim(0, -1) - -# ratio_sequence_a -= ratio_sequence_a.mean(dim=0, keepdim=True) -# ratio_sequence_b -= ratio_sequence_b.mean(dim=0, keepdim=True) - -# u_a, s_a, Vh_a = torch.linalg.svd(ratio_sequence_a, full_matrices=False) -# u_a = u_a[:, 0] -# s_a = s_a[0] -# Vh_a = Vh_a[0, :] - -# heartbeatactivitmap_a = torch.zeros( -# (original_shape[0], original_shape[1]), device=Vh_a.device, dtype=Vh_a.dtype -# ).flatten(start_dim=0, end_dim=-1) - -# heartbeatactivitmap_a *= torch.nan -# heartbeatactivitmap_a[mask] = s_a * Vh_a -# heartbeatactivitmap_a = heartbeatactivitmap_a.reshape( -# (original_shape[0], original_shape[1]) -# ) - -# u_b, s_b, Vh_b = torch.linalg.svd(ratio_sequence_b, full_matrices=False) -# u_b = u_b[:, 0] -# s_b = s_b[0] -# Vh_b = Vh_b[0, :] - -# heartbeatactivitmap_b = torch.zeros( -# (original_shape[0], original_shape[1]), device=Vh_b.device, dtype=Vh_b.dtype -# ).flatten(start_dim=0, end_dim=-1) - -# heartbeatactivitmap_b *= torch.nan -# heartbeatactivitmap_b[mask] = s_b * Vh_b -# heartbeatactivitmap_b = heartbeatactivitmap_b.reshape( -# (original_shape[0], original_shape[1]) -# ) - -# plt.figure(2) -# plt.subplot(2, 1, 1) -# plt.plot(u_a.cpu(), label="aceptor") -# plt.plot(u_b.cpu(), label="donor") -# plt.legend() -# plt.subplot(2, 1, 2) -# plt.imshow( -# torch.cat( -# ( -# heartbeatactivitmap_a, -# heartbeatactivitmap_b, -# ), -# dim=1, -# ).cpu() -# ) -# plt.colorbar() -# plt.show() diff --git a/reproduction_effort/functions/ImageAlignment.py b/reproduction_effort/functions/ImageAlignment.py deleted file mode 100644 index 6472d02..0000000 --- a/reproduction_effort/functions/ImageAlignment.py +++ /dev/null @@ -1,1015 +0,0 @@ -import torch -import torchvision as tv # type: ignore - -# The source code is based on: -# https://github.com/matejak/imreg_dft - -# The original LICENSE: -# Copyright (c) 2014, Matěj Týč -# Copyright (c) 2011-2014, Christoph Gohlke -# Copyright (c) 2011-2014, The Regents of the University of California -# Produced at the Laboratory for Fluorescence Dynamics - -# All rights reserved. - -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: - -# * Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. - -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. - -# * Neither the name of the {organization} nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. - -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - -class ImageAlignment(torch.nn.Module): - device: torch.device - default_dtype: torch.dtype - excess_const: float = 1.1 - exponent: str = "inf" - success: torch.Tensor | None = None - - # The factor that detmines how many - # sub-pixel we will shift - scale_factor: int = 4 - - reference_image: torch.Tensor | None = None - - last_scale: torch.Tensor | None = None - last_angle: torch.Tensor | None = None - last_tvec: torch.Tensor | None = None - - # Cache - image_reference_dft: torch.Tensor | None = None - filt: torch.Tensor - pcorr_shape: torch.Tensor - log_base: torch.Tensor - image_reference_logp: torch.Tensor - - def __init__( - self, - device: torch.device | None = None, - default_dtype: torch.dtype | None = None, - ) -> None: - super().__init__() - - assert device is not None - assert default_dtype is not None - self.device = device - self.default_dtype = default_dtype - - def set_new_reference_image(self, new_reference_image: torch.Tensor | None = None): - assert new_reference_image is not None - assert new_reference_image.ndim == 2 - self.reference_image = ( - new_reference_image.detach() - .clone() - .to(device=self.device) - .type(dtype=self.default_dtype) - ) - self.image_reference_dft = None - - def forward( - self, input: torch.Tensor, new_reference_image: torch.Tensor | None = None - ) -> torch.Tensor: - assert input.ndim == 3 - - if new_reference_image is not None: - self.set_new_reference_image(new_reference_image) - - assert self.reference_image is not None - assert self.reference_image.ndim == 2 - assert input.shape[-2] == self.reference_image.shape[-2] - assert input.shape[-1] == self.reference_image.shape[-1] - - self.last_scale, self.last_angle, self.last_tvec, output = self.similarity( - self.reference_image, - input.to(device=self.device).type(dtype=self.default_dtype), - ) - - return output - - def dry_run( - self, input: torch.Tensor, new_reference_image: torch.Tensor | None = None - ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: - assert input.ndim == 3 - - if new_reference_image is not None: - self.set_new_reference_image(new_reference_image) - - assert self.reference_image is not None - assert self.reference_image.ndim == 2 - assert input.shape[-2] == self.reference_image.shape[-2] - assert input.shape[-1] == self.reference_image.shape[-1] - - images_todo = input.to(device=self.device).type(dtype=self.default_dtype) - image_reference = self.reference_image - - assert image_reference.ndim == 2 - assert images_todo.ndim == 3 - - bgval: torch.Tensor = self.get_borderval(img=images_todo, radius=5) - - self.last_scale, self.last_angle, self.last_tvec = self._similarity( - image_reference, - images_todo, - bgval, - ) - - return self.last_scale, self.last_angle, self.last_tvec - - def dry_run_translation( - self, input: torch.Tensor, new_reference_image: torch.Tensor | None = None - ) -> torch.Tensor: - assert input.ndim == 3 - - if new_reference_image is not None: - self.set_new_reference_image(new_reference_image) - - assert self.reference_image is not None - assert self.reference_image.ndim == 2 - assert input.shape[-2] == self.reference_image.shape[-2] - assert input.shape[-1] == self.reference_image.shape[-1] - - images_todo = input.to(device=self.device).type(dtype=self.default_dtype) - image_reference = self.reference_image - - assert image_reference.ndim == 2 - assert images_todo.ndim == 3 - - tvec, _ = self._translation(image_reference, images_todo) - - return tvec - - # --------------- - - def dry_run_angle( - self, - input: torch.Tensor, - new_reference_image: torch.Tensor | None = None, - ) -> torch.Tensor: - assert input.ndim == 3 - - if new_reference_image is not None: - self.set_new_reference_image(new_reference_image) - - constraints_dynamic_angle_0: torch.Tensor = torch.zeros( - (input.shape[0]), dtype=self.default_dtype, device=self.device - ) - constraints_dynamic_angle_1: torch.Tensor | None = None - constraints_dynamic_scale_0: torch.Tensor = torch.ones( - (input.shape[0]), dtype=self.default_dtype, device=self.device - ) - constraints_dynamic_scale_1: torch.Tensor | None = None - - assert self.reference_image is not None - assert self.reference_image.ndim == 2 - assert input.shape[-2] == self.reference_image.shape[-2] - assert input.shape[-1] == self.reference_image.shape[-1] - - images_todo = input.to(device=self.device).type(dtype=self.default_dtype) - image_reference = self.reference_image - - assert image_reference.ndim == 2 - assert images_todo.ndim == 3 - - _, newangle = self._get_ang_scale( - image_reference, - images_todo, - constraints_dynamic_scale_0, - constraints_dynamic_scale_1, - constraints_dynamic_angle_0, - constraints_dynamic_angle_1, - ) - - return newangle - - # --------------- - - def _get_pcorr_shape(self, shape: torch.Size) -> tuple[int, int]: - ret = (int(max(shape[-2:]) * 1.0),) * 2 - return ret - - def _get_log_base(self, shape: torch.Size, new_r: torch.Tensor) -> torch.Tensor: - old_r = torch.tensor( - (float(shape[-2]) * self.excess_const) / 2.0, - dtype=self.default_dtype, - device=self.device, - ) - log_base = torch.exp(torch.log(old_r) / new_r) - return log_base - - def wrap_angle( - self, angles: torch.Tensor, ceil: float = 2 * torch.pi - ) -> torch.Tensor: - angles += ceil / 2.0 - angles %= ceil - angles -= ceil / 2.0 - return angles - - def get_borderval( - self, img: torch.Tensor, radius: int | None = None - ) -> torch.Tensor: - assert img.ndim == 3 - if radius is None: - mindim = min([int(img.shape[-2]), int(img.shape[-1])]) - radius = max(1, mindim // 20) - mask = torch.zeros( - (int(img.shape[-2]), int(img.shape[-1])), - dtype=torch.bool, - device=self.device, - ) - mask[:, :radius] = True - mask[:, -radius:] = True - mask[:radius, :] = True - mask[-radius:, :] = True - - mean = torch.median(img[:, mask], dim=-1)[0] - return mean - - def get_apofield(self, shape: torch.Size, aporad: int) -> torch.Tensor: - if aporad == 0: - return torch.ones( - shape[-2:], - dtype=self.default_dtype, - device=self.device, - ) - - assert int(shape[-2]) > aporad * 2 - assert int(shape[-1]) > aporad * 2 - - apos = torch.hann_window( - aporad * 2, dtype=self.default_dtype, periodic=False, device=self.device - ) - - toapp_0 = torch.ones( - shape[-2], - dtype=self.default_dtype, - device=self.device, - ) - toapp_0[:aporad] = apos[:aporad] - toapp_0[-aporad:] = apos[-aporad:] - - toapp_1 = torch.ones( - shape[-1], - dtype=self.default_dtype, - device=self.device, - ) - toapp_1[:aporad] = apos[:aporad] - toapp_1[-aporad:] = apos[-aporad:] - - apofield = torch.outer(toapp_0, toapp_1) - - return apofield - - def _get_subarr( - self, array: torch.Tensor, center: torch.Tensor, rad: int - ) -> torch.Tensor: - assert array.ndim == 3 - assert center.ndim == 2 - assert array.shape[0] == center.shape[0] - assert center.shape[1] == 2 - - dim = 1 + 2 * rad - subarr = torch.zeros( - (array.shape[0], dim, dim), dtype=self.default_dtype, device=self.device - ) - - corner = center - rad - idx_p = range(0, corner.shape[0]) - for ii in range(0, dim): - yidx = corner[:, 0] + ii - yidx %= array.shape[-2] - for jj in range(0, dim): - xidx = corner[:, 1] + jj - xidx %= array.shape[-1] - subarr[:, ii, jj] = array[idx_p, yidx, xidx] - - return subarr - - def _argmax_2d(self, array: torch.Tensor) -> torch.Tensor: - assert array.ndim == 3 - - max_pos = array.reshape( - (array.shape[0], array.shape[1] * array.shape[2]) - ).argmax(dim=1) - pos_0 = max_pos // array.shape[2] - - max_pos -= pos_0 * array.shape[2] - ret = torch.zeros( - (array.shape[0], 2), dtype=self.default_dtype, device=self.device - ) - ret[:, 0] = pos_0 - ret[:, 1] = max_pos - - return ret.type(dtype=torch.int64) - - def _apodize(self, what: torch.Tensor) -> torch.Tensor: - mindim = min([int(what.shape[-2]), int(what.shape[-1])]) - aporad = int(mindim * 0.12) - - apofield = self.get_apofield(what.shape, aporad).unsqueeze(0) - - res = what * apofield - bg = self.get_borderval(what, aporad // 2).unsqueeze(-1).unsqueeze(-1) - res += bg * (1 - apofield) - return res - - def _logpolar_filter(self, shape: torch.Size) -> torch.Tensor: - yy = torch.linspace( - -torch.pi / 2.0, - torch.pi / 2.0, - shape[-2], - dtype=self.default_dtype, - device=self.device, - ).unsqueeze(1) - - xx = torch.linspace( - -torch.pi / 2.0, - torch.pi / 2.0, - shape[-1], - dtype=self.default_dtype, - device=self.device, - ).unsqueeze(0) - - rads = torch.sqrt(yy**2 + xx**2) - filt = 1.0 - torch.cos(rads) ** 2 - - filt[torch.abs(rads) > torch.pi / 2] = 1 - return filt - - def _get_angles(self, shape: torch.Tensor) -> torch.Tensor: - ret = torch.zeros( - (int(shape[-2]), int(shape[-1])), - dtype=self.default_dtype, - device=self.device, - ) - ret -= torch.linspace( - 0, - torch.pi, - int(shape[-2] + 1), - dtype=self.default_dtype, - device=self.device, - )[:-1].unsqueeze(-1) - - return ret - - def _get_lograd(self, shape: torch.Tensor, log_base: torch.Tensor) -> torch.Tensor: - ret = torch.zeros( - (int(shape[-2]), int(shape[-1])), - dtype=self.default_dtype, - device=self.device, - ) - ret += torch.pow( - log_base, - torch.arange( - 0, - int(shape[-1]), - dtype=self.default_dtype, - device=self.device, - ), - ).unsqueeze(0) - return ret - - def _logpolar( - self, image: torch.Tensor, shape: torch.Tensor, log_base: torch.Tensor - ) -> torch.Tensor: - assert image.ndim == 3 - - imshape: torch.Tensor = torch.tensor( - image.shape[-2:], - dtype=self.default_dtype, - device=self.device, - ) - - center: torch.Tensor = imshape.clone() / 2 - - theta: torch.Tensor = self._get_angles(shape) - radius_x: torch.Tensor = self._get_lograd(shape, log_base) - radius_y: torch.Tensor = radius_x.clone() - - ellipse_coef: torch.Tensor = imshape[0] / imshape[1] - radius_x /= ellipse_coef - - y = radius_y * torch.sin(theta) + center[0] - y /= float(image.shape[-2]) - y *= 2 - y -= 1 - - x = radius_x * torch.cos(theta) + center[1] - x /= float(image.shape[-1]) - x *= 2 - x -= 1 - - idx_x = torch.where(torch.abs(x) <= 1.0, 1.0, 0.0) - idx_y = torch.where(torch.abs(y) <= 1.0, 1.0, 0.0) - - normalized_coords = torch.cat( - ( - x.unsqueeze(-1), - y.unsqueeze(-1), - ), - dim=-1, - ).unsqueeze(0) - - output = torch.empty( - (int(image.shape[0]), int(y.shape[0]), int(y.shape[1])), - dtype=self.default_dtype, - device=self.device, - ) - - for id in range(0, int(image.shape[0])): - bgval: torch.Tensor = torch.quantile(image[id, :, :], q=1.0 / 100.0) - - temp = torch.nn.functional.grid_sample( - image[id, :, :].unsqueeze(0).unsqueeze(0), - normalized_coords, - mode="bilinear", - padding_mode="zeros", - align_corners=False, - ) - - output[id, :, :] = torch.where((idx_x * idx_y) == 0.0, bgval, temp) - - return output - - def _argmax_ext(self, array: torch.Tensor, exponent: float | str) -> torch.Tensor: - assert array.ndim == 3 - - if exponent == "inf": - ret = self._argmax_2d(array) - else: - assert isinstance(exponent, float) or isinstance(exponent, int) - - col = ( - torch.arange( - 0, array.shape[-2], dtype=self.default_dtype, device=self.device - ) - .unsqueeze(-1) - .unsqueeze(0) - ) - row = ( - torch.arange( - 0, array.shape[-1], dtype=self.default_dtype, device=self.device - ) - .unsqueeze(0) - .unsqueeze(0) - ) - - arr2 = torch.pow(array, float(exponent)) - arrsum = arr2.sum(dim=-2).sum(dim=-1) - - ret = torch.zeros( - (array.shape[0], 2), dtype=self.default_dtype, device=self.device - ) - - arrprody = (arr2 * col).sum(dim=-1).sum(dim=-1) / arrsum - arrprodx = (arr2 * row).sum(dim=-1).sum(dim=-1) / arrsum - - ret[:, 0] = arrprody.squeeze(-1).squeeze(-1) - ret[:, 1] = arrprodx.squeeze(-1).squeeze(-1) - - idx = torch.where(arrsum == 0.0)[0] - ret[idx, :] = 0.0 - return ret - - def _interpolate( - self, array: torch.Tensor, rough: torch.Tensor, rad: int = 2 - ) -> torch.Tensor: - assert array.ndim == 3 - assert rough.ndim == 2 - - rough = torch.round(rough).type(torch.int64) - - surroundings = self._get_subarr(array, rough, rad) - - com = self._argmax_ext(surroundings, 1.0) - - offset = com - rad - ret = rough + offset - - ret += 0.5 - ret %= ( - torch.tensor(array.shape[-2:], dtype=self.default_dtype, device=self.device) - .type(dtype=torch.int64) - .unsqueeze(0) - ) - ret -= 0.5 - return ret - - def _get_success( - self, array: torch.Tensor, coord: torch.Tensor, radius: int = 2 - ) -> torch.Tensor: - assert array.ndim == 3 - assert coord.ndim == 2 - assert array.shape[0] == coord.shape[0] - assert coord.shape[1] == 2 - - coord = torch.round(coord).type(dtype=torch.int64) - subarr = self._get_subarr( - array, coord, 2 - ) # Not my fault. They want a 2 there. Not radius - - theval = subarr.sum(dim=-1).sum(dim=-1) - - theval2 = array[range(0, coord.shape[0]), coord[:, 0], coord[:, 1]] - - success = torch.sqrt(theval * theval2) - return success - - def _get_constraint_mask( - self, - shape: torch.Size, - log_base: torch.Tensor, - constraints_scale_0: torch.Tensor, - constraints_scale_1: torch.Tensor | None, - constraints_angle_0: torch.Tensor, - constraints_angle_1: torch.Tensor | None, - ) -> torch.Tensor: - assert constraints_scale_0 is not None - assert constraints_angle_0 is not None - assert constraints_scale_0.ndim == 1 - assert constraints_angle_0.ndim == 1 - - assert constraints_scale_0.shape[0] == constraints_angle_0.shape[0] - - mask: torch.Tensor = torch.ones( - (constraints_scale_0.shape[0], int(shape[-2]), int(shape[-1])), - device=self.device, - dtype=self.default_dtype, - ) - - scale: torch.Tensor = constraints_scale_0.clone() - if constraints_scale_1 is not None: - sigma: torch.Tensor | None = constraints_scale_1.clone() - else: - sigma = None - - scales = torch.fft.ifftshift( - self._get_lograd( - torch.tensor(shape[-2:], device=self.device, dtype=self.default_dtype), - log_base, - ) - ) - - scales *= log_base ** (-shape[-1] / 2.0) - scales = scales.unsqueeze(0) - (1.0 / scale).unsqueeze(-1).unsqueeze(-1) - - if sigma is not None: - assert sigma.shape[0] == constraints_scale_0.shape[0] - - for p_id in range(0, sigma.shape[0]): - if sigma[p_id] == 0: - ascales = torch.abs(scales[p_id, ...]) - scale_min = ascales.min() - binary_mask = torch.where(ascales > scale_min, 0.0, 1.0) - mask[p_id, ...] *= binary_mask - else: - mask[p_id, ...] *= torch.exp( - -(torch.pow(scales[p_id, ...], 2)) / torch.pow(sigma[p_id], 2) - ) - - angle: torch.Tensor = constraints_angle_0.clone() - if constraints_angle_1 is not None: - sigma = constraints_angle_1.clone() - else: - sigma = None - - angles = self._get_angles( - torch.tensor(shape[-2:], device=self.device, dtype=self.default_dtype) - ) - - angles = angles.unsqueeze(0) + torch.deg2rad(angle).unsqueeze(-1).unsqueeze(-1) - - angles = torch.rad2deg(angles) - - if sigma is not None: - assert sigma.shape[0] == constraints_scale_0.shape[0] - - for p_id in range(0, sigma.shape[0]): - if sigma[p_id] == 0: - aangles = torch.abs(angles[p_id, ...]) - angle_min = aangles.min() - binary_mask = torch.where(aangles > angle_min, 0.0, 1.0) - mask[p_id, ...] *= binary_mask - else: - mask *= torch.exp( - -(torch.pow(angles[p_id, ...], 2)) / torch.pow(sigma[p_id], 2) - ) - - mask = torch.fft.fftshift(mask, dim=(-2, -1)) - - return mask - - def argmax_angscale( - self, - array: torch.Tensor, - log_base: torch.Tensor, - constraints_scale_0: torch.Tensor, - constraints_scale_1: torch.Tensor | None, - constraints_angle_0: torch.Tensor, - constraints_angle_1: torch.Tensor | None, - ) -> tuple[torch.Tensor, torch.Tensor]: - assert array.ndim == 3 - assert constraints_scale_0 is not None - assert constraints_angle_0 is not None - assert constraints_scale_0.ndim == 1 - assert constraints_angle_0.ndim == 1 - - mask = self._get_constraint_mask( - array.shape[-2:], - log_base, - constraints_scale_0, - constraints_scale_1, - constraints_angle_0, - constraints_angle_1, - ) - - array_orig = array.clone() - - array *= mask - ret = self._argmax_ext(array, self.exponent) - - ret_final = self._interpolate(array, ret) - - success = self._get_success(array_orig, ret_final, 0) - - return ret_final, success - - def argmax_translation( - self, array: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - assert array.ndim == 3 - - array_orig = array.clone() - - ashape = torch.tensor(array.shape[-2:], device=self.device).type( - dtype=torch.int64 - ) - - aporad = (ashape // 6).min() - mask2 = self.get_apofield(torch.Size(ashape), aporad).unsqueeze(0) - array *= mask2 - - tvec = self._argmax_ext(array, "inf") - - tvec = self._interpolate(array_orig, tvec) - - success = self._get_success(array_orig, tvec, 2) - - return tvec, success - - def transform_img( - self, - img: torch.Tensor, - scale: torch.Tensor | None = None, - angle: torch.Tensor | None = None, - tvec: torch.Tensor | None = None, - bgval: torch.Tensor | None = None, - ) -> torch.Tensor: - assert img.ndim == 3 - - if scale is None: - scale = torch.ones( - (img.shape[0],), dtype=self.default_dtype, device=self.device - ) - assert scale.ndim == 1 - assert scale.shape[0] == img.shape[0] - - if angle is None: - angle = torch.zeros( - (img.shape[0],), dtype=self.default_dtype, device=self.device - ) - assert angle.ndim == 1 - assert angle.shape[0] == img.shape[0] - - if tvec is None: - tvec = torch.zeros( - (img.shape[0], 2), dtype=self.default_dtype, device=self.device - ) - assert tvec.ndim == 2 - assert tvec.shape[0] == img.shape[0] - assert tvec.shape[1] == 2 - - if bgval is None: - bgval = self.get_borderval(img) - assert bgval.ndim == 1 - assert bgval.shape[0] == img.shape[0] - - # Otherwise we need to decompose it and put it back together - assert torch.is_complex(img) is False - - output = torch.zeros_like(img) - - for pos in range(0, img.shape[0]): - image_processed = img[pos, :, :].unsqueeze(0).clone() - - temp_shift = [ - int(round(tvec[pos, 1].item() * self.scale_factor)), - int(round(tvec[pos, 0].item() * self.scale_factor)), - ] - - image_processed = torch.nn.functional.interpolate( - image_processed.unsqueeze(0), - scale_factor=self.scale_factor, - mode="bilinear", - ).squeeze(0) - - image_processed = tv.transforms.functional.affine( - img=image_processed, - angle=-float(angle[pos]), - translate=temp_shift, - scale=float(scale[pos]), - shear=[0, 0], - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=float(bgval[pos]), - center=None, - ) - - image_processed = torch.nn.functional.interpolate( - image_processed.unsqueeze(0), - scale_factor=1.0 / self.scale_factor, - mode="bilinear", - ).squeeze(0) - - image_processed = tv.transforms.functional.center_crop( - image_processed, img.shape[-2:] - ) - - output[pos, ...] = image_processed.squeeze(0) - - return output - - def transform_img_dict( - self, - img: torch.Tensor, - scale: torch.Tensor | None = None, - angle: torch.Tensor | None = None, - tvec: torch.Tensor | None = None, - bgval: torch.Tensor | None = None, - invert=False, - ) -> torch.Tensor: - if invert: - if scale is not None: - scale = 1.0 / scale - if angle is not None: - angle *= -1 - if tvec is not None: - tvec *= -1 - - res = self.transform_img(img, scale, angle, tvec, bgval=bgval) - return res - - def _phase_correlation( - self, image_reference: torch.Tensor, images_todo: torch.Tensor, callback, *args - ) -> tuple[torch.Tensor, torch.Tensor]: - assert image_reference.ndim == 3 - assert image_reference.shape[0] == 1 - assert images_todo.ndim == 3 - - assert callback is not None - - image_reference_fft = torch.fft.fft2(image_reference, dim=(-2, -1)) - images_todo_fft = torch.fft.fft2(images_todo, dim=(-2, -1)) - - eps = torch.abs(images_todo_fft).max(dim=-1)[0].max(dim=-1)[0] * 1e-15 - - cps = abs( - torch.fft.ifft2( - (image_reference_fft * images_todo_fft.conj()) - / ( - torch.abs(image_reference_fft) * torch.abs(images_todo_fft) - + eps.unsqueeze(-1).unsqueeze(-1) - ), - dim=(-2, -1), - ) - ) - - scps = torch.fft.fftshift(cps, dim=(-2, -1)) - - ret, success = callback(scps, *args) - - ret[:, 0] -= image_reference_fft.shape[-2] // 2 - ret[:, 1] -= image_reference_fft.shape[-1] // 2 - - return ret, success - - def _translation( - self, im0: torch.Tensor, im1: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - assert im0.ndim == 2 - ret, succ = self._phase_correlation( - im0.unsqueeze(0), im1, self.argmax_translation - ) - - return ret, succ - - def _get_ang_scale( - self, - image_reference: torch.Tensor, - images_todo: torch.Tensor, - constraints_scale_0: torch.Tensor, - constraints_scale_1: torch.Tensor | None, - constraints_angle_0: torch.Tensor, - constraints_angle_1: torch.Tensor | None, - ) -> tuple[torch.Tensor, torch.Tensor]: - assert image_reference.ndim == 2 - assert images_todo.ndim == 3 - assert image_reference.shape[-1] == images_todo.shape[-1] - assert image_reference.shape[-2] == images_todo.shape[-2] - assert constraints_scale_0.shape[0] == images_todo.shape[0] - assert constraints_angle_0.shape[0] == images_todo.shape[0] - - if constraints_scale_1 is not None: - assert constraints_scale_1.shape[0] == images_todo.shape[0] - - if constraints_angle_1 is not None: - assert constraints_angle_1.shape[0] == images_todo.shape[0] - - if self.image_reference_dft is None: - image_reference_apod = self._apodize(image_reference.unsqueeze(0)) - self.image_reference_dft = torch.fft.fftshift( - torch.fft.fft2(image_reference_apod, dim=(-2, -1)), dim=(-2, -1) - ) - self.filt = self._logpolar_filter(image_reference.shape).unsqueeze(0) - self.image_reference_dft *= self.filt - self.pcorr_shape = torch.tensor( - self._get_pcorr_shape(image_reference.shape[-2:]), - dtype=self.default_dtype, - device=self.device, - ) - self.log_base = self._get_log_base( - image_reference.shape, - self.pcorr_shape[1], - ) - self.image_reference_logp = self._logpolar( - torch.abs(self.image_reference_dft), self.pcorr_shape, self.log_base - ) - - images_todo_apod = self._apodize(images_todo) - images_todo_dft = torch.fft.fftshift( - torch.fft.fft2(images_todo_apod, dim=(-2, -1)), dim=(-2, -1) - ) - - images_todo_dft *= self.filt - - images_todo_lopg = self._logpolar( - torch.abs(images_todo_dft), self.pcorr_shape, self.log_base - ) - - temp, _ = self._phase_correlation( - self.image_reference_logp, - images_todo_lopg, - self.argmax_angscale, - self.log_base, - constraints_scale_0, - constraints_scale_1, - constraints_angle_0, - constraints_angle_1, - ) - - arg_ang = temp[:, 0].clone() - arg_rad = temp[:, 1].clone() - - angle = -torch.pi * arg_ang / float(self.pcorr_shape[0]) - angle = torch.rad2deg(angle) - - angle = self.wrap_angle(angle, 360) - - scale = torch.pow(self.log_base, arg_rad) - - angle = -angle - scale = 1.0 / scale - - assert torch.where(scale < 2)[0].shape[0] == scale.shape[0] - assert torch.where(scale > 0.5)[0].shape[0] == scale.shape[0] - - return scale, angle - - def translation( - self, im0: torch.Tensor, im1: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - angle = torch.zeros( - (im1.shape[0]), dtype=self.default_dtype, device=self.device - ) - assert im1.ndim == 3 - assert im0.shape[-2] == im1.shape[-2] - assert im0.shape[-1] == im1.shape[-1] - - tvec, succ = self._translation(im0, im1) - tvec2, succ2 = self._translation(im0, torch.rot90(im1, k=2, dims=[-2, -1])) - - assert tvec.shape[0] == tvec2.shape[0] - assert tvec.ndim == 2 - assert tvec2.ndim == 2 - assert tvec.shape[1] == 2 - assert tvec2.shape[1] == 2 - assert succ.shape[0] == succ2.shape[0] - assert succ.ndim == 1 - assert succ2.ndim == 1 - assert tvec.shape[0] == succ.shape[0] - assert angle.shape[0] == tvec.shape[0] - assert angle.ndim == 1 - - for pos in range(0, angle.shape[0]): - pick_rotated = False - if succ2[pos] > succ[pos]: - pick_rotated = True - - if pick_rotated: - tvec[pos, :] = tvec2[pos, :] - succ[pos] = succ2[pos] - angle[pos] += 180 - - return tvec, succ, angle - - def _similarity( - self, - image_reference: torch.Tensor, - images_todo: torch.Tensor, - bgval: torch.Tensor, - ): - assert image_reference.ndim == 2 - assert images_todo.ndim == 3 - assert image_reference.shape[-1] == images_todo.shape[-1] - assert image_reference.shape[-2] == images_todo.shape[-2] - - # We are going to iterate and precise scale and angle estimates - scale: torch.Tensor = torch.ones( - (images_todo.shape[0]), dtype=self.default_dtype, device=self.device - ) - angle: torch.Tensor = torch.zeros( - (images_todo.shape[0]), dtype=self.default_dtype, device=self.device - ) - - constraints_dynamic_angle_0: torch.Tensor = torch.zeros( - (images_todo.shape[0]), dtype=self.default_dtype, device=self.device - ) - constraints_dynamic_angle_1: torch.Tensor | None = None - constraints_dynamic_scale_0: torch.Tensor = torch.ones( - (images_todo.shape[0]), dtype=self.default_dtype, device=self.device - ) - constraints_dynamic_scale_1: torch.Tensor | None = None - - newscale, newangle = self._get_ang_scale( - image_reference, - images_todo, - constraints_dynamic_scale_0, - constraints_dynamic_scale_1, - constraints_dynamic_angle_0, - constraints_dynamic_angle_1, - ) - scale *= newscale - angle += newangle - - im2 = self.transform_img(images_todo, scale, angle, bgval=bgval) - - tvec, self.success, res_angle = self.translation(image_reference, im2) - - angle += res_angle - - angle = self.wrap_angle(angle, 360) - - return scale, angle, tvec - - def similarity( - self, - image_reference: torch.Tensor, - images_todo: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - assert image_reference.ndim == 2 - assert images_todo.ndim == 3 - - bgval: torch.Tensor = self.get_borderval(img=images_todo, radius=5) - - scale, angle, tvec = self._similarity( - image_reference, - images_todo, - bgval, - ) - - im2 = self.transform_img_dict( - img=images_todo, - scale=scale, - angle=angle, - tvec=tvec, - bgval=bgval, - ) - - return scale, angle, tvec, im2 diff --git a/reproduction_effort/functions/adjust_factor.py b/reproduction_effort/functions/adjust_factor.py deleted file mode 100644 index 2adc4e3..0000000 --- a/reproduction_effort/functions/adjust_factor.py +++ /dev/null @@ -1,95 +0,0 @@ -import torch -import math - - -def adjust_factor( - input_acceptor: torch.Tensor, - input_donor: torch.Tensor, - lower_frequency_heartbeat: float, - upper_frequency_heartbeat: float, - sample_frequency: float, - mask: torch.Tensor, - power_factors: None | list[float], -) -> tuple[float, float]: - - number_of_active_pixel: torch.Tensor = mask.type(dtype=torch.float32).sum() - signal_acceptor: torch.Tensor = (input_acceptor * mask.unsqueeze(-1)).sum( - dim=0 - ).sum(dim=0) / number_of_active_pixel - - signal_donor: torch.Tensor = (input_donor * mask.unsqueeze(-1)).sum(dim=0).sum( - dim=0 - ) / number_of_active_pixel - - signal_acceptor_offset = signal_acceptor.mean() - signal_donor_offset = signal_donor.mean() - - if power_factors is None: - signal_acceptor = signal_acceptor - signal_acceptor_offset - signal_donor = signal_donor - signal_donor_offset - - blackman_window = torch.blackman_window( - window_length=signal_acceptor.shape[0], - periodic=True, - dtype=signal_acceptor.dtype, - device=signal_acceptor.device, - ) - - signal_acceptor *= blackman_window - signal_donor *= blackman_window - nfft: int = int(2 ** math.ceil(math.log2(signal_donor.shape[0]))) - nfft = max([256, nfft]) - - signal_acceptor_fft: torch.Tensor = torch.fft.rfft(signal_acceptor, n=nfft) - signal_donor_fft: torch.Tensor = torch.fft.rfft(signal_donor, n=nfft) - - frequency_axis: torch.Tensor = ( - torch.fft.rfftfreq(nfft, device=signal_acceptor_fft.device) - * sample_frequency - ) - - signal_acceptor_power: torch.Tensor = torch.abs(signal_acceptor_fft) ** 2 - signal_acceptor_power[1:-1] *= 2 - - signal_donor_power: torch.Tensor = torch.abs(signal_donor_fft) ** 2 - signal_donor_power[1:-1] *= 2 - - if frequency_axis[-1] != (sample_frequency / 2.0): - signal_acceptor_power[-1] *= 2 - signal_donor_power[-1] *= 2 - - signal_acceptor_power /= blackman_window.sum() ** 2 - signal_donor_power /= blackman_window.sum() ** 2 - - idx = torch.where( - (frequency_axis >= lower_frequency_heartbeat) - * (frequency_axis <= upper_frequency_heartbeat) - )[0] - - frequency_axis = frequency_axis[idx] - signal_acceptor_power = signal_acceptor_power[idx] - signal_donor_power = signal_donor_power[idx] - - acceptor_range: float = float( - signal_acceptor_power.max() - signal_acceptor_power.min() - ) - - donor_range: float = float(signal_donor_power.max() - signal_donor_power.min()) - else: - donor_range = float(power_factors[0]) - acceptor_range = float(power_factors[1]) - - acceptor_correction_factor: float = float( - 0.5 - * ( - 1 - + (signal_acceptor_offset * math.sqrt(donor_range)) - / (signal_donor_offset * math.sqrt(acceptor_range)) - ) - ) - - donor_correction_factor: float = float( - acceptor_correction_factor / (2 * acceptor_correction_factor - 1) - ) - - return donor_correction_factor, acceptor_correction_factor diff --git a/reproduction_effort/functions/align_cameras.py b/reproduction_effort/functions/align_cameras.py deleted file mode 100644 index 97d52bb..0000000 --- a/reproduction_effort/functions/align_cameras.py +++ /dev/null @@ -1,162 +0,0 @@ -import torch -import torchvision as tv # type: ignore - -from functions.align_refref import align_refref -from functions.perform_donor_volume_rotation import perform_donor_volume_rotation -from functions.perform_donor_volume_translation import perform_donor_volume_translation -from functions.ImageAlignment import ImageAlignment - - -@torch.no_grad() -def align_cameras( - channels: list[str], - data: torch.Tensor, - ref_image: torch.Tensor, - device: torch.device, - dtype: torch.dtype, - batch_size: int, - fill_value: float = 0, -) -> tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, -]: - image_alignment = ImageAlignment(default_dtype=dtype, device=device) - - # --- Get reference image --- - acceptor_index: int = channels.index("acceptor") - donor_index: int = channels.index("donor") - oxygenation_index: int = channels.index("oxygenation") - volume_index: int = channels.index("volume") - - # --==-- DONE --==-- - - # --- Sort data --- - acceptor = data[..., acceptor_index].moveaxis(-1, 0).clone() - donor = data[..., donor_index].moveaxis(-1, 0).clone() - oxygenation = data[..., oxygenation_index].moveaxis(-1, 0).clone() - volume = data[..., volume_index].moveaxis(-1, 0).clone() - - ref_image_acceptor = ref_image[..., acceptor_index].clone() - ref_image_donor = ref_image[..., donor_index].clone() - ref_image_oxygenation = ref_image[..., oxygenation_index].clone() - ref_image_volume = ref_image[..., volume_index].clone() - del data - # --==-- DONE --==-- - - # --- Calculate translation and rotation between the reference images --- - angle_refref, tvec_refref, ref_image_acceptor, ref_image_donor = align_refref( - ref_image_acceptor=ref_image_acceptor, - ref_image_donor=ref_image_donor, - image_alignment=image_alignment, - batch_size=batch_size, - fill_value=fill_value, - ) - - ref_image_oxygenation = tv.transforms.functional.affine( - img=ref_image_oxygenation.unsqueeze(0), - angle=-float(angle_refref), - translate=[0, 0], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ) - - ref_image_oxygenation = tv.transforms.functional.affine( - img=ref_image_oxygenation, - angle=0, - translate=[tvec_refref[1], tvec_refref[0]], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ).squeeze(0) - - # --==-- DONE --==-- - - # --- Rotate and translate the acceptor and oxygenation data accordingly --- - acceptor = tv.transforms.functional.affine( - img=acceptor, - angle=-float(angle_refref), - translate=[0, 0], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ) - - acceptor = tv.transforms.functional.affine( - img=acceptor, - angle=0, - translate=[tvec_refref[1], tvec_refref[0]], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ) - - oxygenation = tv.transforms.functional.affine( - img=oxygenation, - angle=-float(angle_refref), - translate=[0, 0], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ) - - oxygenation = tv.transforms.functional.affine( - img=oxygenation, - angle=0, - translate=[tvec_refref[1], tvec_refref[0]], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ) - # --==-- DONE --==-- - - acceptor, donor, oxygenation, volume, angle_donor_volume = ( - perform_donor_volume_rotation( - acceptor=acceptor, - donor=donor, - oxygenation=oxygenation, - volume=volume, - ref_image_donor=ref_image_donor, - ref_image_volume=ref_image_volume, - image_alignment=image_alignment, - batch_size=batch_size, - fill_value=fill_value, - ) - ) - - acceptor, donor, oxygenation, volume, tvec_donor_volume = ( - perform_donor_volume_translation( - acceptor=acceptor, - donor=donor, - oxygenation=oxygenation, - volume=volume, - ref_image_donor=ref_image_donor, - ref_image_volume=ref_image_volume, - image_alignment=image_alignment, - batch_size=batch_size, - fill_value=fill_value, - ) - ) - - return ( - acceptor, - donor, - oxygenation, - volume, - angle_donor_volume, - tvec_donor_volume, - angle_refref, - tvec_refref, - ) diff --git a/reproduction_effort/functions/align_refref.py b/reproduction_effort/functions/align_refref.py deleted file mode 100644 index 7361849..0000000 --- a/reproduction_effort/functions/align_refref.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch -import torchvision as tv # type: ignore - -from functions.ImageAlignment import ImageAlignment -from functions.calculate_translation import calculate_translation -from functions.calculate_rotation import calculate_rotation - - -@torch.no_grad() -def align_refref( - ref_image_acceptor: torch.Tensor, - ref_image_donor: torch.Tensor, - image_alignment: ImageAlignment, - batch_size: int, - fill_value: float = 0, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - - angle_refref = calculate_rotation( - image_alignment=image_alignment, - input=ref_image_acceptor.unsqueeze(0), - reference_image=ref_image_donor, - batch_size=batch_size, - ) - - ref_image_acceptor = tv.transforms.functional.affine( - img=ref_image_acceptor.unsqueeze(0), - angle=-float(angle_refref), - translate=[0, 0], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ) - - tvec_refref = calculate_translation( - image_alignment=image_alignment, - input=ref_image_acceptor, - reference_image=ref_image_donor, - batch_size=batch_size, - ) - - tvec_refref = tvec_refref[0, :] - - ref_image_acceptor = tv.transforms.functional.affine( - img=ref_image_acceptor, - angle=0, - translate=[tvec_refref[1], tvec_refref[0]], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ).squeeze(0) - - return angle_refref, tvec_refref, ref_image_acceptor, ref_image_donor diff --git a/reproduction_effort/functions/bandpass.py b/reproduction_effort/functions/bandpass.py deleted file mode 100644 index 171baf5..0000000 --- a/reproduction_effort/functions/bandpass.py +++ /dev/null @@ -1,85 +0,0 @@ -import torchaudio as ta # type: ignore -import torch - - -@torch.no_grad() -def filtfilt( - input: torch.Tensor, - butter_a: torch.Tensor, - butter_b: torch.Tensor, -) -> torch.Tensor: - assert butter_a.ndim == 1 - assert butter_b.ndim == 1 - assert butter_a.shape[0] == butter_b.shape[0] - - process_data: torch.Tensor = input.detach().clone() - - padding_length = 12 * int(butter_a.shape[0]) - left_padding = 2 * process_data[..., 0].unsqueeze(-1) - process_data[ - ..., 1 : padding_length + 1 - ].flip(-1) - right_padding = 2 * process_data[..., -1].unsqueeze(-1) - process_data[ - ..., -(padding_length + 1) : -1 - ].flip(-1) - process_data_padded = torch.cat((left_padding, process_data, right_padding), dim=-1) - - output = ta.functional.filtfilt( - process_data_padded.unsqueeze(0), butter_a, butter_b, clamp=False - ).squeeze(0) - - output = output[..., padding_length:-padding_length] - return output - - -@torch.no_grad() -def butter_bandpass( - device: torch.device, - low_frequency: float = 0.1, - high_frequency: float = 1.0, - fs: float = 30.0, -) -> tuple[torch.Tensor, torch.Tensor]: - import scipy # type: ignore - - butter_b_np, butter_a_np = scipy.signal.butter( - 4, [low_frequency, high_frequency], btype="bandpass", output="ba", fs=fs - ) - butter_a = torch.tensor(butter_a_np, device=device, dtype=torch.float32) - butter_b = torch.tensor(butter_b_np, device=device, dtype=torch.float32) - return butter_a, butter_b - - -@torch.no_grad() -def chunk_iterator(array: torch.Tensor, chunk_size: int): - for i in range(0, array.shape[0], chunk_size): - yield array[i : i + chunk_size] - - -@torch.no_grad() -def bandpass( - data: torch.Tensor, - device: torch.device, - low_frequency: float = 0.1, - high_frequency: float = 1.0, - fs=30.0, - filtfilt_chuck_size: int = 10, -) -> torch.Tensor: - butter_a, butter_b = butter_bandpass( - device=device, - low_frequency=low_frequency, - high_frequency=high_frequency, - fs=fs, - ) - - index_full_dataset: torch.Tensor = torch.arange( - 0, data.shape[1], device=device, dtype=torch.int64 - ) - - for chunk in chunk_iterator(index_full_dataset, filtfilt_chuck_size): - temp_filtfilt = filtfilt( - data[:, chunk, :], - butter_a=butter_a, - butter_b=butter_b, - ) - data[:, chunk, :] = temp_filtfilt - - return data diff --git a/reproduction_effort/functions/binning.py b/reproduction_effort/functions/binning.py deleted file mode 100644 index 5e1cebb..0000000 --- a/reproduction_effort/functions/binning.py +++ /dev/null @@ -1,20 +0,0 @@ -import torch - - -def binning( - data: torch.Tensor, - kernel_size: int = 4, - stride: int = 4, - divisor_override: int | None = 1, -) -> torch.Tensor: - - return ( - torch.nn.functional.avg_pool2d( - input=data.movedim(0, -1).movedim(0, -1), - kernel_size=kernel_size, - stride=stride, - divisor_override=divisor_override, - ) - .movedim(-1, 0) - .movedim(-1, 0) - ) diff --git a/reproduction_effort/functions/calculate_rotation.py b/reproduction_effort/functions/calculate_rotation.py deleted file mode 100644 index 6a53afd..0000000 --- a/reproduction_effort/functions/calculate_rotation.py +++ /dev/null @@ -1,40 +0,0 @@ -import torch - -from functions.ImageAlignment import ImageAlignment - - -@torch.no_grad() -def calculate_rotation( - image_alignment: ImageAlignment, - input: torch.Tensor, - reference_image: torch.Tensor, - batch_size: int, -) -> torch.Tensor: - angle = torch.zeros((input.shape[0])) - - data_loader = torch.utils.data.DataLoader( - torch.utils.data.TensorDataset(input), - batch_size=batch_size, - shuffle=False, - ) - start_position: int = 0 - for input_batch in data_loader: - assert len(input_batch) == 1 - - end_position = start_position + input_batch[0].shape[0] - - angle_temp = image_alignment.dry_run_angle( - input=input_batch[0], - new_reference_image=reference_image, - ) - - assert angle_temp is not None - - angle[start_position:end_position] = angle_temp - - start_position += input_batch[0].shape[0] - - angle = torch.where(angle >= 180, 360.0 - angle, angle) - angle = torch.where(angle <= -180, 360.0 + angle, angle) - - return angle diff --git a/reproduction_effort/functions/calculate_translation.py b/reproduction_effort/functions/calculate_translation.py deleted file mode 100644 index 9eadf59..0000000 --- a/reproduction_effort/functions/calculate_translation.py +++ /dev/null @@ -1,37 +0,0 @@ -import torch - -from functions.ImageAlignment import ImageAlignment - - -@torch.no_grad() -def calculate_translation( - image_alignment: ImageAlignment, - input: torch.Tensor, - reference_image: torch.Tensor, - batch_size: int, -) -> torch.Tensor: - tvec = torch.zeros((input.shape[0], 2)) - - data_loader = torch.utils.data.DataLoader( - torch.utils.data.TensorDataset(input), - batch_size=batch_size, - shuffle=False, - ) - start_position: int = 0 - for input_batch in data_loader: - assert len(input_batch) == 1 - - end_position = start_position + input_batch[0].shape[0] - - tvec_temp = image_alignment.dry_run_translation( - input=input_batch[0], - new_reference_image=reference_image, - ) - - assert tvec_temp is not None - - tvec[start_position:end_position, :] = tvec_temp - - start_position += input_batch[0].shape[0] - - return tvec diff --git a/reproduction_effort/functions/convert_camera_sequenc_to_list.py b/reproduction_effort/functions/convert_camera_sequenc_to_list.py deleted file mode 100644 index 8aa2058..0000000 --- a/reproduction_effort/functions/convert_camera_sequenc_to_list.py +++ /dev/null @@ -1,13 +0,0 @@ -import torch - - -@torch.no_grad() -def convert_camera_sequenc_to_list( - data: torch.Tensor, required_order: list[str], cameras: list[str] -) -> list[torch.Tensor]: - camera_sequence: list[torch.Tensor] = [] - - for cam in required_order: - camera_sequence.append(data[:, :, :, cameras.index(cam)].clone()) - - return camera_sequence diff --git a/reproduction_effort/functions/gauss_smear.py b/reproduction_effort/functions/gauss_smear.py deleted file mode 100644 index 15abfed..0000000 --- a/reproduction_effort/functions/gauss_smear.py +++ /dev/null @@ -1,41 +0,0 @@ -import torch -from functions.gauss_smear_individual import gauss_smear_individual - - -@torch.no_grad() -def gauss_smear( - input_cameras: list[torch.Tensor], - input_mask: torch.Tensor, - spatial_width: float, - temporal_width: float, - use_matlab_mask: bool = True, - epsilon: float = float(torch.finfo(torch.float64).eps), -) -> list[torch.Tensor]: - assert len(input_cameras) == 4 - - filtered_mask: torch.Tensor - filtered_mask, _ = gauss_smear_individual( - input=input_mask, - spatial_width=spatial_width, - temporal_width=temporal_width, - use_matlab_mask=use_matlab_mask, - epsilon=epsilon, - ) - - overwrite_fft_gauss: None | torch.Tensor = None - for id in range(0, len(input_cameras)): - - input_cameras[id] *= input_mask.unsqueeze(-1) - input_cameras[id], overwrite_fft_gauss = gauss_smear_individual( - input=input_cameras[id], - spatial_width=spatial_width, - temporal_width=temporal_width, - overwrite_fft_gauss=overwrite_fft_gauss, - use_matlab_mask=use_matlab_mask, - epsilon=epsilon, - ) - - input_cameras[id] /= filtered_mask + 1e-20 - input_cameras[id] += 1.0 - input_mask.unsqueeze(-1) - - return input_cameras diff --git a/reproduction_effort/functions/gauss_smear_individual.py b/reproduction_effort/functions/gauss_smear_individual.py deleted file mode 100644 index 36700e7..0000000 --- a/reproduction_effort/functions/gauss_smear_individual.py +++ /dev/null @@ -1,127 +0,0 @@ -import torch -import math - - -@torch.no_grad() -def gauss_smear_individual( - input: torch.Tensor, - spatial_width: float, - temporal_width: float, - overwrite_fft_gauss: None | torch.Tensor = None, - use_matlab_mask: bool = True, - epsilon: float = float(torch.finfo(torch.float64).eps), -) -> tuple[torch.Tensor, torch.Tensor]: - - dim_x: int = int(2 * math.ceil(2 * spatial_width) + 1) - dim_y: int = int(2 * math.ceil(2 * spatial_width) + 1) - dim_t: int = int(2 * math.ceil(2 * temporal_width) + 1) - dims_xyt: torch.Tensor = torch.tensor( - [dim_x, dim_y, dim_t], dtype=torch.int64, device=input.device - ) - - if input.ndim == 2: - input = input.unsqueeze(-1) - - input_padded = torch.nn.functional.pad( - input.unsqueeze(0), - pad=( - dim_t, - dim_t, - dim_y, - dim_y, - dim_x, - dim_x, - ), - mode="replicate", - ).squeeze(0) - - if overwrite_fft_gauss is None: - center_x: int = int(math.ceil(input_padded.shape[0] / 2)) - center_y: int = int(math.ceil(input_padded.shape[1] / 2)) - center_z: int = int(math.ceil(input_padded.shape[2] / 2)) - grid_x: torch.Tensor = ( - torch.arange(0, input_padded.shape[0], device=input.device) - center_x + 1 - ) - grid_y: torch.Tensor = ( - torch.arange(0, input_padded.shape[1], device=input.device) - center_y + 1 - ) - grid_z: torch.Tensor = ( - torch.arange(0, input_padded.shape[2], device=input.device) - center_z + 1 - ) - - grid_x = grid_x.unsqueeze(-1).unsqueeze(-1) ** 2 - grid_y = grid_y.unsqueeze(0).unsqueeze(-1) ** 2 - grid_z = grid_z.unsqueeze(0).unsqueeze(0) ** 2 - - gauss_kernel: torch.Tensor = ( - (grid_x / (spatial_width**2)) - + (grid_y / (spatial_width**2)) - + (grid_z / (temporal_width**2)) - ) - - if use_matlab_mask: - filter_radius: torch.Tensor = (dims_xyt - 1) // 2 - - border_lower: list[int] = [ - center_x - int(filter_radius[0]) - 1, - center_y - int(filter_radius[1]) - 1, - center_z - int(filter_radius[2]) - 1, - ] - - border_upper: list[int] = [ - center_x + int(filter_radius[0]), - center_y + int(filter_radius[1]), - center_z + int(filter_radius[2]), - ] - - matlab_mask: torch.Tensor = torch.zeros_like(gauss_kernel) - matlab_mask[ - border_lower[0] : border_upper[0], - border_lower[1] : border_upper[1], - border_lower[2] : border_upper[2], - ] = 1.0 - - gauss_kernel = torch.exp(-gauss_kernel / 2.0) - if use_matlab_mask: - gauss_kernel = gauss_kernel * matlab_mask - - gauss_kernel[gauss_kernel < (epsilon * gauss_kernel.max())] = 0 - - sum_gauss_kernel: float = float(gauss_kernel.sum()) - - if sum_gauss_kernel != 0.0: - gauss_kernel = gauss_kernel / sum_gauss_kernel - - # FFT Shift - gauss_kernel = torch.cat( - (gauss_kernel[center_x - 1 :, :, :], gauss_kernel[: center_x - 1, :, :]), - dim=0, - ) - gauss_kernel = torch.cat( - (gauss_kernel[:, center_y - 1 :, :], gauss_kernel[:, : center_y - 1, :]), - dim=1, - ) - gauss_kernel = torch.cat( - (gauss_kernel[:, :, center_z - 1 :], gauss_kernel[:, :, : center_z - 1]), - dim=2, - ) - overwrite_fft_gauss = torch.fft.fftn(gauss_kernel) - input_padded_gauss_filtered: torch.Tensor = torch.real( - torch.fft.ifftn(torch.fft.fftn(input_padded) * overwrite_fft_gauss) - ) - else: - input_padded_gauss_filtered = torch.real( - torch.fft.ifftn(torch.fft.fftn(input_padded) * overwrite_fft_gauss) - ) - - start = dims_xyt - stop = ( - torch.tensor(input_padded.shape, device=dims_xyt.device, dtype=dims_xyt.dtype) - - dims_xyt - ) - - output = input_padded_gauss_filtered[ - start[0] : stop[0], start[1] : stop[1], start[2] : stop[2] - ] - - return (output, overwrite_fft_gauss) diff --git a/reproduction_effort/functions/get_experiments.py b/reproduction_effort/functions/get_experiments.py deleted file mode 100644 index d92b936..0000000 --- a/reproduction_effort/functions/get_experiments.py +++ /dev/null @@ -1,19 +0,0 @@ -import torch -import os -import glob - - -@torch.no_grad() -def get_experiments(path: str) -> torch.Tensor: - filename_np: str = os.path.join( - path, - "Exp*_Part001.npy", - ) - - list_str = glob.glob(filename_np) - list_int: list[int] = [] - for i in range(0, len(list_str)): - list_int.append(int(list_str[i].split("Exp")[-1].split("_Trial")[0])) - list_int = sorted(list_int) - - return torch.tensor(list_int).unique() diff --git a/reproduction_effort/functions/get_parts.py b/reproduction_effort/functions/get_parts.py deleted file mode 100644 index d68e1ae..0000000 --- a/reproduction_effort/functions/get_parts.py +++ /dev/null @@ -1,18 +0,0 @@ -import torch -import os -import glob - - -@torch.no_grad() -def get_parts(path: str, experiment_id: int, trial_id: int) -> torch.Tensor: - filename_np: str = os.path.join( - path, - f"Exp{experiment_id:03d}_Trial{trial_id:03d}_Part*.npy", - ) - - list_str = glob.glob(filename_np) - list_int: list[int] = [] - for i in range(0, len(list_str)): - list_int.append(int(list_str[i].split("_Part")[-1].split(".npy")[0])) - list_int = sorted(list_int) - return torch.tensor(list_int).unique() diff --git a/reproduction_effort/functions/get_trials.py b/reproduction_effort/functions/get_trials.py deleted file mode 100644 index 8c687d9..0000000 --- a/reproduction_effort/functions/get_trials.py +++ /dev/null @@ -1,18 +0,0 @@ -import torch -import os -import glob - - -@torch.no_grad() -def get_trials(path: str, experiment_id: int) -> torch.Tensor: - filename_np: str = os.path.join( - path, - f"Exp{experiment_id:03d}_Trial*_Part001.npy", - ) - - list_str = glob.glob(filename_np) - list_int: list[int] = [] - for i in range(0, len(list_str)): - list_int.append(int(list_str[i].split("_Trial")[-1].split("_Part")[0])) - list_int = sorted(list_int) - return torch.tensor(list_int).unique() diff --git a/reproduction_effort/functions/heart_beat_frequency.py b/reproduction_effort/functions/heart_beat_frequency.py deleted file mode 100644 index 99b2985..0000000 --- a/reproduction_effort/functions/heart_beat_frequency.py +++ /dev/null @@ -1,49 +0,0 @@ -import torch - - -def heart_beat_frequency( - input: torch.Tensor, - lower_frequency_heartbeat: float, - upper_frequency_heartbeat: float, - sample_frequency: float, - mask: torch.Tensor, -) -> float: - - number_of_active_pixel: torch.Tensor = mask.type(dtype=torch.float32).sum() - signal: torch.Tensor = (input * mask.unsqueeze(-1)).sum(dim=0).sum( - dim=0 - ) / number_of_active_pixel - signal = signal - signal.mean() - - hamming_window = torch.hamming_window( - window_length=signal.shape[0], - periodic=True, - alpha=0.54, - beta=0.46, - dtype=signal.dtype, - device=signal.device, - ) - - signal *= hamming_window - - signal_fft: torch.Tensor = torch.fft.rfft(signal) - frequency_axis: torch.Tensor = ( - torch.fft.rfftfreq(signal.shape[0], device=input.device) * sample_frequency - ) - signal_power: torch.Tensor = torch.abs(signal_fft) ** 2 - signal_power[1:-1] *= 2 - - if frequency_axis[-1] != (sample_frequency / 2.0): - signal_power[-1] *= 2 - signal_power /= hamming_window.sum() ** 2 - - idx = torch.where( - (frequency_axis > lower_frequency_heartbeat) - * (frequency_axis < upper_frequency_heartbeat) - )[0] - frequency_axis = frequency_axis[idx] - signal_power = signal_power[idx] - - heart_rate = float(frequency_axis[torch.argmax(signal_power)]) - - return heart_rate diff --git a/reproduction_effort/functions/interpolate_along_time.py b/reproduction_effort/functions/interpolate_along_time.py deleted file mode 100644 index d747cb9..0000000 --- a/reproduction_effort/functions/interpolate_along_time.py +++ /dev/null @@ -1,11 +0,0 @@ -import torch - - -def interpolate_along_time(camera_sequence: list[torch.Tensor]) -> None: - camera_sequence[2][:, :, 1:] = ( - camera_sequence[2][:, :, 1:] + camera_sequence[2][:, :, :-1] - ) / 2.0 - - camera_sequence[3][:, :, 1:] = ( - camera_sequence[3][:, :, 1:] + camera_sequence[3][:, :, :-1] - ) / 2.0 diff --git a/reproduction_effort/functions/make_mask.py b/reproduction_effort/functions/make_mask.py deleted file mode 100644 index 3a57fac..0000000 --- a/reproduction_effort/functions/make_mask.py +++ /dev/null @@ -1,32 +0,0 @@ -import scipy.io as sio # type: ignore - -import torch - - -@torch.no_grad() -def make_mask( - filename_mask: str, - camera_sequence: list[torch.Tensor], - device: torch.device, - dtype: torch.dtype, -) -> torch.Tensor: - mask: torch.Tensor = torch.tensor( - sio.loadmat(filename_mask)["maskInfo"]["maskIdx2D"][0][0], - device=device, - dtype=torch.bool, - ) - mask = mask > 0.5 - - limit: torch.Tensor = torch.tensor( - 2**16 - 1, - device=device, - dtype=dtype, - ) - - for id in range(0, len(camera_sequence)): - if torch.any(camera_sequence[id].flatten() >= limit): - mask = mask & ~(torch.any(camera_sequence[id] >= limit, dim=-1)) - if torch.any(camera_sequence[id].flatten() < 0): - mask = mask & ~(torch.any(camera_sequence[id] < 0, dim=-1)) - - return mask diff --git a/reproduction_effort/functions/perform_donor_volume_rotation.py b/reproduction_effort/functions/perform_donor_volume_rotation.py deleted file mode 100644 index 519b111..0000000 --- a/reproduction_effort/functions/perform_donor_volume_rotation.py +++ /dev/null @@ -1,84 +0,0 @@ -import torch -import torchvision as tv # type: ignore -from functions.calculate_rotation import calculate_rotation -from functions.ImageAlignment import ImageAlignment - - -@torch.no_grad() -def perform_donor_volume_rotation( - acceptor: torch.Tensor, - donor: torch.Tensor, - oxygenation: torch.Tensor, - volume: torch.Tensor, - ref_image_donor: torch.Tensor, - ref_image_volume: torch.Tensor, - image_alignment: ImageAlignment, - batch_size: int, - fill_value: float = 0, -) -> tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, -]: - - angle_donor = calculate_rotation( - input=donor, - reference_image=ref_image_donor, - image_alignment=image_alignment, - batch_size=batch_size, - ) - - angle_volume = calculate_rotation( - input=volume, - reference_image=ref_image_volume, - image_alignment=image_alignment, - batch_size=batch_size, - ) - - angle_donor_volume = (angle_donor + angle_volume) / 2.0 - - for frame_id in range(0, angle_donor_volume.shape[0]): - - acceptor[frame_id, ...] = tv.transforms.functional.affine( - img=acceptor[frame_id, ...].unsqueeze(0), - angle=-float(angle_donor_volume[frame_id]), - translate=[0, 0], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ).squeeze(0) - - donor[frame_id, ...] = tv.transforms.functional.affine( - img=donor[frame_id, ...].unsqueeze(0), - angle=-float(angle_donor_volume[frame_id]), - translate=[0, 0], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ).squeeze(0) - - oxygenation[frame_id, ...] = tv.transforms.functional.affine( - img=oxygenation[frame_id, ...].unsqueeze(0), - angle=-float(angle_donor_volume[frame_id]), - translate=[0, 0], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ).squeeze(0) - - volume[frame_id, ...] = tv.transforms.functional.affine( - img=volume[frame_id, ...].unsqueeze(0), - angle=-float(angle_donor_volume[frame_id]), - translate=[0, 0], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ).squeeze(0) - - return (acceptor, donor, oxygenation, volume, angle_donor_volume) diff --git a/reproduction_effort/functions/perform_donor_volume_translation.py b/reproduction_effort/functions/perform_donor_volume_translation.py deleted file mode 100644 index b43ff95..0000000 --- a/reproduction_effort/functions/perform_donor_volume_translation.py +++ /dev/null @@ -1,84 +0,0 @@ -import torch -import torchvision as tv # type: ignore -from functions.calculate_translation import calculate_translation -from functions.ImageAlignment import ImageAlignment - - -@torch.no_grad() -def perform_donor_volume_translation( - acceptor: torch.Tensor, - donor: torch.Tensor, - oxygenation: torch.Tensor, - volume: torch.Tensor, - ref_image_donor: torch.Tensor, - ref_image_volume: torch.Tensor, - image_alignment: ImageAlignment, - batch_size: int, - fill_value: float = 0, -) -> tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, -]: - - tvec_donor = calculate_translation( - input=donor, - reference_image=ref_image_donor, - image_alignment=image_alignment, - batch_size=batch_size, - ) - - tvec_volume = calculate_translation( - input=volume, - reference_image=ref_image_volume, - image_alignment=image_alignment, - batch_size=batch_size, - ) - - tvec_donor_volume = (tvec_donor + tvec_volume) / 2.0 - - for frame_id in range(0, tvec_donor_volume.shape[0]): - - acceptor[frame_id, ...] = tv.transforms.functional.affine( - img=acceptor[frame_id, ...].unsqueeze(0), - angle=0, - translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ).squeeze(0) - - donor[frame_id, ...] = tv.transforms.functional.affine( - img=donor[frame_id, ...].unsqueeze(0), - angle=0, - translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ).squeeze(0) - - oxygenation[frame_id, ...] = tv.transforms.functional.affine( - img=oxygenation[frame_id, ...].unsqueeze(0), - angle=0, - translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ).squeeze(0) - - volume[frame_id, ...] = tv.transforms.functional.affine( - img=volume[frame_id, ...].unsqueeze(0), - angle=0, - translate=[tvec_donor_volume[frame_id, 1], tvec_donor_volume[frame_id, 0]], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ).squeeze(0) - - return (acceptor, donor, oxygenation, volume, tvec_donor_volume) diff --git a/reproduction_effort/functions/preprocess_camera_sequence.py b/reproduction_effort/functions/preprocess_camera_sequence.py deleted file mode 100644 index 607d0d2..0000000 --- a/reproduction_effort/functions/preprocess_camera_sequence.py +++ /dev/null @@ -1,36 +0,0 @@ -import torch - - -@torch.no_grad() -def preprocess_camera_sequence( - camera_sequence: torch.Tensor, - mask: torch.Tensor, - first_none_ramp_frame: int, - device: torch.device, - dtype: torch.dtype, -) -> tuple[torch.Tensor, torch.Tensor]: - - limit: torch.Tensor = torch.tensor( - 2**16 - 1, - device=device, - dtype=dtype, - ) - - camera_sequence = camera_sequence / camera_sequence[ - :, :, first_none_ramp_frame: - ].nanmean( - dim=2, - keepdim=True, - ) - - camera_sequence = camera_sequence.nan_to_num(nan=0.0) - - camera_sequence_zero_idx = torch.any(camera_sequence == 0, dim=-1, keepdim=True) - mask &= (~camera_sequence_zero_idx.squeeze(-1)).type(dtype=torch.bool) - camera_sequence_zero_idx = torch.tile( - camera_sequence_zero_idx, (1, 1, camera_sequence.shape[-1]) - ) - camera_sequence_zero_idx[:, :, :first_none_ramp_frame] = False - camera_sequence[camera_sequence_zero_idx] = limit - - return camera_sequence, mask diff --git a/reproduction_effort/functions/preprocessing.py b/reproduction_effort/functions/preprocessing.py deleted file mode 100644 index 64ee8fc..0000000 --- a/reproduction_effort/functions/preprocessing.py +++ /dev/null @@ -1,73 +0,0 @@ -import torch - - -from functions.make_mask import make_mask -from functions.preprocess_camera_sequence import preprocess_camera_sequence -from functions.gauss_smear import gauss_smear -from functions.regression import regression - - -@torch.no_grad() -def preprocessing( - cameras: list[str], - camera_sequence: list[torch.Tensor], - filename_mask: str, - device: torch.device, - first_none_ramp_frame: int, - spatial_width: float, - temporal_width: float, - target_camera: list[str], - regressor_cameras: list[str], - dtype: torch.dtype = torch.float32, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - mask: torch.Tensor = make_mask( - filename_mask=filename_mask, - camera_sequence=camera_sequence, - device=device, - dtype=dtype, - ) - - for num_cams in range(len(camera_sequence)): - camera_sequence[num_cams], mask = preprocess_camera_sequence( - camera_sequence=camera_sequence[num_cams], - mask=mask, - first_none_ramp_frame=first_none_ramp_frame, - device=device, - dtype=dtype, - ) - - camera_sequence_filtered: list[torch.Tensor] = [] - for id in range(0, len(camera_sequence)): - camera_sequence_filtered.append(camera_sequence[id].clone()) - - camera_sequence_filtered = gauss_smear( - camera_sequence_filtered, - mask.type(dtype=dtype), - spatial_width=spatial_width, - temporal_width=temporal_width, - ) - - regressor_camera_ids: list[int] = [] - - for cam in regressor_cameras: - regressor_camera_ids.append(cameras.index(cam)) - - results: list[torch.Tensor] = [] - - for channel_position in range(0, len(target_camera)): - print(f"channel position: {channel_position}") - target_camera_selected = target_camera[channel_position] - target_camera_id: int = cameras.index(target_camera_selected) - - output = regression( - target_camera_id=target_camera_id, - regressor_camera_ids=regressor_camera_ids, - mask=mask, - camera_sequence=camera_sequence, - camera_sequence_filtered=camera_sequence_filtered, - first_none_ramp_frame=first_none_ramp_frame, - ) - results.append(output) - - return results[0], results[1], mask diff --git a/reproduction_effort/functions/preprocessing_classsic.py b/reproduction_effort/functions/preprocessing_classsic.py deleted file mode 100644 index 27a1172..0000000 --- a/reproduction_effort/functions/preprocessing_classsic.py +++ /dev/null @@ -1,137 +0,0 @@ -import torch - - -from functions.make_mask import make_mask -from functions.heart_beat_frequency import heart_beat_frequency -from functions.adjust_factor import adjust_factor -from functions.preprocess_camera_sequence import preprocess_camera_sequence -from functions.interpolate_along_time import interpolate_along_time -from functions.gauss_smear import gauss_smear -from functions.regression import regression - - -@torch.no_grad() -def preprocessing( - cameras: list[str], - camera_sequence: list[torch.Tensor], - filename_mask: str, - device: torch.device, - first_none_ramp_frame: int, - spatial_width: float, - temporal_width: float, - target_camera: list[str], - regressor_cameras: list[str], - lower_frequency_heartbeat: float, - upper_frequency_heartbeat: float, - sample_frequency: float, - dtype: torch.dtype = torch.float32, - power_factors: None | list[torch.Tensor] = None, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - mask: torch.Tensor = make_mask( - filename_mask=filename_mask, - camera_sequence=camera_sequence, - device=device, - dtype=dtype, - ) - - for num_cams in range(len(camera_sequence)): - camera_sequence[num_cams], mask = preprocess_camera_sequence( - camera_sequence=camera_sequence[num_cams], - mask=mask, - first_none_ramp_frame=first_none_ramp_frame, - device=device, - dtype=dtype, - ) - - # Interpolate in-between images - if power_factors is None: - interpolate_along_time(camera_sequence) - - camera_sequence_filtered: list[torch.Tensor] = [] - for id in range(0, len(camera_sequence)): - camera_sequence_filtered.append(camera_sequence[id].clone()) - - if power_factors is None: - idx_volume: int = cameras.index("volume") - heart_rate: None | float = heart_beat_frequency( - input=camera_sequence_filtered[idx_volume], - lower_frequency_heartbeat=lower_frequency_heartbeat, - upper_frequency_heartbeat=upper_frequency_heartbeat, - sample_frequency=sample_frequency, - mask=mask, - ) - else: - heart_rate = None - - camera_sequence_filtered = gauss_smear( - camera_sequence_filtered, - mask.type(dtype=dtype), - spatial_width=spatial_width, - temporal_width=temporal_width, - ) - - regressor_camera_ids: list[int] = [] - - for cam in regressor_cameras: - regressor_camera_ids.append(cameras.index(cam)) - - results: list[torch.Tensor] = [] - - for channel_position in range(0, len(target_camera)): - print(f"channel position: {channel_position}") - target_camera_selected = target_camera[channel_position] - target_camera_id: int = cameras.index(target_camera_selected) - - output = regression( - target_camera_id=target_camera_id, - regressor_camera_ids=regressor_camera_ids, - mask=mask, - camera_sequence=camera_sequence, - camera_sequence_filtered=camera_sequence_filtered, - first_none_ramp_frame=first_none_ramp_frame, - ) - results.append(output) - - if heart_rate is not None: - lower_frequency_heartbeat_selection: float = heart_rate - 3 - upper_frequency_heartbeat_selection: float = heart_rate + 3 - else: - lower_frequency_heartbeat_selection = 0 - upper_frequency_heartbeat_selection = 0 - - donor_correction_factor: torch.Tensor | float - acceptor_correction_factor: torch.Tensor | float - if heart_rate is not None: - donor_correction_factor, acceptor_correction_factor = adjust_factor( - input_acceptor=results[0], - input_donor=results[1], - lower_frequency_heartbeat=lower_frequency_heartbeat_selection, - upper_frequency_heartbeat=upper_frequency_heartbeat_selection, - sample_frequency=sample_frequency, - mask=mask, - power_factors=power_factors, - ) - - results[0] = acceptor_correction_factor * ( - results[0] - results[0].mean(dim=-1, keepdim=True) - ) + results[0].mean(dim=-1, keepdim=True) - - results[1] = donor_correction_factor * ( - results[1] - results[1].mean(dim=-1, keepdim=True) - ) + results[1].mean(dim=-1, keepdim=True) - else: - assert power_factors is not None - donor_correction_factor = power_factors[0] - acceptor_correction_factor = power_factors[1] - donor_factor: torch.Tensor = ( - donor_correction_factor + acceptor_correction_factor - ) / (2 * donor_correction_factor) - acceptor_factor: torch.Tensor = ( - donor_correction_factor + acceptor_correction_factor - ) / (2 * acceptor_correction_factor) - - results[0] *= acceptor_factor * mask.unsqueeze(-1) - results[1] *= donor_factor * mask.unsqueeze(-1) - - return results[0], results[1], mask diff --git a/reproduction_effort/functions/regression.py b/reproduction_effort/functions/regression.py deleted file mode 100644 index c273d0a..0000000 --- a/reproduction_effort/functions/regression.py +++ /dev/null @@ -1,118 +0,0 @@ -import torch -from functions.regression_internal import regression_internal - - -@torch.no_grad() -def regression( - target_camera_id: int, - regressor_camera_ids: list[int], - mask: torch.Tensor, - camera_sequence: list[torch.Tensor], - camera_sequence_filtered: list[torch.Tensor], - first_none_ramp_frame: int, -) -> torch.Tensor: - - assert len(regressor_camera_ids) > 0 - - # ------- Prepare the target signals ---------- - target_signals_train: torch.Tensor = ( - camera_sequence_filtered[target_camera_id][..., first_none_ramp_frame:].clone() - - 1.0 - ) - target_signals_train[target_signals_train < -1] = 0.0 - - target_signals_perform: torch.Tensor = ( - camera_sequence[target_camera_id].clone() - 1.0 - ) - - # Check if everything is happy - assert target_signals_train.ndim == 3 - assert target_signals_train.ndim == target_signals_perform.ndim - assert target_signals_train.shape[0] == target_signals_perform.shape[0] - assert target_signals_train.shape[1] == target_signals_perform.shape[1] - assert ( - target_signals_train.shape[2] + first_none_ramp_frame - ) == target_signals_perform.shape[2] - # --==DONE==- - - # ------- Prepare the regressor signals ---------- - - # --- Train --- - - regressor_signals_train: torch.Tensor = torch.zeros( - ( - camera_sequence_filtered[0].shape[0], - camera_sequence_filtered[0].shape[1], - camera_sequence_filtered[0].shape[2], - len(regressor_camera_ids) + 1, - ), - device=camera_sequence_filtered[0].device, - dtype=camera_sequence_filtered[0].dtype, - ) - - # Copy the regressor signals -1 - for matrix_id, id in enumerate(regressor_camera_ids): - regressor_signals_train[..., matrix_id] = camera_sequence_filtered[id] - 1.0 - - regressor_signals_train[regressor_signals_train < -1] = 0.0 - - # Linear regressor - trend = torch.arange( - 0, regressor_signals_train.shape[-2], device=camera_sequence_filtered[0].device - ) / float(regressor_signals_train.shape[-2] - 1) - trend -= trend.mean() - trend = trend.unsqueeze(0).unsqueeze(0) - trend = trend.tile( - (regressor_signals_train.shape[0], regressor_signals_train.shape[1], 1) - ) - regressor_signals_train[..., -1] = trend - - regressor_signals_train = regressor_signals_train[:, :, first_none_ramp_frame:, :] - - # --- Perform --- - - regressor_signals_perform: torch.Tensor = torch.zeros( - ( - camera_sequence[0].shape[0], - camera_sequence[0].shape[1], - camera_sequence[0].shape[2], - len(regressor_camera_ids) + 1, - ), - device=camera_sequence[0].device, - dtype=camera_sequence[0].dtype, - ) - - # Copy the regressor signals -1 - for matrix_id, id in enumerate(regressor_camera_ids): - regressor_signals_perform[..., matrix_id] = camera_sequence[id] - 1.0 - - # Linear regressor - trend = torch.arange( - 0, regressor_signals_perform.shape[-2], device=camera_sequence[0].device - ) / float(regressor_signals_perform.shape[-2] - 1) - trend -= trend.mean() - trend = trend.unsqueeze(0).unsqueeze(0) - trend = trend.tile( - (regressor_signals_perform.shape[0], regressor_signals_perform.shape[1], 1) - ) - regressor_signals_perform[..., -1] = trend - - # --==DONE==- - - coefficients, intercept = regression_internal( - input_regressor=regressor_signals_train, input_target=target_signals_train - ) - - target_signals_perform -= ( - regressor_signals_perform * coefficients.unsqueeze(-2) - ).sum(dim=-1) - - target_signals_perform -= intercept.unsqueeze(-1) - - target_signals_perform[ - ~mask.unsqueeze(-1).tile((1, 1, target_signals_perform.shape[-1])) - ] = 0.0 - - target_signals_perform += 1.0 - - return target_signals_perform diff --git a/reproduction_effort/functions/regression_internal.py b/reproduction_effort/functions/regression_internal.py deleted file mode 100644 index 352d7ba..0000000 --- a/reproduction_effort/functions/regression_internal.py +++ /dev/null @@ -1,20 +0,0 @@ -import torch - - -def regression_internal( - input_regressor: torch.Tensor, input_target: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor]: - - regressor_offset = input_regressor.mean(keepdim=True, dim=-2) - target_offset = input_target.mean(keepdim=True, dim=-1) - - regressor = input_regressor - regressor_offset - target = input_target - target_offset - - coefficients, _, _, _ = torch.linalg.lstsq(regressor, target, rcond=None) # None ? - - intercept = target_offset.squeeze(-1) - ( - coefficients * regressor_offset.squeeze(-2) - ).sum(dim=-1) - - return coefficients, intercept diff --git a/reproduction_effort/heartbeat.py b/reproduction_effort/heartbeat.py deleted file mode 100644 index 5a392ab..0000000 --- a/reproduction_effort/heartbeat.py +++ /dev/null @@ -1,201 +0,0 @@ -import numpy as np -import torch -import os -import json -import matplotlib.pyplot as plt -import scipy.io as sio # type: ignore - -from functions.binning import binning -from functions.align_cameras import align_cameras -from functions.bandpass import bandpass -from functions.make_mask import make_mask - -if torch.cuda.is_available(): - device_name: str = "cuda:0" -else: - device_name = "cpu" -print(f"Using device: {device_name}") -device: torch.device = torch.device(device_name) -dtype: torch.dtype = torch.float32 - - -filename_raw: str = f"raw{os.sep}Exp001_Trial001_Part001.npy" -filename_raw_json: str = f"raw{os.sep}Exp001_Trial001_Part001_meta.txt" -filename_mask: str = "2020-12-08maskPixelraw2.mat" - -first_none_ramp_frame: int = 100 -spatial_width: float = 2 -temporal_width: float = 0.1 - -lower_freqency_bandpass: float = 5.0 -upper_freqency_bandpass: float = 14.0 - -lower_frequency_heartbeat: float = 5.0 -upper_frequency_heartbeat: float = 14.0 -sample_frequency: float = 100.0 - -target_camera: list[str] = ["acceptor", "donor"] -regressor_cameras: list[str] = ["oxygenation", "volume"] -batch_size: int = 200 -required_order: list[str] = ["acceptor", "donor", "oxygenation", "volume"] - - -test_overwrite_with_old_bining: bool = False -test_overwrite_with_old_aligned: bool = False -filename_data_binning_replace: str = "bin_old/Exp001_Trial001_Part001.mat" -filename_data_aligned_replace: str = "aligned_old/Exp001_Trial001_Part001.mat" - -data = torch.tensor(np.load(filename_raw).astype(np.float32), dtype=dtype) - -with open(filename_raw_json, "r") as file_handle: - metadata: dict = json.load(file_handle) -channels: list[str] = metadata["channelKey"] - - -if test_overwrite_with_old_bining: - data = torch.tensor( - sio.loadmat(filename_data_binning_replace)["nparray"].astype(np.float32), - dtype=dtype, - device=device, - ) -else: - data = binning(data).to(device) - -ref_image = data[:, :, data.shape[-2] // 2, :].clone() - -( - acceptor, - donor, - oxygenation, - volume, - angle_donor_volume, - tvec_donor_volume, - angle_refref, - tvec_refref, -) = align_cameras( - channels=channels, - data=data, - ref_image=ref_image, - device=device, - dtype=dtype, - batch_size=batch_size, - fill_value=-1, -) -del data - - -camera_sequence: list[torch.Tensor] = [] - -for cam in required_order: - if cam.startswith("acceptor"): - camera_sequence.append(acceptor.movedim(0, -1).clone()) - del acceptor - if cam.startswith("donor"): - camera_sequence.append(donor.movedim(0, -1).clone()) - del donor - if cam.startswith("oxygenation"): - camera_sequence.append(oxygenation.movedim(0, -1).clone()) - del oxygenation - if cam.startswith("volume"): - camera_sequence.append(volume.movedim(0, -1).clone()) - del volume - -if test_overwrite_with_old_aligned: - - data_aligned_replace: torch.Tensor = torch.tensor( - sio.loadmat(filename_data_aligned_replace)["data"].astype(np.float32), - device=device, - dtype=dtype, - ) - - camera_sequence[0] = data_aligned_replace[..., 0].clone() - camera_sequence[1] = data_aligned_replace[..., 1].clone() - camera_sequence[2] = data_aligned_replace[..., 2].clone() - camera_sequence[3] = data_aligned_replace[..., 3].clone() - del data_aligned_replace - - -mask: torch.Tensor = make_mask( - filename_mask=filename_mask, - camera_sequence=camera_sequence, - device=device, - dtype=dtype, -) - -mask_flatten = mask.flatten(start_dim=0, end_dim=-1) - -original_shape = camera_sequence[0].shape -for i in range(0, len(camera_sequence)): - camera_sequence[i] = bandpass( - data=camera_sequence[i].clone(), - device=camera_sequence[i].device, - low_frequency=lower_freqency_bandpass, - high_frequency=upper_freqency_bandpass, - fs=100.0, - filtfilt_chuck_size=10, - ) - - camera_sequence[i] = camera_sequence[i].flatten(start_dim=0, end_dim=-2) - camera_sequence[i] = camera_sequence[i][mask_flatten, :] - if (i == 0) or (i == 1): - camera_sequence[i] = camera_sequence[i][:, 1:] - else: - camera_sequence[i] = ( - camera_sequence[i][:, 1:] + camera_sequence[i][:, :-1] - ) / 2.0 - - camera_sequence[i] = camera_sequence[i].movedim(0, -1) - camera_sequence[i] -= camera_sequence[i].mean(dim=0, keepdim=True) - - -camera_sequence_cat = torch.cat( - (camera_sequence[0], camera_sequence[1], camera_sequence[2], camera_sequence[3]), - dim=-1, -) - -print(camera_sequence_cat.min(), camera_sequence_cat.max()) - -u_a, s_a, Vh_a = torch.linalg.svd(camera_sequence_cat, full_matrices=False) -u_a = u_a[:, 0] -Vh_a = Vh_a[0, :] * s_a[0] - -heart_beat_activity_map: list[torch.Tensor] = [] - -start_pos: int = 0 -end_pos: int = 0 -for i in range(0, len(camera_sequence)): - end_pos = start_pos + int(mask_flatten.sum()) - heart_beat_activity_map.append( - torch.full( - (original_shape[0], original_shape[1]), - torch.nan, - device=Vh_a.device, - dtype=Vh_a.dtype, - ).flatten(start_dim=0, end_dim=-1) - ) - heart_beat_activity_map[-1][mask_flatten] = Vh_a[start_pos:end_pos] - heart_beat_activity_map[-1] = heart_beat_activity_map[-1].reshape( - (original_shape[0], original_shape[1]) - ) - start_pos = end_pos - -full_image = torch.cat(heart_beat_activity_map, dim=1) - - -# I want to scale the time series to std unity -# and therefore need to increase the amplitudes of the maps -u_a_std = torch.std(u_a) -u_a /= u_a_std -full_image *= u_a_std - -plt.subplot(2, 1, 1) -plt.plot(u_a.cpu()) -plt.xlabel("Frame ID") -plt.title( - f"Common heartbeat in {lower_freqency_bandpass}Hz - {upper_freqency_bandpass}Hz" -) -plt.subplot(2, 1, 2) -plt.imshow(full_image.cpu(), cmap="hot") -plt.colorbar() -plt.title("acceptor, donor, oxygenation, volume") -plt.show() diff --git a/reproduction_effort/make_test_data_aligned.py b/reproduction_effort/make_test_data_aligned.py deleted file mode 100644 index 5e9d165..0000000 --- a/reproduction_effort/make_test_data_aligned.py +++ /dev/null @@ -1,204 +0,0 @@ -import torch -import torchvision as tv # type: ignore -import numpy as np -import matplotlib.pyplot as plt -import scipy.io as sio # type: ignore -import json - -from functions.align_cameras import align_cameras - -if torch.cuda.is_available(): - device_name: str = "cuda:0" -else: - device_name = "cpu" -print(f"Using device: {device_name}") -device: torch.device = torch.device(device_name) -dtype: torch.dtype = torch.float32 - -filename_bin_mat: str = "bin_old/Exp001_Trial001_Part001.mat" -filename_bin_mat_fake: str = "Exp001_Trial001_Part001_fake.mat" -fill_value: float = 0.0 - -mat_data = torch.tensor( - sio.loadmat(filename_bin_mat)["nparray"].astype(dtype=np.float32), - dtype=dtype, - device=device, -) - -angle_refref_target = torch.tensor( - [2], - dtype=dtype, - device=device, -) - -tvec_refref_target = torch.tensor( - [10, 3], - dtype=dtype, - device=device, -) - - -t = ( - torch.arange( - 0, - mat_data.shape[-2], - dtype=dtype, - device=device, - ) - - mat_data.shape[-2] // 2 -) / float(mat_data.shape[-2] // 2) - -f_a: float = 8 -A_a: float = 2 -a_target = A_a * torch.sin(2 * torch.pi * t * f_a) - -f_x: float = 5 -A_x: float = 10 -x_target = A_x * torch.sin(2.0 * torch.pi * t * f_x) - -f_y: float = 7 -A_y: float = 7 -y_target = A_y * torch.sin(2 * torch.pi * t * f_y) - -master_images: torch.Tensor = mat_data[:, :, mat_data.shape[-2] // 2, 1] - -master_images_2: torch.Tensor = master_images.unsqueeze(-1).tile( - (1, 1, mat_data.shape[-1]) -) - -# Rotate and move the whole timeseries of the acceptor and oxygenation -master_images_2[..., 0] = tv.transforms.functional.affine( - img=master_images_2[..., 0].unsqueeze(0), - angle=-float(angle_refref_target), - translate=[0, 0], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, -).squeeze(0) - -master_images_2[..., 0] = tv.transforms.functional.affine( - img=master_images_2[..., 0].unsqueeze(0), - angle=0, - translate=[tvec_refref_target[1], tvec_refref_target[0]], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, -).squeeze(0) - -master_images_2[..., 2] = tv.transforms.functional.affine( - img=master_images_2[..., 2].unsqueeze(0), - angle=-float(angle_refref_target), - translate=[0, 0], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, -).squeeze(0) - -master_images_2[..., 2] = tv.transforms.functional.affine( - img=master_images_2[..., 2].unsqueeze(0), - angle=0, - translate=[tvec_refref_target[1], tvec_refref_target[0]], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, -).squeeze(0) - -fake_data = master_images_2.unsqueeze(-2).tile((1, 1, mat_data.shape[-2], 1)).clone() - -for t_id in range(0, fake_data.shape[-2]): - for c_id in range(0, fake_data.shape[-1]): - fake_data[..., t_id, c_id] = tv.transforms.functional.affine( - img=fake_data[..., t_id, c_id].unsqueeze(0), - angle=-float(a_target[t_id]), - translate=[0, 0], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ).squeeze(0) - -fake_data = fake_data.clone() - -for t_id in range(0, fake_data.shape[-2]): - for c_id in range(0, fake_data.shape[-1]): - fake_data[..., t_id, c_id] = tv.transforms.functional.affine( - img=fake_data[..., t_id, c_id].unsqueeze(0), - angle=0.0, - translate=[y_target[t_id], x_target[t_id]], - scale=1.0, - shear=0, - interpolation=tv.transforms.InterpolationMode.BILINEAR, - fill=fill_value, - ).squeeze(0) - -fake_data_np = fake_data.cpu().numpy() -mdic = {"nparray": fake_data_np} -sio.savemat(filename_bin_mat_fake, mdic) - -# ---------------------------------------------------- - -batch_size: int = 200 -filename_raw_json: str = "raw/Exp001_Trial001_Part001_meta.txt" - -with open(filename_raw_json, "r") as file_handle: - metadata: dict = json.load(file_handle) -channels: list[str] = metadata["channelKey"] - - -data = torch.tensor( - sio.loadmat(filename_bin_mat_fake)["nparray"].astype(np.float32), - dtype=dtype, - device=device, -) - -ref_image = data[:, :, data.shape[-2] // 2, :].clone() - -( - acceptor, - donor, - oxygenation, - volume, - angle_donor_volume, - tvec_donor_volume, - angle_refref, - tvec_refref, -) = align_cameras( - channels=channels, - data=data, - ref_image=ref_image, - device=device, - dtype=dtype, - batch_size=batch_size, - fill_value=-1, -) -del data - - -print("References Acceptor <-> Donor:") -print("Rotation:") -print(f"target: {float(angle_refref_target):.3f}") -print(f"found: {-float(angle_refref):.3f}") -print("Translation") -print(f"target: {float(tvec_refref_target[0]):.3f}, {float(tvec_refref_target[1]):.3f}") -print(f"found: {-float(tvec_refref[0]):.3f}, {-float(tvec_refref[1]):.3f}") - -plt.subplot(3, 1, 1) -plt.plot(-angle_donor_volume.cpu(), "g", label="angle found") -plt.plot(a_target.cpu(), "--k", label="angle target") -plt.legend() - -plt.subplot(3, 1, 2) -plt.plot(-tvec_donor_volume[:, 0].cpu(), "g", label="x found") -plt.plot(x_target.cpu(), "k--", label="x target") -plt.legend() - -plt.subplot(3, 1, 3) -plt.plot(-tvec_donor_volume[:, 1].cpu(), "g", label="y found") -plt.plot(y_target.cpu(), "k--", label="y target") -plt.legend() - -plt.show()