2023-07-12 22:06:48 +02:00
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
import skimage
|
|
|
|
|
|
|
|
filename: str = "example_data_crop"
|
|
|
|
threshold: float = 0.8
|
2023-07-12 23:08:51 +02:00
|
|
|
tolerance: float | None = None
|
2023-07-12 22:06:48 +02:00
|
|
|
minimum_area: int = 100
|
|
|
|
|
|
|
|
torch_device: torch.device = torch.device(
|
|
|
|
"cuda:0" if torch.cuda.is_available() else "cpu"
|
|
|
|
)
|
|
|
|
|
|
|
|
print("Load data")
|
|
|
|
input = np.load(filename + str("_decorrelated.npy"))
|
|
|
|
data = torch.tensor(input, device=torch_device)
|
|
|
|
del input
|
|
|
|
print("loading done")
|
|
|
|
|
|
|
|
data = data.nan_to_num(nan=0.0)
|
|
|
|
data -= data.mean(dim=0, keepdim=True)
|
|
|
|
data /= data.std(dim=0, keepdim=True)
|
|
|
|
|
|
|
|
master_image = (data.max(dim=0)[0] - data.min(dim=0)[0]).nan_to_num(nan=0.0).clone()
|
|
|
|
temp_image = master_image.clone()
|
|
|
|
master_mask = torch.ones_like(temp_image)
|
|
|
|
|
|
|
|
stored_contours: list = []
|
|
|
|
counter: int = 0
|
|
|
|
contours_found: int = 0
|
|
|
|
while int(master_mask.sum()) > 0:
|
|
|
|
if counter % 100 == 0:
|
|
|
|
print(
|
|
|
|
f"number of pixel tested: {counter} remaining pixels: {int(master_mask.sum())} cells found: {contours_found}"
|
|
|
|
)
|
|
|
|
counter += 1
|
|
|
|
mask: np.ndarray | None = None
|
|
|
|
|
|
|
|
temp_image *= master_mask
|
|
|
|
|
|
|
|
# Convert index to 2D
|
|
|
|
temp_idx = temp_image.argmax()
|
|
|
|
x = int(temp_idx // int(temp_image.shape[1]))
|
|
|
|
y = int(temp_idx - x * int(temp_image.shape[1]))
|
|
|
|
if bool(master_mask[x, y]) is False:
|
|
|
|
break
|
|
|
|
|
|
|
|
test_data = data[:, x, y].clone()
|
|
|
|
|
|
|
|
# Calculate the correlation
|
|
|
|
scale = (data * test_data.unsqueeze(-1).unsqueeze(-1)).mean(dim=0)
|
|
|
|
scale = scale.nan_to_num(nan=0.0)
|
|
|
|
scale *= master_mask
|
|
|
|
|
|
|
|
# Check for areas with high correlation
|
|
|
|
image = (scale > threshold).type(torch.uint8).cpu().numpy()
|
|
|
|
|
|
|
|
found_something: bool = False
|
|
|
|
# Find the coutours
|
|
|
|
for contour in skimage.measure.find_contours(image, 0):
|
|
|
|
# soften outline
|
2023-07-12 23:08:51 +02:00
|
|
|
if tolerance is not None:
|
|
|
|
coords = skimage.measure.approximate_polygon(
|
|
|
|
contour, tolerance=tolerance
|
|
|
|
).astype(dtype=np.float32)
|
|
|
|
else:
|
|
|
|
coords = contour.astype(dtype=np.float32)
|
2023-07-12 23:58:19 +02:00
|
|
|
|
2023-07-12 22:06:48 +02:00
|
|
|
# Make a mask out of the polygon
|
|
|
|
mask = skimage.draw.polygon2mask(scale.shape, coords)
|
|
|
|
assert mask is not None
|
|
|
|
|
|
|
|
# check if this is the contour in which the original point was
|
|
|
|
if mask[x, y]:
|
|
|
|
found_something = True
|
|
|
|
|
|
|
|
if mask.sum() > minimum_area:
|
|
|
|
stored_contours.append(coords)
|
|
|
|
contours_found += 1
|
|
|
|
idx_set_mask = torch.where(torch.tensor(mask, device=torch_device) > 0)
|
|
|
|
|
|
|
|
master_mask[idx_set_mask] = 0.0
|
|
|
|
break
|
|
|
|
|
|
|
|
if found_something is False:
|
|
|
|
master_mask[x, y] = 0.0
|
|
|
|
print("-==- DONE -==-")
|
|
|
|
np.save("cells.npy", np.array(stored_contours, dtype=object))
|
|
|
|
|
|
|
|
plt.imshow(master_image.cpu(), cmap="hot")
|
|
|
|
for i in range(0, len(stored_contours)):
|
|
|
|
plt.plot(stored_contours[i][:, 1], stored_contours[i][:, 0], "-g", linewidth=2)
|
|
|
|
plt.colorbar()
|
|
|
|
plt.show()
|