diff --git a/inspection_30fps_no_glow_main_svd_removed.py b/inspection_30fps_no_glow_main_svd_removed.py new file mode 100644 index 0000000..fdf508f --- /dev/null +++ b/inspection_30fps_no_glow_main_svd_removed.py @@ -0,0 +1,156 @@ +import numpy as np +import torch +import matplotlib.pyplot as plt +import skimage +from scipy.stats import skew +from svd import calculate_svd, to_remove, calculate_translation +import torchvision as tv + +from ImageAlignment import ImageAlignment + +# from Anime import Anime + +filename: str = "example_data_crop" +use_svd: bool = True + +torch_device: torch.device = torch.device( + "cuda:0" if torch.cuda.is_available() else "cpu" +) + +with torch.no_grad(): + print("Load data") + input = np.load(filename + str(".npy")) # str("_decorrelated.npy")) + data = torch.tensor(input, device=torch_device) + # del input + print("loading done") + + fill_value: float = 0.0 + + print("Movement compensation [BROKEN!!!!]") + print("During development, information about what could move was missing.") + print("Thus the preprocessing before shift determination may not work.") + data -= data.min(dim=0)[0] + data /= data.std(dim=0, keepdim=True) + 1e-20 + + image_alignment = ImageAlignment(default_dtype=torch.float32, device=torch_device) + + tvec = calculate_translation( + input=data, + reference_image=data[0, ...].clone(), + image_alignment=image_alignment, + ) + tvec_media = tvec.median(dim=0)[0] + print(f"Median of movement: {tvec_media[0]}, {tvec_media[1]}") + + data = torch.tensor(input, device=torch_device) + data -= data.min(dim=0, keepdim=True)[0] + for id in range(0, data.shape[0]): + data[id, ...] = tv.transforms.functional.affine( + img=data[id, ...].unsqueeze(0), + angle=0, + translate=[tvec[id, 1], tvec[id, 0]], + scale=1.0, + shear=0, + fill=fill_value, + ).squeeze(0) + + print("SVD") + whiten_mean, whiten_k, eigenvalues = calculate_svd(data) + # ---- + data = torch.tensor(input, device=torch_device) + for id in range(0, data.shape[0]): + data[id, ...] = tv.transforms.functional.affine( + img=data[id, ...].unsqueeze(0), + angle=0, + translate=[tvec[id, 1], tvec[id, 0]], + scale=1.0, + shear=0, + fill=fill_value, + ).squeeze(0) + data -= data.min(dim=0, keepdim=True)[0] + to_remove_data = to_remove(data, whiten_k, whiten_mean) + + data -= to_remove_data + del to_remove_data + + stored_contours = np.load("cells.npy", allow_pickle=True) + + if use_svd: + data_flat = torch.flatten( + data.nan_to_num(nan=0.0).movedim(0, -1), + start_dim=0, + end_dim=1, + ) + + to_plot = torch.zeros( + (int(data.shape[0]), int(stored_contours.shape[0])), + device=torch_device, + dtype=torch.float32, + ) + + print("Calculate cell's time series") + + for id in range(0, stored_contours.shape[0]): + mask = torch.tensor( + skimage.draw.polygon2mask( + (int(data.shape[1]), int(data.shape[2])), stored_contours[id] + ), + device=torch_device, + dtype=torch.float32, + ) + if use_svd: + mask_flat = torch.flatten( + mask.unsqueeze(0).nan_to_num(nan=0.0).movedim(0, -1), + start_dim=0, + end_dim=1, + ) + idx = torch.where(mask_flat > 0)[0] + temp = data_flat[idx, :].clone() + whiten_mean = torch.mean(temp, dim=-1) + temp -= whiten_mean.unsqueeze(-1) + svd_u, svd_s, _ = torch.svd_lowrank(temp, q=6) + + whiten_k = ( + torch.sign(svd_u[0, :]).unsqueeze(0) + * svd_u + / (svd_s.unsqueeze(0) + 1e-20) + )[:, 0] + + temp = temp * whiten_k.unsqueeze(-1) + data_svd = temp.movedim(-1, 0).sum(dim=-1) + to_plot[:, id] = data_svd + else: + ts = (data * mask.unsqueeze(0)).nan_to_num(nan=0.0).sum( + dim=(-2, -1) + ) / mask.sum() + to_plot[:, id] = ts + + +skew_value = skew(to_plot.cpu().numpy(), axis=0) +skew_idx = np.flip(skew_value.argsort()) +skew_value = skew_value[skew_idx] + +to_plot_np = to_plot.cpu().numpy() +to_plot_np = to_plot_np[:, skew_idx] + + +plt.imshow(to_plot_np.T, cmap="gray_r", interpolation="nearest") +plt.colorbar() +plt.show() + + +# plt.plot(to_plot[:, 0:5].cpu()) +# plt.show() + +# block_size: int = 8 +# # print(to_plot.shape[1] // block_size) +# for i in range(0, 4 * 8): +# plt.subplot(8, 4, i + 1) +# plt.plot(to_plot[:, i * block_size : (i + 1) * block_size].cpu()) +# plt.ylim( +# [ +# to_plot.min().cpu(), +# to_plot.max().cpu(), +# ] +# ) +# plt.show()