ca_imaging_funhouse/inspection_30fps_no_glow_main_svd_removed.py
2023-07-14 22:33:18 +02:00

166 lines
4.9 KiB
Python

import numpy as np
import torch
import matplotlib.pyplot as plt
import skimage
from scipy.stats import skew
from svd import to_remove
import torchvision as tv
from ImageAlignment import ImageAlignment
from Anime import Anime
filename: str = "example_data_crop"
use_svd: bool = True
show_movie: bool = True
torch_device: torch.device = torch.device(
"cuda:0" if torch.cuda.is_available() else "cpu"
)
with torch.no_grad():
print("Load data")
input = np.load(filename + str(".npy")) # str("_decorrelated.npy"))
# del input
print("loading done")
stored_contours = np.load("cells.npy", allow_pickle=True)
kernel_size_pooling = int(np.load("kernel_size_pooling.npy"))
fill_value = float(np.load("fill_value.npy"))
image_alignment = ImageAlignment(default_dtype=torch.float32, device=torch_device)
tvec = torch.tensor(np.load(filename + "_tvec.npy"), device=torch_device)
np_svd_data = np.load(
filename + "_svd.npz",
)
whiten_mean = torch.tensor(np_svd_data["whiten_mean"], device=torch_device)
whiten_k = torch.tensor(np_svd_data["whiten_k"], device=torch_device)
eigenvalues = torch.tensor(np_svd_data["eigenvalues"], device=torch_device)
del np_svd_data
data = torch.tensor(input, device=torch_device)
for id in range(0, data.shape[0]):
data[id, ...] = tv.transforms.functional.affine(
img=data[id, ...].unsqueeze(0),
angle=0,
translate=[tvec[id, 1], tvec[id, 0]],
scale=1.0,
shear=0,
fill=fill_value,
).squeeze(0)
data -= data.min(dim=0, keepdim=True)[0]
to_remove_data = to_remove(data, whiten_k, whiten_mean)
data -= to_remove_data
del to_remove_data
print("Pooling")
# Warning: The contour masks have the same size as the binned data!!!
avage_pooling = torch.nn.AvgPool2d(
kernel_size=(kernel_size_pooling, kernel_size_pooling),
stride=(kernel_size_pooling, kernel_size_pooling),
)
data = avage_pooling(data)
if use_svd:
data_flat = torch.flatten(
data.nan_to_num(nan=0.0).movedim(0, -1),
start_dim=0,
end_dim=1,
)
to_plot = torch.zeros(
(int(data.shape[0]), int(stored_contours.shape[0])),
device=torch_device,
dtype=torch.float32,
)
print("Calculate cell's time series")
for id in range(0, stored_contours.shape[0]):
mask = torch.tensor(
skimage.draw.polygon2mask(
(int(data.shape[1]), int(data.shape[2])), stored_contours[id]
),
device=torch_device,
dtype=torch.float32,
)
if use_svd:
mask_flat = torch.flatten(
mask.unsqueeze(0).nan_to_num(nan=0.0).movedim(0, -1),
start_dim=0,
end_dim=1,
)
idx = torch.where(mask_flat > 0)[0]
temp = data_flat[idx, :].clone()
whiten_mean = torch.mean(temp, dim=-1)
temp -= whiten_mean.unsqueeze(-1)
svd_u, svd_s, _ = torch.svd_lowrank(temp, q=6)
whiten_k = (
torch.sign(svd_u[0, :]).unsqueeze(0)
* svd_u
/ (svd_s.unsqueeze(0) + 1e-20)
)[:, 0]
temp = temp * whiten_k.unsqueeze(-1)
data_svd = temp.movedim(-1, 0).sum(dim=-1)
to_plot[:, id] = data_svd
else:
ts = (data * mask.unsqueeze(0)).nan_to_num(nan=0.0).sum(
dim=(-2, -1)
) / mask.sum()
to_plot[:, id] = ts
if show_movie:
print("Calculate movie")
# Clean tensor
data *= 0.0
for id in range(0, stored_contours.shape[0]):
mask = torch.tensor(
skimage.draw.polygon2mask(
(int(data.shape[1]), int(data.shape[2])), stored_contours[id]
),
device=torch_device,
dtype=torch.float32,
)
# * 1.0 - mask: otherwise the overlapping outlines look bad
# Yes... reshape and indices would be faster...
data *= 1.0 - mask.unsqueeze(0)
data += mask.unsqueeze(0) * to_plot[:, id].unsqueeze(1).unsqueeze(2)
ani = Anime()
ani.show(data)
exit()
skew_value = skew(to_plot.cpu().numpy(), axis=0)
skew_idx = np.flip(skew_value.argsort())
skew_value = skew_value[skew_idx]
to_plot_np = to_plot.cpu().numpy()
to_plot_np = to_plot_np[:, skew_idx]
plt.imshow(to_plot_np.T, cmap="gray_r", interpolation="nearest")
plt.colorbar()
plt.show()
# plt.plot(to_plot[:, 0:5].cpu())
# plt.show()
# block_size: int = 8
# # print(to_plot.shape[1] // block_size)
# for i in range(0, 4 * 8):
# plt.subplot(8, 4, i + 1)
# plt.plot(to_plot[:, i * block_size : (i + 1) * block_size].cpu())
# plt.ylim(
# [
# to_plot.min().cpu(),
# to_plot.max().cpu(),
# ]
# )
# plt.show()