Add files via upload
This commit is contained in:
parent
4c0b9d63ec
commit
eedae1ae44
1 changed files with 156 additions and 0 deletions
156
inspection_30fps_no_glow_main_svd_removed.py
Normal file
156
inspection_30fps_no_glow_main_svd_removed.py
Normal file
|
@ -0,0 +1,156 @@
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import skimage
|
||||||
|
from scipy.stats import skew
|
||||||
|
from svd import calculate_svd, to_remove, calculate_translation
|
||||||
|
import torchvision as tv
|
||||||
|
|
||||||
|
from ImageAlignment import ImageAlignment
|
||||||
|
|
||||||
|
# from Anime import Anime
|
||||||
|
|
||||||
|
filename: str = "example_data_crop"
|
||||||
|
use_svd: bool = True
|
||||||
|
|
||||||
|
torch_device: torch.device = torch.device(
|
||||||
|
"cuda:0" if torch.cuda.is_available() else "cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
print("Load data")
|
||||||
|
input = np.load(filename + str(".npy")) # str("_decorrelated.npy"))
|
||||||
|
data = torch.tensor(input, device=torch_device)
|
||||||
|
# del input
|
||||||
|
print("loading done")
|
||||||
|
|
||||||
|
fill_value: float = 0.0
|
||||||
|
|
||||||
|
print("Movement compensation [BROKEN!!!!]")
|
||||||
|
print("During development, information about what could move was missing.")
|
||||||
|
print("Thus the preprocessing before shift determination may not work.")
|
||||||
|
data -= data.min(dim=0)[0]
|
||||||
|
data /= data.std(dim=0, keepdim=True) + 1e-20
|
||||||
|
|
||||||
|
image_alignment = ImageAlignment(default_dtype=torch.float32, device=torch_device)
|
||||||
|
|
||||||
|
tvec = calculate_translation(
|
||||||
|
input=data,
|
||||||
|
reference_image=data[0, ...].clone(),
|
||||||
|
image_alignment=image_alignment,
|
||||||
|
)
|
||||||
|
tvec_media = tvec.median(dim=0)[0]
|
||||||
|
print(f"Median of movement: {tvec_media[0]}, {tvec_media[1]}")
|
||||||
|
|
||||||
|
data = torch.tensor(input, device=torch_device)
|
||||||
|
data -= data.min(dim=0, keepdim=True)[0]
|
||||||
|
for id in range(0, data.shape[0]):
|
||||||
|
data[id, ...] = tv.transforms.functional.affine(
|
||||||
|
img=data[id, ...].unsqueeze(0),
|
||||||
|
angle=0,
|
||||||
|
translate=[tvec[id, 1], tvec[id, 0]],
|
||||||
|
scale=1.0,
|
||||||
|
shear=0,
|
||||||
|
fill=fill_value,
|
||||||
|
).squeeze(0)
|
||||||
|
|
||||||
|
print("SVD")
|
||||||
|
whiten_mean, whiten_k, eigenvalues = calculate_svd(data)
|
||||||
|
# ----
|
||||||
|
data = torch.tensor(input, device=torch_device)
|
||||||
|
for id in range(0, data.shape[0]):
|
||||||
|
data[id, ...] = tv.transforms.functional.affine(
|
||||||
|
img=data[id, ...].unsqueeze(0),
|
||||||
|
angle=0,
|
||||||
|
translate=[tvec[id, 1], tvec[id, 0]],
|
||||||
|
scale=1.0,
|
||||||
|
shear=0,
|
||||||
|
fill=fill_value,
|
||||||
|
).squeeze(0)
|
||||||
|
data -= data.min(dim=0, keepdim=True)[0]
|
||||||
|
to_remove_data = to_remove(data, whiten_k, whiten_mean)
|
||||||
|
|
||||||
|
data -= to_remove_data
|
||||||
|
del to_remove_data
|
||||||
|
|
||||||
|
stored_contours = np.load("cells.npy", allow_pickle=True)
|
||||||
|
|
||||||
|
if use_svd:
|
||||||
|
data_flat = torch.flatten(
|
||||||
|
data.nan_to_num(nan=0.0).movedim(0, -1),
|
||||||
|
start_dim=0,
|
||||||
|
end_dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
to_plot = torch.zeros(
|
||||||
|
(int(data.shape[0]), int(stored_contours.shape[0])),
|
||||||
|
device=torch_device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Calculate cell's time series")
|
||||||
|
|
||||||
|
for id in range(0, stored_contours.shape[0]):
|
||||||
|
mask = torch.tensor(
|
||||||
|
skimage.draw.polygon2mask(
|
||||||
|
(int(data.shape[1]), int(data.shape[2])), stored_contours[id]
|
||||||
|
),
|
||||||
|
device=torch_device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
if use_svd:
|
||||||
|
mask_flat = torch.flatten(
|
||||||
|
mask.unsqueeze(0).nan_to_num(nan=0.0).movedim(0, -1),
|
||||||
|
start_dim=0,
|
||||||
|
end_dim=1,
|
||||||
|
)
|
||||||
|
idx = torch.where(mask_flat > 0)[0]
|
||||||
|
temp = data_flat[idx, :].clone()
|
||||||
|
whiten_mean = torch.mean(temp, dim=-1)
|
||||||
|
temp -= whiten_mean.unsqueeze(-1)
|
||||||
|
svd_u, svd_s, _ = torch.svd_lowrank(temp, q=6)
|
||||||
|
|
||||||
|
whiten_k = (
|
||||||
|
torch.sign(svd_u[0, :]).unsqueeze(0)
|
||||||
|
* svd_u
|
||||||
|
/ (svd_s.unsqueeze(0) + 1e-20)
|
||||||
|
)[:, 0]
|
||||||
|
|
||||||
|
temp = temp * whiten_k.unsqueeze(-1)
|
||||||
|
data_svd = temp.movedim(-1, 0).sum(dim=-1)
|
||||||
|
to_plot[:, id] = data_svd
|
||||||
|
else:
|
||||||
|
ts = (data * mask.unsqueeze(0)).nan_to_num(nan=0.0).sum(
|
||||||
|
dim=(-2, -1)
|
||||||
|
) / mask.sum()
|
||||||
|
to_plot[:, id] = ts
|
||||||
|
|
||||||
|
|
||||||
|
skew_value = skew(to_plot.cpu().numpy(), axis=0)
|
||||||
|
skew_idx = np.flip(skew_value.argsort())
|
||||||
|
skew_value = skew_value[skew_idx]
|
||||||
|
|
||||||
|
to_plot_np = to_plot.cpu().numpy()
|
||||||
|
to_plot_np = to_plot_np[:, skew_idx]
|
||||||
|
|
||||||
|
|
||||||
|
plt.imshow(to_plot_np.T, cmap="gray_r", interpolation="nearest")
|
||||||
|
plt.colorbar()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
# plt.plot(to_plot[:, 0:5].cpu())
|
||||||
|
# plt.show()
|
||||||
|
|
||||||
|
# block_size: int = 8
|
||||||
|
# # print(to_plot.shape[1] // block_size)
|
||||||
|
# for i in range(0, 4 * 8):
|
||||||
|
# plt.subplot(8, 4, i + 1)
|
||||||
|
# plt.plot(to_plot[:, i * block_size : (i + 1) * block_size].cpu())
|
||||||
|
# plt.ylim(
|
||||||
|
# [
|
||||||
|
# to_plot.min().cpu(),
|
||||||
|
# to_plot.max().cpu(),
|
||||||
|
# ]
|
||||||
|
# )
|
||||||
|
# plt.show()
|
Loading…
Reference in a new issue