diff --git a/initial_cell_estimate.py b/initial_cell_estimate.py new file mode 100644 index 0000000..4a299f2 --- /dev/null +++ b/initial_cell_estimate.py @@ -0,0 +1,91 @@ +import numpy as np +import torch +import matplotlib.pyplot as plt +import skimage + +filename: str = "example_data_crop" +threshold: float = 0.8 +tolerance: float = 1.0 +minimum_area: int = 100 + +torch_device: torch.device = torch.device( + "cuda:0" if torch.cuda.is_available() else "cpu" +) + +print("Load data") +input = np.load(filename + str("_decorrelated.npy")) +data = torch.tensor(input, device=torch_device) +del input +print("loading done") + +data = data.nan_to_num(nan=0.0) +data -= data.mean(dim=0, keepdim=True) +data /= data.std(dim=0, keepdim=True) + +master_image = (data.max(dim=0)[0] - data.min(dim=0)[0]).nan_to_num(nan=0.0).clone() +temp_image = master_image.clone() +master_mask = torch.ones_like(temp_image) + +stored_contours: list = [] +counter: int = 0 +contours_found: int = 0 +while int(master_mask.sum()) > 0: + if counter % 100 == 0: + print( + f"number of pixel tested: {counter} remaining pixels: {int(master_mask.sum())} cells found: {contours_found}" + ) + counter += 1 + mask: np.ndarray | None = None + + temp_image *= master_mask + + # Convert index to 2D + temp_idx = temp_image.argmax() + x = int(temp_idx // int(temp_image.shape[1])) + y = int(temp_idx - x * int(temp_image.shape[1])) + if bool(master_mask[x, y]) is False: + break + + test_data = data[:, x, y].clone() + + # Calculate the correlation + scale = (data * test_data.unsqueeze(-1).unsqueeze(-1)).mean(dim=0) + scale = scale.nan_to_num(nan=0.0) + scale *= master_mask + + # Check for areas with high correlation + image = (scale > threshold).type(torch.uint8).cpu().numpy() + + found_something: bool = False + # Find the coutours + for contour in skimage.measure.find_contours(image, 0): + # soften outline + coords = skimage.measure.approximate_polygon( + contour, tolerance=tolerance + ).astype(dtype=np.float32) + # Make a mask out of the polygon + mask = skimage.draw.polygon2mask(scale.shape, coords) + assert mask is not None + + # check if this is the contour in which the original point was + if mask[x, y]: + found_something = True + + if mask.sum() > minimum_area: + stored_contours.append(coords) + contours_found += 1 + idx_set_mask = torch.where(torch.tensor(mask, device=torch_device) > 0) + + master_mask[idx_set_mask] = 0.0 + break + + if found_something is False: + master_mask[x, y] = 0.0 +print("-==- DONE -==-") +np.save("cells.npy", np.array(stored_contours, dtype=object)) + +plt.imshow(master_image.cpu(), cmap="hot") +for i in range(0, len(stored_contours)): + plt.plot(stored_contours[i][:, 1], stored_contours[i][:, 0], "-g", linewidth=2) +plt.colorbar() +plt.show()