From b0e61e13d8b0776376416527e130812ac854b889 Mon Sep 17 00:00:00 2001 From: David Rotermund <54365609+davrot@users.noreply.github.com> Date: Thu, 13 Jul 2023 02:11:36 +0200 Subject: [PATCH] Add files via upload --- inspection.py | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/inspection.py b/inspection.py index 654e03b..739189f 100644 --- a/inspection.py +++ b/inspection.py @@ -5,6 +5,7 @@ import skimage from scipy.stats import skew filename: str = "example_data_crop" +use_svd: bool = True torch_device: torch.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu" @@ -18,6 +19,13 @@ print("loading done") stored_contours = np.load("cells.npy", allow_pickle=True) +if use_svd: + data_flat = torch.flatten( + data.nan_to_num(nan=0.0).movedim(0, -1), + start_dim=0, + end_dim=1, + ) + to_plot = torch.zeros( (int(data.shape[0]), int(stored_contours.shape[0])), device=torch_device, @@ -31,9 +39,30 @@ for id in range(0, stored_contours.shape[0]): device=torch_device, dtype=torch.float32, ) + if use_svd: + mask_flat = torch.flatten( + mask.unsqueeze(0).nan_to_num(nan=0.0).movedim(0, -1), + start_dim=0, + end_dim=1, + ) + idx = torch.where(mask_flat > 0)[0] + temp = data_flat[idx, :].clone() + whiten_mean = torch.mean(temp, dim=-1) + temp -= whiten_mean.unsqueeze(-1) + svd_u, svd_s, _ = torch.svd_lowrank(temp, q=6) - ts = (data * mask.unsqueeze(0)).nan_to_num(nan=0.0).sum(dim=(-2, -1)) / mask.sum() - to_plot[:, id] = ts + whiten_k = ( + torch.sign(svd_u[0, :]).unsqueeze(0) * svd_u / (svd_s.unsqueeze(0) + 1e-20) + )[:, 0] + + temp = temp * whiten_k.unsqueeze(-1) + data_svd = temp.movedim(-1, 0).sum(dim=-1) + to_plot[:, id] = data_svd + else: + ts = (data * mask.unsqueeze(0)).nan_to_num(nan=0.0).sum( + dim=(-2, -1) + ) / mask.sum() + to_plot[:, id] = ts skew_value = skew(to_plot.cpu().numpy(), axis=0)