diff --git a/reproduction_effort/functions/bandpass.py b/reproduction_effort/functions/bandpass.py new file mode 100644 index 0000000..171baf5 --- /dev/null +++ b/reproduction_effort/functions/bandpass.py @@ -0,0 +1,85 @@ +import torchaudio as ta # type: ignore +import torch + + +@torch.no_grad() +def filtfilt( + input: torch.Tensor, + butter_a: torch.Tensor, + butter_b: torch.Tensor, +) -> torch.Tensor: + assert butter_a.ndim == 1 + assert butter_b.ndim == 1 + assert butter_a.shape[0] == butter_b.shape[0] + + process_data: torch.Tensor = input.detach().clone() + + padding_length = 12 * int(butter_a.shape[0]) + left_padding = 2 * process_data[..., 0].unsqueeze(-1) - process_data[ + ..., 1 : padding_length + 1 + ].flip(-1) + right_padding = 2 * process_data[..., -1].unsqueeze(-1) - process_data[ + ..., -(padding_length + 1) : -1 + ].flip(-1) + process_data_padded = torch.cat((left_padding, process_data, right_padding), dim=-1) + + output = ta.functional.filtfilt( + process_data_padded.unsqueeze(0), butter_a, butter_b, clamp=False + ).squeeze(0) + + output = output[..., padding_length:-padding_length] + return output + + +@torch.no_grad() +def butter_bandpass( + device: torch.device, + low_frequency: float = 0.1, + high_frequency: float = 1.0, + fs: float = 30.0, +) -> tuple[torch.Tensor, torch.Tensor]: + import scipy # type: ignore + + butter_b_np, butter_a_np = scipy.signal.butter( + 4, [low_frequency, high_frequency], btype="bandpass", output="ba", fs=fs + ) + butter_a = torch.tensor(butter_a_np, device=device, dtype=torch.float32) + butter_b = torch.tensor(butter_b_np, device=device, dtype=torch.float32) + return butter_a, butter_b + + +@torch.no_grad() +def chunk_iterator(array: torch.Tensor, chunk_size: int): + for i in range(0, array.shape[0], chunk_size): + yield array[i : i + chunk_size] + + +@torch.no_grad() +def bandpass( + data: torch.Tensor, + device: torch.device, + low_frequency: float = 0.1, + high_frequency: float = 1.0, + fs=30.0, + filtfilt_chuck_size: int = 10, +) -> torch.Tensor: + butter_a, butter_b = butter_bandpass( + device=device, + low_frequency=low_frequency, + high_frequency=high_frequency, + fs=fs, + ) + + index_full_dataset: torch.Tensor = torch.arange( + 0, data.shape[1], device=device, dtype=torch.int64 + ) + + for chunk in chunk_iterator(index_full_dataset, filtfilt_chuck_size): + temp_filtfilt = filtfilt( + data[:, chunk, :], + butter_a=butter_a, + butter_b=butter_b, + ) + data[:, chunk, :] = temp_filtfilt + + return data diff --git a/reproduction_effort/functions/convert_camera_sequenc_to_list.py b/reproduction_effort/functions/convert_camera_sequenc_to_list.py new file mode 100644 index 0000000..8aa2058 --- /dev/null +++ b/reproduction_effort/functions/convert_camera_sequenc_to_list.py @@ -0,0 +1,13 @@ +import torch + + +@torch.no_grad() +def convert_camera_sequenc_to_list( + data: torch.Tensor, required_order: list[str], cameras: list[str] +) -> list[torch.Tensor]: + camera_sequence: list[torch.Tensor] = [] + + for cam in required_order: + camera_sequence.append(data[:, :, :, cameras.index(cam)].clone()) + + return camera_sequence diff --git a/reproduction_effort/functions/gauss_smear.py b/reproduction_effort/functions/gauss_smear.py new file mode 100644 index 0000000..15abfed --- /dev/null +++ b/reproduction_effort/functions/gauss_smear.py @@ -0,0 +1,41 @@ +import torch +from functions.gauss_smear_individual import gauss_smear_individual + + +@torch.no_grad() +def gauss_smear( + input_cameras: list[torch.Tensor], + input_mask: torch.Tensor, + spatial_width: float, + temporal_width: float, + use_matlab_mask: bool = True, + epsilon: float = float(torch.finfo(torch.float64).eps), +) -> list[torch.Tensor]: + assert len(input_cameras) == 4 + + filtered_mask: torch.Tensor + filtered_mask, _ = gauss_smear_individual( + input=input_mask, + spatial_width=spatial_width, + temporal_width=temporal_width, + use_matlab_mask=use_matlab_mask, + epsilon=epsilon, + ) + + overwrite_fft_gauss: None | torch.Tensor = None + for id in range(0, len(input_cameras)): + + input_cameras[id] *= input_mask.unsqueeze(-1) + input_cameras[id], overwrite_fft_gauss = gauss_smear_individual( + input=input_cameras[id], + spatial_width=spatial_width, + temporal_width=temporal_width, + overwrite_fft_gauss=overwrite_fft_gauss, + use_matlab_mask=use_matlab_mask, + epsilon=epsilon, + ) + + input_cameras[id] /= filtered_mask + 1e-20 + input_cameras[id] += 1.0 - input_mask.unsqueeze(-1) + + return input_cameras diff --git a/reproduction_effort/functions/gauss_smear_individual.py b/reproduction_effort/functions/gauss_smear_individual.py new file mode 100644 index 0000000..36700e7 --- /dev/null +++ b/reproduction_effort/functions/gauss_smear_individual.py @@ -0,0 +1,127 @@ +import torch +import math + + +@torch.no_grad() +def gauss_smear_individual( + input: torch.Tensor, + spatial_width: float, + temporal_width: float, + overwrite_fft_gauss: None | torch.Tensor = None, + use_matlab_mask: bool = True, + epsilon: float = float(torch.finfo(torch.float64).eps), +) -> tuple[torch.Tensor, torch.Tensor]: + + dim_x: int = int(2 * math.ceil(2 * spatial_width) + 1) + dim_y: int = int(2 * math.ceil(2 * spatial_width) + 1) + dim_t: int = int(2 * math.ceil(2 * temporal_width) + 1) + dims_xyt: torch.Tensor = torch.tensor( + [dim_x, dim_y, dim_t], dtype=torch.int64, device=input.device + ) + + if input.ndim == 2: + input = input.unsqueeze(-1) + + input_padded = torch.nn.functional.pad( + input.unsqueeze(0), + pad=( + dim_t, + dim_t, + dim_y, + dim_y, + dim_x, + dim_x, + ), + mode="replicate", + ).squeeze(0) + + if overwrite_fft_gauss is None: + center_x: int = int(math.ceil(input_padded.shape[0] / 2)) + center_y: int = int(math.ceil(input_padded.shape[1] / 2)) + center_z: int = int(math.ceil(input_padded.shape[2] / 2)) + grid_x: torch.Tensor = ( + torch.arange(0, input_padded.shape[0], device=input.device) - center_x + 1 + ) + grid_y: torch.Tensor = ( + torch.arange(0, input_padded.shape[1], device=input.device) - center_y + 1 + ) + grid_z: torch.Tensor = ( + torch.arange(0, input_padded.shape[2], device=input.device) - center_z + 1 + ) + + grid_x = grid_x.unsqueeze(-1).unsqueeze(-1) ** 2 + grid_y = grid_y.unsqueeze(0).unsqueeze(-1) ** 2 + grid_z = grid_z.unsqueeze(0).unsqueeze(0) ** 2 + + gauss_kernel: torch.Tensor = ( + (grid_x / (spatial_width**2)) + + (grid_y / (spatial_width**2)) + + (grid_z / (temporal_width**2)) + ) + + if use_matlab_mask: + filter_radius: torch.Tensor = (dims_xyt - 1) // 2 + + border_lower: list[int] = [ + center_x - int(filter_radius[0]) - 1, + center_y - int(filter_radius[1]) - 1, + center_z - int(filter_radius[2]) - 1, + ] + + border_upper: list[int] = [ + center_x + int(filter_radius[0]), + center_y + int(filter_radius[1]), + center_z + int(filter_radius[2]), + ] + + matlab_mask: torch.Tensor = torch.zeros_like(gauss_kernel) + matlab_mask[ + border_lower[0] : border_upper[0], + border_lower[1] : border_upper[1], + border_lower[2] : border_upper[2], + ] = 1.0 + + gauss_kernel = torch.exp(-gauss_kernel / 2.0) + if use_matlab_mask: + gauss_kernel = gauss_kernel * matlab_mask + + gauss_kernel[gauss_kernel < (epsilon * gauss_kernel.max())] = 0 + + sum_gauss_kernel: float = float(gauss_kernel.sum()) + + if sum_gauss_kernel != 0.0: + gauss_kernel = gauss_kernel / sum_gauss_kernel + + # FFT Shift + gauss_kernel = torch.cat( + (gauss_kernel[center_x - 1 :, :, :], gauss_kernel[: center_x - 1, :, :]), + dim=0, + ) + gauss_kernel = torch.cat( + (gauss_kernel[:, center_y - 1 :, :], gauss_kernel[:, : center_y - 1, :]), + dim=1, + ) + gauss_kernel = torch.cat( + (gauss_kernel[:, :, center_z - 1 :], gauss_kernel[:, :, : center_z - 1]), + dim=2, + ) + overwrite_fft_gauss = torch.fft.fftn(gauss_kernel) + input_padded_gauss_filtered: torch.Tensor = torch.real( + torch.fft.ifftn(torch.fft.fftn(input_padded) * overwrite_fft_gauss) + ) + else: + input_padded_gauss_filtered = torch.real( + torch.fft.ifftn(torch.fft.fftn(input_padded) * overwrite_fft_gauss) + ) + + start = dims_xyt + stop = ( + torch.tensor(input_padded.shape, device=dims_xyt.device, dtype=dims_xyt.dtype) + - dims_xyt + ) + + output = input_padded_gauss_filtered[ + start[0] : stop[0], start[1] : stop[1], start[2] : stop[2] + ] + + return (output, overwrite_fft_gauss) diff --git a/reproduction_effort/functions/interpolate_along_time.py b/reproduction_effort/functions/interpolate_along_time.py new file mode 100644 index 0000000..d747cb9 --- /dev/null +++ b/reproduction_effort/functions/interpolate_along_time.py @@ -0,0 +1,11 @@ +import torch + + +def interpolate_along_time(camera_sequence: list[torch.Tensor]) -> None: + camera_sequence[2][:, :, 1:] = ( + camera_sequence[2][:, :, 1:] + camera_sequence[2][:, :, :-1] + ) / 2.0 + + camera_sequence[3][:, :, 1:] = ( + camera_sequence[3][:, :, 1:] + camera_sequence[3][:, :, :-1] + ) / 2.0 diff --git a/reproduction_effort/functions/make_mask.py b/reproduction_effort/functions/make_mask.py new file mode 100644 index 0000000..0e6bdf4 --- /dev/null +++ b/reproduction_effort/functions/make_mask.py @@ -0,0 +1,26 @@ +import scipy.io as sio # type: ignore + +import torch + + +@torch.no_grad() +def make_mask( + filename_mask: str, data: torch.Tensor, device: torch.device, dtype: torch.dtype +) -> torch.Tensor: + mask: torch.Tensor = torch.tensor( + sio.loadmat(filename_mask)["maskInfo"]["maskIdx2D"][0][0], + device=device, + dtype=torch.bool, + ) + mask = mask > 0.5 + + limit: torch.Tensor = torch.tensor( + 2**16 - 1, + device=device, + dtype=dtype, + ) + + if torch.any(data.flatten() >= limit): + mask = mask & ~(torch.any(torch.any(data >= limit, dim=-1), dim=-1)) + + return mask diff --git a/reproduction_effort/functions/preprocess_camera_sequence.py b/reproduction_effort/functions/preprocess_camera_sequence.py new file mode 100644 index 0000000..d6d2555 --- /dev/null +++ b/reproduction_effort/functions/preprocess_camera_sequence.py @@ -0,0 +1,36 @@ +import torch + + +@torch.no_grad() +def preprocess_camera_sequence( + camera_sequence: torch.Tensor, + mask: torch.Tensor, + first_none_ramp_frame: int, + device: torch.device, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + + limit: torch.Tensor = torch.tensor( + 2**16 - 1, + device=device, + dtype=dtype, + ) + + camera_sequence = camera_sequence / camera_sequence[ + :, :, first_none_ramp_frame: + ].mean( + dim=2, + keepdim=True, + ) + + camera_sequence = camera_sequence.nan_to_num(nan=0.0) + + camera_sequence_zero_idx = torch.any(camera_sequence == 0, dim=-1, keepdim=True) + mask &= (~camera_sequence_zero_idx.squeeze(-1)).type(dtype=torch.bool) + camera_sequence_zero_idx = torch.tile( + camera_sequence_zero_idx, (1, 1, camera_sequence.shape[-1]) + ) + camera_sequence_zero_idx[:, :, :first_none_ramp_frame] = False + camera_sequence[camera_sequence_zero_idx] = limit + + return camera_sequence, mask diff --git a/reproduction_effort/functions/preprocessing.py b/reproduction_effort/functions/preprocessing.py new file mode 100644 index 0000000..2fd12b4 --- /dev/null +++ b/reproduction_effort/functions/preprocessing.py @@ -0,0 +1,93 @@ +import scipy.io as sio # type: ignore +import torch +import numpy as np +import json + +from functions.make_mask import make_mask +from functions.convert_camera_sequenc_to_list import convert_camera_sequenc_to_list +from functions.preprocess_camera_sequence import preprocess_camera_sequence +from functions.interpolate_along_time import interpolate_along_time +from functions.gauss_smear import gauss_smear +from functions.regression import regression + + +@torch.no_grad() +def preprocessing( + filename_metadata: str, + filename_data: str, + filename_mask: str, + device: torch.device, + first_none_ramp_frame: int, + spatial_width: float, + temporal_width: float, + target_camera: list[str], + regressor_cameras: list[str], + dtype: torch.dtype = torch.float32, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + data: torch.Tensor = torch.tensor( + sio.loadmat(filename_data)["data"].astype(np.float32), + device=device, + dtype=dtype, + ) + + with open(filename_metadata, "r") as file_handle: + metadata: dict = json.load(file_handle) + cameras: list[str] = metadata["channelKey"] + + required_order: list[str] = ["acceptor", "donor", "oxygenation", "volume"] + + mask: torch.Tensor = make_mask( + filename_mask=filename_mask, data=data, device=device, dtype=dtype + ) + + camera_sequence: list[torch.Tensor] = convert_camera_sequenc_to_list( + data=data, required_order=required_order, cameras=cameras + ) + + for num_cams in range(len(camera_sequence)): + camera_sequence[num_cams], mask = preprocess_camera_sequence( + camera_sequence=camera_sequence[num_cams], + mask=mask, + first_none_ramp_frame=first_none_ramp_frame, + device=device, + dtype=dtype, + ) + + # Interpolate in-between images + interpolate_along_time(camera_sequence) + + camera_sequence_filtered: list[torch.Tensor] = [] + for id in range(0, len(camera_sequence)): + camera_sequence_filtered.append(camera_sequence[id].clone()) + + camera_sequence_filtered = gauss_smear( + camera_sequence_filtered, + mask.type(dtype=dtype), + spatial_width=spatial_width, + temporal_width=temporal_width, + ) + + regressor_camera_ids: list[int] = [] + + for cam in regressor_cameras: + regressor_camera_ids.append(cameras.index(cam)) + + results: list[torch.Tensor] = [] + + for channel_position in range(0, len(target_camera)): + print(f"channel position: {channel_position}") + target_camera_selected = target_camera[channel_position] + target_camera_id: int = cameras.index(target_camera_selected) + + output = regression( + target_camera_id=target_camera_id, + regressor_camera_ids=regressor_camera_ids, + mask=mask, + camera_sequence=camera_sequence, + camera_sequence_filtered=camera_sequence_filtered, + first_none_ramp_frame=first_none_ramp_frame, + ) + results.append(output) + + return results[0], results[1], mask diff --git a/reproduction_effort/functions/regression.py b/reproduction_effort/functions/regression.py new file mode 100644 index 0000000..c273d0a --- /dev/null +++ b/reproduction_effort/functions/regression.py @@ -0,0 +1,118 @@ +import torch +from functions.regression_internal import regression_internal + + +@torch.no_grad() +def regression( + target_camera_id: int, + regressor_camera_ids: list[int], + mask: torch.Tensor, + camera_sequence: list[torch.Tensor], + camera_sequence_filtered: list[torch.Tensor], + first_none_ramp_frame: int, +) -> torch.Tensor: + + assert len(regressor_camera_ids) > 0 + + # ------- Prepare the target signals ---------- + target_signals_train: torch.Tensor = ( + camera_sequence_filtered[target_camera_id][..., first_none_ramp_frame:].clone() + - 1.0 + ) + target_signals_train[target_signals_train < -1] = 0.0 + + target_signals_perform: torch.Tensor = ( + camera_sequence[target_camera_id].clone() - 1.0 + ) + + # Check if everything is happy + assert target_signals_train.ndim == 3 + assert target_signals_train.ndim == target_signals_perform.ndim + assert target_signals_train.shape[0] == target_signals_perform.shape[0] + assert target_signals_train.shape[1] == target_signals_perform.shape[1] + assert ( + target_signals_train.shape[2] + first_none_ramp_frame + ) == target_signals_perform.shape[2] + # --==DONE==- + + # ------- Prepare the regressor signals ---------- + + # --- Train --- + + regressor_signals_train: torch.Tensor = torch.zeros( + ( + camera_sequence_filtered[0].shape[0], + camera_sequence_filtered[0].shape[1], + camera_sequence_filtered[0].shape[2], + len(regressor_camera_ids) + 1, + ), + device=camera_sequence_filtered[0].device, + dtype=camera_sequence_filtered[0].dtype, + ) + + # Copy the regressor signals -1 + for matrix_id, id in enumerate(regressor_camera_ids): + regressor_signals_train[..., matrix_id] = camera_sequence_filtered[id] - 1.0 + + regressor_signals_train[regressor_signals_train < -1] = 0.0 + + # Linear regressor + trend = torch.arange( + 0, regressor_signals_train.shape[-2], device=camera_sequence_filtered[0].device + ) / float(regressor_signals_train.shape[-2] - 1) + trend -= trend.mean() + trend = trend.unsqueeze(0).unsqueeze(0) + trend = trend.tile( + (regressor_signals_train.shape[0], regressor_signals_train.shape[1], 1) + ) + regressor_signals_train[..., -1] = trend + + regressor_signals_train = regressor_signals_train[:, :, first_none_ramp_frame:, :] + + # --- Perform --- + + regressor_signals_perform: torch.Tensor = torch.zeros( + ( + camera_sequence[0].shape[0], + camera_sequence[0].shape[1], + camera_sequence[0].shape[2], + len(regressor_camera_ids) + 1, + ), + device=camera_sequence[0].device, + dtype=camera_sequence[0].dtype, + ) + + # Copy the regressor signals -1 + for matrix_id, id in enumerate(regressor_camera_ids): + regressor_signals_perform[..., matrix_id] = camera_sequence[id] - 1.0 + + # Linear regressor + trend = torch.arange( + 0, regressor_signals_perform.shape[-2], device=camera_sequence[0].device + ) / float(regressor_signals_perform.shape[-2] - 1) + trend -= trend.mean() + trend = trend.unsqueeze(0).unsqueeze(0) + trend = trend.tile( + (regressor_signals_perform.shape[0], regressor_signals_perform.shape[1], 1) + ) + regressor_signals_perform[..., -1] = trend + + # --==DONE==- + + coefficients, intercept = regression_internal( + input_regressor=regressor_signals_train, input_target=target_signals_train + ) + + target_signals_perform -= ( + regressor_signals_perform * coefficients.unsqueeze(-2) + ).sum(dim=-1) + + target_signals_perform -= intercept.unsqueeze(-1) + + target_signals_perform[ + ~mask.unsqueeze(-1).tile((1, 1, target_signals_perform.shape[-1])) + ] = 0.0 + + target_signals_perform += 1.0 + + return target_signals_perform diff --git a/reproduction_effort/functions/regression_internal.py b/reproduction_effort/functions/regression_internal.py new file mode 100644 index 0000000..352d7ba --- /dev/null +++ b/reproduction_effort/functions/regression_internal.py @@ -0,0 +1,20 @@ +import torch + + +def regression_internal( + input_regressor: torch.Tensor, input_target: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + + regressor_offset = input_regressor.mean(keepdim=True, dim=-2) + target_offset = input_target.mean(keepdim=True, dim=-1) + + regressor = input_regressor - regressor_offset + target = input_target - target_offset + + coefficients, _, _, _ = torch.linalg.lstsq(regressor, target, rcond=None) # None ? + + intercept = target_offset.squeeze(-1) - ( + coefficients * regressor_offset.squeeze(-2) + ).sum(dim=-1) + + return coefficients, intercept diff --git a/reproduction_effort/heartbeatanalyse.py b/reproduction_effort/heartbeatanalyse.py new file mode 100644 index 0000000..8a26474 --- /dev/null +++ b/reproduction_effort/heartbeatanalyse.py @@ -0,0 +1,122 @@ +import torch +import matplotlib.pyplot as plt + + +from functions.preprocessing import preprocessing +from functions.bandpass import bandpass + + +if __name__ == "__main__": + + if torch.cuda.is_available(): + device_name: str = "cuda:0" + else: + device_name = "cpu" + print(f"Using device: {device_name}") + device: torch.device = torch.device(device_name) + + filename_metadata: str = "raw/Exp001_Trial001_Part001_meta.txt" + filename_data: str = "Exp001_Trial001_Part001.mat" + filename_mask: str = "2020-12-08maskPixelraw2.mat" + + first_none_ramp_frame: int = 100 + spatial_width: float = 2 + temporal_width: float = 0.1 + + lower_freqency_bandpass: float = 5.0 + upper_freqency_bandpass: float = 14.0 + + target_camera: list[str] = ["acceptor", "donor"] + regressor_cameras: list[str] = ["oxygenation", "volume"] + + ratio_sequence_a, ratio_sequence_b, mask = preprocessing( + filename_metadata=filename_metadata, + filename_data=filename_data, + filename_mask=filename_mask, + device=device, + first_none_ramp_frame=first_none_ramp_frame, + spatial_width=spatial_width, + temporal_width=temporal_width, + target_camera=target_camera, + regressor_cameras=regressor_cameras, + ) + + ratio_sequence_a = bandpass( + data=ratio_sequence_a, + device=ratio_sequence_a.device, + low_frequency=lower_freqency_bandpass, + high_frequency=upper_freqency_bandpass, + fs=100.0, + filtfilt_chuck_size=10, + ) + + ratio_sequence_b = bandpass( + data=ratio_sequence_b, + device=ratio_sequence_b.device, + low_frequency=lower_freqency_bandpass, + high_frequency=upper_freqency_bandpass, + fs=100.0, + filtfilt_chuck_size=10, + ) + + original_shape = ratio_sequence_a.shape + + ratio_sequence_a = ratio_sequence_a.flatten(start_dim=0, end_dim=-2) + ratio_sequence_b = ratio_sequence_b.flatten(start_dim=0, end_dim=-2) + + mask = mask.flatten(start_dim=0, end_dim=-1) + ratio_sequence_a = ratio_sequence_a[mask, :] + ratio_sequence_b = ratio_sequence_b[mask, :] + + ratio_sequence_a = ratio_sequence_a.movedim(0, -1) + ratio_sequence_b = ratio_sequence_b.movedim(0, -1) + + ratio_sequence_a -= ratio_sequence_a.mean(dim=0, keepdim=True) + ratio_sequence_b -= ratio_sequence_b.mean(dim=0, keepdim=True) + + u_a, s_a, Vh_a = torch.linalg.svd(ratio_sequence_a, full_matrices=False) + u_a = u_a[:, 0] + s_a = s_a[0] + Vh_a = Vh_a[0, :] + + heartbeatactivitmap_a = torch.zeros( + (original_shape[0], original_shape[1]), device=Vh_a.device, dtype=Vh_a.dtype + ).flatten(start_dim=0, end_dim=-1) + + heartbeatactivitmap_a *= torch.nan + heartbeatactivitmap_a[mask] = s_a * Vh_a + heartbeatactivitmap_a = heartbeatactivitmap_a.reshape( + (original_shape[0], original_shape[1]) + ) + + u_b, s_b, Vh_b = torch.linalg.svd(ratio_sequence_b, full_matrices=False) + u_b = u_b[:, 0] + s_b = s_b[0] + Vh_b = Vh_b[0, :] + + heartbeatactivitmap_b = torch.zeros( + (original_shape[0], original_shape[1]), device=Vh_b.device, dtype=Vh_b.dtype + ).flatten(start_dim=0, end_dim=-1) + + heartbeatactivitmap_b *= torch.nan + heartbeatactivitmap_b[mask] = s_b * Vh_b + heartbeatactivitmap_b = heartbeatactivitmap_b.reshape( + (original_shape[0], original_shape[1]) + ) + + plt.subplot(2, 1, 1) + plt.plot(u_a.cpu(), label="aceptor") + plt.plot(u_b.cpu(), label="donor") + plt.legend() + plt.subplot(2, 1, 2) + plt.imshow( + torch.cat( + ( + heartbeatactivitmap_a, + heartbeatactivitmap_b, + ), + dim=1, + ).cpu() + ) + plt.colorbar() + plt.show() diff --git a/reproduction_effort/preprocessing.py b/reproduction_effort/preprocessing.py new file mode 100644 index 0000000..483a81e --- /dev/null +++ b/reproduction_effort/preprocessing.py @@ -0,0 +1,68 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt +import h5py # type: ignore + +from functions.preprocessing import preprocessing + + +if __name__ == "__main__": + + if torch.cuda.is_available(): + device_name: str = "cuda:0" + else: + device_name = "cpu" + print(f"Using device: {device_name}") + device: torch.device = torch.device(device_name) + + filename_metadata: str = "raw/Exp001_Trial001_Part001_meta.txt" + filename_data: str = "Exp001_Trial001_Part001.mat" + filename_mask: str = "2020-12-08maskPixelraw2.mat" + + first_none_ramp_frame: int = 100 + spatial_width: float = 2 + temporal_width: float = 0.1 + + target_camera: list[str] = ["acceptor", "donor"] + regressor_cameras: list[str] = ["oxygenation", "volume"] + + data_acceptor, data_donor, mask = preprocessing( + filename_metadata=filename_metadata, + filename_data=filename_data, + filename_mask=filename_mask, + device=device, + first_none_ramp_frame=first_none_ramp_frame, + spatial_width=spatial_width, + temporal_width=temporal_width, + target_camera=target_camera, + regressor_cameras=regressor_cameras, + ) + + ratio_sequence: torch.Tensor = data_acceptor / data_donor + + new: np.ndarray = ratio_sequence.cpu().numpy() + + file_handle = h5py.File("old.mat", "r") + old: np.ndarray = np.array(file_handle["ratioSequence"]) + # HDF5 loads everything backwards... + old = np.moveaxis(old, 0, -1) + old = np.moveaxis(old, 0, -2) + + pos_x = 25 + pos_y = 75 + + plt.subplot(2, 1, 1) + new_select = new[pos_x, pos_y, :] + old_select = old[pos_x, pos_y, :] + plt.plot(new_select, label="New") + plt.plot(old_select, "--", label="Old") + plt.plot(old_select - new_select + 1.0, label="Old - New + 1") + plt.title(f"Position: {pos_x}, {pos_y}") + plt.legend() + + plt.subplot(2, 1, 2) + differences = (np.abs(new - old)).max(axis=-1) + plt.imshow(differences) + plt.title("Max of abs(new-old) along time") + plt.colorbar() + plt.show()