From dab7dcb786972f24b55c39425913e32dd1826695 Mon Sep 17 00:00:00 2001 From: David Rotermund <54365609+davrot@users.noreply.github.com> Date: Mon, 31 Jul 2023 15:23:13 +0200 Subject: [PATCH] Add files via upload --- processing_chain/BuildImage.py | 89 ++++ processing_chain/ContourExtract.py | 288 +++++++++++ processing_chain/DiscardElements.py | 162 ++++++ processing_chain/OnlineEncoding.py | 711 ++++++++++++++++++++++++++ processing_chain/OnlinePerception.py | 299 +++++++++++ processing_chain/PatchGenerator.py | 237 +++++++++ processing_chain/Sparsifier.py | 418 +++++++++++++++ processing_chain/WebCam.py | 325 ++++++++++++ processing_chain/Yolo5Segmentation.py | 402 +++++++++++++++ 9 files changed, 2931 insertions(+) create mode 100644 processing_chain/BuildImage.py create mode 100644 processing_chain/ContourExtract.py create mode 100644 processing_chain/DiscardElements.py create mode 100644 processing_chain/OnlineEncoding.py create mode 100644 processing_chain/OnlinePerception.py create mode 100644 processing_chain/PatchGenerator.py create mode 100644 processing_chain/Sparsifier.py create mode 100644 processing_chain/WebCam.py create mode 100644 processing_chain/Yolo5Segmentation.py diff --git a/processing_chain/BuildImage.py b/processing_chain/BuildImage.py new file mode 100644 index 0000000..eedf3f1 --- /dev/null +++ b/processing_chain/BuildImage.py @@ -0,0 +1,89 @@ +import torch + + +def clip_coordinates(x_canvas: int, dx_canvas: int, dx_dict: int): + + x_canvas = int(x_canvas) + dx_canvas = int(dx_canvas) + dx_dict = int(dx_dict) + dr_dict = int(dx_dict // 2) + + x0_canvas = int(x_canvas - dr_dict) + # placement outside right boundary? + if x0_canvas >= dx_canvas: + return None + + x1_canvas = int(x_canvas + dr_dict + (dx_dict % 2)) + # placement outside left boundary? + if x1_canvas <= 0: + return None + + # clip to the left? + if x0_canvas < 0: + x0_dict = -x0_canvas + x0_canvas = 0 + else: + x0_dict = 0 + + # clip to the right? + if x1_canvas > dx_canvas: + x1_dict = dx_dict - (x1_canvas - dx_canvas) + x1_canvas = dx_canvas + else: + x1_dict = dx_dict + + # print(x0_canvas, x1_canvas, x0_dict, x1_dict) + assert (x1_canvas - x0_canvas) == (x1_dict - x0_dict) + + return x0_canvas, x1_canvas, x0_dict, x1_dict + + +def BuildImage( + canvas_size: torch.Size, + dictionary: torch.Tensor, + position_found: torch.Tensor, + default_dtype, + torch_device, +): + + assert position_found is not None + assert dictionary is not None + + canvas_size_copy = torch.tensor(canvas_size) + assert canvas_size_copy.shape[0] == 4 + canvas_size_copy[1] = 1 + output = torch.zeros( + canvas_size_copy.tolist(), + device=torch_device, + dtype=default_dtype, + ) + + dx_canvas = canvas_size[-2] + dy_canvas = canvas_size[-1] + dx_dict = dictionary.shape[-2] + dy_dict = dictionary.shape[-1] + + for pattern_id in range(0, position_found.shape[0]): + for patch_id in range(0, position_found.shape[1]): + + x_canvas = position_found[pattern_id, patch_id, 1] + y_canvas = position_found[pattern_id, patch_id, 2] + + xv = clip_coordinates(x_canvas, dx_canvas, dx_dict) + if xv == None: + break + + yv = clip_coordinates(y_canvas, dy_canvas, dy_dict) + if yv == None: + break + + if dictionary.shape[0] > 1: + elem_idx = int(position_found[pattern_id, patch_id, 0]) + else: + elem_idx = 0 + + output[pattern_id, 0, xv[0] : xv[1], yv[0] : yv[1]] += dictionary[ + elem_idx, 0, xv[2] : xv[3], yv[2] : yv[3] + ] + + return output diff --git a/processing_chain/ContourExtract.py b/processing_chain/ContourExtract.py new file mode 100644 index 0000000..7f1d82e --- /dev/null +++ b/processing_chain/ContourExtract.py @@ -0,0 +1,288 @@ +# ContourExtract.py +# ==================================== +# extracts contours from gray-level images +# +# Version V1.0, pre-07.03.2023: +# no actual changes, is David's last code version... +# +# Version V1.1, 07.03.2023: +# merged David's rebuild code (GUI capable) +# + +import torch +import time +import math + + +class ContourExtract(torch.nn.Module): + image_width: int = -1 + image_height: int = -1 + output_threshold: torch.Tensor = torch.tensor(-1) + + n_orientations: int + sigma_kernel: float = 0 + lambda_kernel: float = 0 + gamma_aspect_ratio: float = 0 + + kernel_axis_x: torch.Tensor | None = None + kernel_axis_y: torch.Tensor | None = None + target_orientations: torch.Tensor + weight_vector: torch.Tensor + + image_scale: float | None + + psi_phase_offset_cos: torch.Tensor + psi_phase_offset_sin: torch.Tensor + + fft_gabor_cos_bank: torch.Tensor | None = None + fft_gabor_sin_bank: torch.Tensor | None = None + + pi: torch.Tensor + torch_device: torch.device + default_dtype = torch.float32 + + rebuild_kernels: bool + + def __init__( + self, + n_orientations: int, + sigma_kernel: float, + lambda_kernel: float, + gamma_aspect_ratio: float = 1.0, + image_scale: float | None = 255.0, + torch_device: str = "cpu", + ): + super().__init__() + self.torch_device = torch.device(torch_device) + + self.pi = torch.tensor( + math.pi, + device=self.torch_device, + dtype=self.default_dtype, + ) + + self.psi_phase_offset_cos = torch.tensor( + 0.0, + device=self.torch_device, + dtype=self.default_dtype, + ) + self.psi_phase_offset_sin = -self.pi / 2 + self.gamma_aspect_ratio = gamma_aspect_ratio + self.image_scale = image_scale + + self.update_settings( + n_orientations=n_orientations, + sigma_kernel=sigma_kernel, + lambda_kernel=lambda_kernel, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + + assert input.ndim == 4, "input must have 4 dims!" + + # We can only handle one color channel + assert input.shape[1] == 1, "input.shape[1] must be 1!" + + input = input.type(dtype=self.default_dtype).to(device=self.torch_device) + if self.image_scale is not None: + # scale grey level [0, 255] to range [0.0, 1.0] + input /= self.image_scale + + # Do we have valid kernels? + if input.shape[-2] != self.image_width: + self.image_width = input.shape[-2] + self.rebuild_kernels = True + + if input.shape[-1] != self.image_height: + self.image_height = input.shape[-1] + self.rebuild_kernels = True + + assert self.image_width > 0 + assert self.image_height > 0 + + t0 = time.time() + + # We need to rebuild the kernels + if self.rebuild_kernels is True: + + self.kernel_axis_x = self.create_kernel_axis(self.image_width) + self.kernel_axis_y = self.create_kernel_axis(self.image_height) + + assert self.kernel_axis_x is not None + assert self.kernel_axis_y is not None + + gabor_cos_bank, gabor_sin_bank = self.create_gabor_filter_bank() + + assert gabor_cos_bank is not None + assert gabor_sin_bank is not None + + self.fft_gabor_cos_bank = torch.fft.rfft2( + gabor_cos_bank, s=None, dim=(-2, -1), norm=None + ).unsqueeze(0) + + self.fft_gabor_sin_bank = torch.fft.rfft2( + gabor_sin_bank, s=None, dim=(-2, -1), norm=None + ).unsqueeze(0) + + # compute threshold for ignoring non-zero outputs arising due + # to numerical imprecision (fft, kernel definition) + assert self.weight_vector is not None + norm_input = torch.full( + (1, 1, self.image_width, self.image_height), + 1.0, + device=self.torch_device, + dtype=self.default_dtype, + ) + norm_fft_input = torch.fft.rfft2( + norm_input, s=None, dim=(-2, -1), norm=None + ) + + norm_output_sin = torch.fft.irfft2( + norm_fft_input * self.fft_gabor_sin_bank, + s=None, + dim=(-2, -1), + norm=None, + ) + norm_output_cos = torch.fft.irfft2( + norm_fft_input * self.fft_gabor_cos_bank, + s=None, + dim=(-2, -1), + norm=None, + ) + norm_value_matrix = torch.sqrt(norm_output_sin**2 + norm_output_cos**2) + + # norm_output = torch.abs( + # (self.weight_vector * norm_value_matrix).sum(dim=-1) + # ).type(dtype=torch.float32) + + self.output_threshold = norm_value_matrix.max() + + self.rebuild_kernels = False + + assert self.fft_gabor_cos_bank is not None + assert self.fft_gabor_sin_bank is not None + + t1 = time.time() + + fft_input = torch.fft.rfft2(input, s=None, dim=(-2, -1), norm=None) + + output_sin = torch.fft.irfft2( + fft_input * self.fft_gabor_sin_bank, s=None, dim=(-2, -1), norm=None + ) + + output_cos = torch.fft.irfft2( + fft_input * self.fft_gabor_cos_bank, s=None, dim=(-2, -1), norm=None + ) + + t2 = time.time() + + output = torch.sqrt(output_sin**2 + output_cos**2) + + t3 = time.time() + + print( + "ContourExtract {:.3f}s: prep-{:.3f}s, fft-{:.3f}s, out-{:.3f}s".format( + t3 - t0, t1 - t0, t2 - t1, t3 - t2 + ) + ) + + return output + + def create_collapse(self, input: torch.Tensor) -> torch.Tensor: + + assert self.weight_vector is not None + + output = torch.abs((self.weight_vector * input).sum(dim=1)).type( + dtype=self.default_dtype + ) + + return output + + def create_kernel_axis(self, axis_size: int) -> torch.Tensor: + + lower_bound_axis: int = -int(math.floor(axis_size / 2)) + upper_bound_axis: int = int(math.ceil(axis_size / 2)) + + kernel_axis = torch.arange( + lower_bound_axis, + upper_bound_axis, + device=self.torch_device, + dtype=self.default_dtype, + ) + kernel_axis = torch.roll(kernel_axis, int(math.ceil(axis_size / 2))) + + return kernel_axis + + def create_gabor_filter_bank(self) -> tuple[torch.Tensor, torch.Tensor]: + + assert self.kernel_axis_x is not None + assert self.kernel_axis_y is not None + assert self.target_orientations is not None + + orientation_matrix = ( + self.target_orientations.unsqueeze(-1).unsqueeze(-1).detach().clone() + ) + x_kernel_matrix = self.kernel_axis_x.unsqueeze(0).unsqueeze(-1).detach().clone() + y_kernel_matrix = self.kernel_axis_y.unsqueeze(0).unsqueeze(0).detach().clone() + + r2 = x_kernel_matrix**2 + self.gamma_aspect_ratio * y_kernel_matrix**2 + + kr = x_kernel_matrix * torch.cos( + orientation_matrix + ) + y_kernel_matrix * torch.sin(orientation_matrix) + + c0 = torch.exp(-2 * (self.pi * self.sigma_kernel / self.lambda_kernel) ** 2) + + gauss: torch.Tensor = torch.exp(-r2 / 2 / self.sigma_kernel**2) + + gabor_cos_bank: torch.Tensor = gauss * ( + torch.cos(2 * self.pi * kr / self.lambda_kernel + self.psi_phase_offset_cos) + - c0 * torch.cos(self.psi_phase_offset_cos) + ) + + gabor_sin_bank: torch.Tensor = gauss * ( + torch.cos(2 * self.pi * kr / self.lambda_kernel + self.psi_phase_offset_sin) + - c0 * torch.cos(self.psi_phase_offset_sin) + ) + + return gabor_cos_bank, gabor_sin_bank + + def update_settings( + self, + n_orientations: int, + sigma_kernel: float, + lambda_kernel: float, + ): + + self.rebuild_kernels = True + + self.n_orientations = n_orientations + self.sigma_kernel = sigma_kernel + self.lambda_kernel = lambda_kernel + + # generate orientation axis and axis for complex summation + self.target_orientations: torch.Tensor = ( + torch.arange( + start=0, + end=int(self.n_orientations), + device=self.torch_device, + dtype=self.default_dtype, + ) + * torch.tensor( + math.pi, + device=self.torch_device, + dtype=self.default_dtype, + ) + / self.n_orientations + ) + + self.weight_vector: torch.Tensor = ( + torch.exp( + 2.0 + * torch.complex(torch.tensor(0.0), torch.tensor(1.0)) + * self.target_orientations + ) + .unsqueeze(-1) + .unsqueeze(-1) + .unsqueeze(0) + ) diff --git a/processing_chain/DiscardElements.py b/processing_chain/DiscardElements.py new file mode 100644 index 0000000..a0ef43a --- /dev/null +++ b/processing_chain/DiscardElements.py @@ -0,0 +1,162 @@ +#%% +# DiscardElements.py +# ==================================== +# removes elements from a sparse image representation +# such that a 'most uniform' coverage still exists +# +# Version V1.0, pre-07.03.2023: +# no actual changes, is David's last code version... + +import numpy as np + +# assume locations is array [n, 3] +# the three entries are [shape_index, pos_x, pos_y] + + +def discard_elements_simple(locations: np.ndarray, target_number_elements: list): + + n_locations: int = locations.shape[0] + locations_remain: list = [] + + # Loop across all target number of elements + for target_elem in target_number_elements: + + assert target_elem > 0, "Number of target elements must be larger than 0!" + assert ( + target_elem <= n_locations + ), "Number of target elements must be <= number of available locations!" + + # Build distance matrix between positions in locations_highest_res. + # Its diagonal is defined as Inf because we don't want to consider these values in our + # search for the minimum distances. + distance_matrix = np.sqrt( + ((locations[np.newaxis, :, 1:] - locations[:, np.newaxis, 1:]) ** 2).sum( + axis=-1 + ) + ) + distance_matrix[np.arange(n_locations), np.arange(n_locations)] = np.inf + + # Find the minimal distances in upper triangle of matrix. + idcs_remove: list = [] + while (n_locations - len(idcs_remove)) != target_elem: + + # Get index of matrix with minimal distance + row_idcs, col_idcs = np.where( + distance_matrix == distance_matrix[distance_matrix > 0].min() + ) + + # Get the max index. + # It correspond to the index of the element we will remove in the locations_highest_res list + sel_idx: int = max(row_idcs[0], col_idcs[0]) + idcs_remove.append(sel_idx) # Save the index + + # Set current distance as Inf because we don't want to consider it further in our search + distance_matrix[sel_idx, :] = np.inf + distance_matrix[:, sel_idx] = np.inf + + idcs_remain: list = np.setdiff1d(np.arange(n_locations), idcs_remove) + locations_remain.append(locations[idcs_remain, :]) + + return locations_remain + + +# assume locations is array [n, 3] +# the three entries are [shape_index, pos_x, pos_y] +def discard_elements( + locations: np.ndarray, target_number_elements: list, prior: np.ndarray +): + + n_locations: int = locations.shape[0] + locations_remain: list = [] + disable_value: float = np.nan + + # if type(prior) != np.ndarray: + # prior = np.ones((n_locations,)) + assert prior.shape == ( + n_locations, + ), "Prior must have same number of entries as elements in locations!" + print(prior) + + # Loop across all target number of elements + for target_elem in target_number_elements: + + assert target_elem > 0, "Number of target elements must be larger than 0!" + assert ( + target_elem <= n_locations + ), "Number of target elements must be <= number of available locations!" + + # Build distance matrix between positions in locations_highest_res. + # Its diagonal is defined as Inf because we don't want to consider these values in our + # search for the minimum distances. + distance_matrix = np.sqrt( + ((locations[np.newaxis, :, 1:] - locations[:, np.newaxis, 1:]) ** 2).sum( + axis=-1 + ) + ) + prior_matrix = prior[np.newaxis, :] * prior[:, np.newaxis] + distance_matrix *= prior_matrix + distance_matrix[np.arange(n_locations), np.arange(n_locations)] = disable_value + print(distance_matrix) + + # Find the minimal distances in upper triangle of matrix. + idcs_remove: list = [] + while (n_locations - len(idcs_remove)) != target_elem: + + # Get index of matrix with minimal distance + row_idcs, col_idcs = np.where( + # distance_matrix == distance_matrix[distance_matrix > 0].min() + distance_matrix + == np.nanmin(distance_matrix) + ) + + # Get the max index. + # It correspond to the index of the element we will remove in the locations_highest_res list + print(row_idcs[0], col_idcs[0]) + # if prior[row_idcs[0]] >= prior[col_idcs[0]]: + # sel_idx = row_idcs[0] + # else: + # sel_idx = col_idcs[0] + d_row = np.nansum(distance_matrix[row_idcs[0], :]) + d_col = np.nansum(distance_matrix[:, col_idcs[0]]) + if d_row > d_col: + sel_idx = col_idcs[0] + else: + sel_idx = row_idcs[0] + # sel_idx: int = max(row_idcs[0], col_idcs[0]) + idcs_remove.append(sel_idx) # Save the index + + # Set current distance as Inf because we don't want to consider it further in our search + distance_matrix[sel_idx, :] = disable_value + distance_matrix[:, sel_idx] = disable_value + + idcs_remain: list = np.setdiff1d(np.arange(n_locations), idcs_remove) + locations_remain.append(locations[idcs_remain, :]) + + return locations_remain + + +if __name__ == "__main__": + import matplotlib.pyplot as plt + + # generate a circle with n locations + n_locations: int = 20 + phi = np.arange(n_locations) / n_locations * 2 * np.pi + locations = np.ones((n_locations, 3)) + locations[:, 1] = np.cos(phi) + locations[:, 2] = np.sin(phi) + prior = np.ones((n_locations,)) + prior[:10] = 0.1 + locations_remain = discard_elements(locations, [n_locations // 5], prior=prior) + + plt.plot(locations[:, 1], locations[:, 2], "ko") + plt.plot(locations_remain[0][:, 1], locations_remain[0][:, 2], "rx") + plt.show() + + locations_remain_simple = discard_elements_simple(locations, [n_locations // 5]) + + plt.plot(locations[:, 1], locations[:, 2], "ko") + plt.plot(locations_remain_simple[0][:, 1], locations_remain_simple[0][:, 2], "rx") + plt.show() + + +# %% diff --git a/processing_chain/OnlineEncoding.py b/processing_chain/OnlineEncoding.py new file mode 100644 index 0000000..b7400fb --- /dev/null +++ b/processing_chain/OnlineEncoding.py @@ -0,0 +1,711 @@ +# %% +# +# test_OnlineEncoding.py +# ======================================================== +# encode visual scenes into sparse representations using +# different kinds of dictionaries +# +# -> derived from test_PsychophysicsEncoding.py +# +# Version 1.0, 29.04.2023: +# +# Version 1.1, 21.06.2023: +# define proper class +# + + +# Import Python modules +# ======================================================== +# import csv +# import time +import os +import glob +import matplotlib.pyplot as plt +import torch +import torchvision as tv +from PIL import Image +import cv2 +import numpy as np + + +# Import our modules +# ======================================================== +from processing_chain.ContourExtract import ContourExtract +from processing_chain.PatchGenerator import PatchGenerator +from processing_chain.Sparsifier import Sparsifier +from processing_chain.DiscardElements import discard_elements_simple +from processing_chain.BuildImage import BuildImage +from processing_chain.WebCam import WebCam +from processing_chain.Yolo5Segmentation import Yolo5Segmentation + + +# TODO required? +def show_torch_frame( + frame_torch: torch.Tensor, + title: str = "", + cmap: str = "viridis", + target: str = "pyplot", +): + frame_numpy = ( + (frame_torch.movedim(0, -1) * 255).type(dtype=torch.uint8).cpu().numpy() + ) + if target == "pyplot": + plt.imshow(frame_numpy, cmap=cmap) + plt.title(title) + plt.show() + if target == "cv2": + if frame_numpy.ndim == 3: + if frame_numpy.shape[-1] == 1: + frame_numpy = np.tile(frame_numpy, [1, 1, 3]) + frame_numpy = (frame_numpy - frame_numpy.min()) / ( + frame_numpy.max() - frame_numpy.min() + ) + # print(frame_numpy.shape, frame_numpy.max(), frame_numpy.min()) + cv2.namedWindow(title, cv2.WINDOW_NORMAL) + cv2.imshow(title, frame_numpy[:, :, (2, 1, 0)]) + cv2.waitKey(1) + + return + + +# TODO required? +def embed_image(frame_torch, out_height, out_width, init_value=0): + + out_shape = torch.tensor(frame_torch.shape) + + frame_width = frame_torch.shape[-1] + frame_height = frame_torch.shape[-2] + + frame_width_idx0 = max([0, (frame_width - out_width) // 2]) + frame_height_idx0 = max([0, (frame_height - out_height) // 2]) + + select_width = min([frame_width, out_width]) + select_height = min([frame_height, out_height]) + + out_shape[-1] = out_width + out_shape[-2] = out_height + + out_torch = init_value * torch.ones(tuple(out_shape)) + + out_width_idx0 = max([0, (out_width - frame_width) // 2]) + out_height_idx0 = max([0, (out_height - frame_height) // 2]) + + out_torch[ + ..., + out_height_idx0 : (out_height_idx0 + select_height), + out_width_idx0 : (out_width_idx0 + select_width), + ] = frame_torch[ + ..., + frame_height_idx0 : (frame_height_idx0 + select_height), + frame_width_idx0 : (frame_width_idx0 + select_width), + ] + + return out_torch + + +class OnlineEncoding: + + # TODO: also pre-populate self-ies here? + # + # DEFINED IN "__init__": + # + # display (fixed) + # gabor (changeable) + # encoding (changeable) + # dictionary (changeable) + # control (fixed) + # path (fixed) + # verbose + # torch_device, default_dtype + # display_size_max_x_PIX, display_size_max_y_PIX + # padding_fill + # cap + # yolo + # classes_detect + # + # + # DEFINED IN "apply_parameter_changes": + # + # padding_PIX + # sigma_kernel_PIX, lambda_kernel_PIX + # out_x, out_y + # clocks, phosphene, clocks_filter + # + + def __init__(self, source=0, verbose=False): + + # Define parameters + # ======================================================== + # Unit abbreviations: + # dva = degrees of visual angle + # pix = pixels + print("OE-Init: Defining default parameters...") + self.verbose = verbose + + # display: Defines geometry of target display + # ======================================================== + # The encoded image will be scaled such that it optimally uses + # the max space available. If the orignal image has a different aspect + # ratio than the display region, it will only use one spatial + # dimension (horizontal or vertical) to its full extent + # + # If one DVA corresponds to different PIX_per_DVA on the display, + # (i.e. varying distance observers from screen), it should be set + # larger than the largest PIX_per_DVA required, for avoiding + # extrapolation artefacts or blur. + # + self.display = { + "size_max_x_DVA": 10.0, # maximum x size of encoded image + "size_max_y_DVA": 10.0, # minimum y size of encoded image + "PIX_per_DVA": 40.0, # scaling factor pixels to DVA + "scale": "same_range", # "same_luminance" or "same_range" + } + + # gabor: Defines paras of Gabor filters for contour extraction + # ============================================================== + self.gabor = { + "sigma_kernel_DVA": 0.06, + "lambda_kernel_DVA": 0.12, + "n_orientations": 8, + } + + # encoding: Defines parameters of sparse encoding process + # ======================================================== + # Roughly speaking, after contour extraction dictionary elements + # will be placed starting from the position with the highest + # overlap with the contour. Elements placed can be surrounded + # by a dead or inhibitory zone to prevent placing further elements + # too closely. The procedure will map 'n_patches_compute' elements + # and then stop. For each element one obtains an overlap with the + # contour image. + # + # After placement, the overlaps found are normalized to the max + # overlap found, and then all elements with a larger normalized overlap + # than 'overlap_threshold' will be selected. These remaining + # elements will comprise a 'full' encoding of the contour. + # + # To generate even sparser representations, the full encoding can + # be reduced to a certain percentage of elements in the full encoding + # by setting the variable 'percentages' + # + # Example: n_patches_compute = 100 reduced by overlap_threshold = 0.1 + # to 80 elements. Requesting a percentage of 30% yields a representation + # with 24 elements. + # + self.encoding = { + "n_patches_compute": 100, # this amount of patches will be placed + "use_exp_deadzone": True, # parameters of Gaussian deadzone + "size_exp_deadzone_DVA": 1.20, # PREVIOUSLY 1.4283 + "use_cutout_deadzone": True, # parameters of cutout deadzone + "size_cutout_deadzone_DVA": 0.65, # PREVIOUSLY 0.7575 + "overlap_threshold": 0.1, # relative overlap threshold + "percentages": torch.tensor([100]), + } + self.number_of_patches = self.encoding["n_patches_compute"] + + # dictionary: Defines parameters of dictionary + # ======================================================== + self.dictionary = { + "size_DVA": 1.0, # PREVIOUSLY 1.25, + "clocks": None, # parameters for clocks dictionary, see below + "phosphene": None, # paramters for phosphene dictionary, see below + } + + self.dictionary["phosphene"]: dict[float] = { + "sigma_width": 0.18, # DEFAULT 0.15, # half-width of Gaussian + } + + self.dictionary["clocks"]: dict[int, int, float, float] = { + "n_dir": 8, # number of directions for clock pointer segments + "n_open": 4, # number of opening angles between two clock pointer segments + "pointer_width": 0.07, # PREVIOUSLY 0.05, # relative width and size of tip extension of clock pointer + "pointer_length": 0.18, # PREVIOUSLY 0.15, # relative length of clock pointer + } + + # control: For controlling plotting options and flow of script + # ======================================================== + self.control = { + "force_torch_use_cpu": False, # force using CPU even if GPU available + "show_capture": True, # shows captured image + "show_object": True, # shows detected object + "show_contours": True, # shows extracted contours + "show_percept": True, # shows percept + } + + # specify classes to detect + class_person = 0 + self.classes_detect = [class_person] + + print( + "OE-Init: Defining paths, creating dirs, setting default device and datatype" + ) + + # path: Path infos for input and output images + # ======================================================== + self.path = {"output": "test/output/level1/", "input": "test/images_test/"} + # Make output directories, if necessary: the place were we dump the new images to... + # os.makedirs(self.path["output"], mode=0o777, exist_ok=True) + + # Check if GPU is available and use it, if possible + # ================================================= + self.default_dtype = torch.float32 + torch.set_default_dtype(self.default_dtype) + if self.control["force_torch_use_cpu"]: + torch_device: str = "cpu" + else: + torch_device: str = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using {torch_device} as TORCH device...") + self.torch_device = torch_device + + print("OE-Init: Compute display scaling factors and padding RGB values") + + # global scaling factors for all pixel-related length scales + self.display_size_max_x_PIX: float = ( + self.display["size_max_x_DVA"] * self.display["PIX_per_DVA"] + ) + self.display_size_max_y_PIX: float = ( + self.display["size_max_y_DVA"] * self.display["PIX_per_DVA"] + ) + + # determine padding fill value + tmp = tv.transforms.Grayscale(num_output_channels=1) + tmp_value = torch.full((3, 1, 1), 0) + self.padding_fill: int = int(tmp(tmp_value).squeeze()) + + print(f"OE-Init: Opening camera source or video file '{source}'") + + # open source + self.cap = WebCam(source) + if not self.cap.open_cam(): + raise OSError(f"Opening source {source} failed!") + + # get the video frame size, frame count and frame rate + frame_width = self.cap.cap_frame_width + frame_height = self.cap.cap_frame_height + fps = self.cap.cap_fps + print( + f"OE-Init: Processing frames of {frame_width} x {frame_height} @ {fps} fps." + ) + + # open output file if we want to save frames + # if output_file != None: + # out = cv2.VideoWriter( + # output_file, + # cv2.VideoWriter_fourcc(*"MJPG"), + # fps, + # (out_x, out_y), + # ) + # if out == None: + # raise OSError(f"Can not open file {output_file} for writing!") + + # get an instance of the Yolo segmentation network + print("OE-Init: initialize YOLO") + self.yolo = Yolo5Segmentation(torch_device=self.torch_device) + + self.send_dictionaries = False + + self.apply_parameter_changes() + + return + + def apply_parameter_changes(self): + + # GET NEW PARAMETERS + print("OE-AppParChg: Computing sizes from new parameters") + + ### BLOCK: dictionary ---------------- + # set patch size for both dictionaries, make sure it is odd number + dictionary_size_PIX: int = ( + 1 + + (int(self.dictionary["size_DVA"] * self.display["PIX_per_DVA"]) // 2) * 2 + ) + + ### BLOCK: gabor --------------------- + # convert contour-related parameters to pixel units + self.sigma_kernel_PIX: float = ( + self.gabor["sigma_kernel_DVA"] * self.display["PIX_per_DVA"] + ) + self.lambda_kernel_PIX: float = ( + self.gabor["lambda_kernel_DVA"] * self.display["PIX_per_DVA"] + ) + + ### BLOCK: gabor & dictionary ------------------ + # Padding + # ------- + self.padding_PIX: int = int( + max(3.0 * self.sigma_kernel_PIX, 1.1 * dictionary_size_PIX) + ) + + # define target video/representation width/height + multiple_of = 4 + out_x = self.display_size_max_x_PIX + 2 * self.padding_PIX + out_y = self.display_size_max_y_PIX + 2 * self.padding_PIX + out_x += (multiple_of - (out_x % multiple_of)) % multiple_of + out_y += (multiple_of - (out_y % multiple_of)) % multiple_of + self.out_x = int(out_x) + self.out_y = int(out_y) + + # generate dictionaries + # --------------------- + ### BLOCK: dictionary -------------------------- + print("OE-AppParChg: Generating dictionaries...") + patch_generator = PatchGenerator(torch_device=self.torch_device) + self.phosphene = patch_generator.alphabet_phosphene( + patch_size=dictionary_size_PIX, + sigma_width=self.dictionary["phosphene"]["sigma_width"] + * dictionary_size_PIX, + ) + ### BLOCK: dictionary & gabor -------------------------- + self.clocks_filter, self.clocks, segments = patch_generator.alphabet_clocks( + patch_size=dictionary_size_PIX, + n_dir=self.dictionary["clocks"]["n_dir"], + n_filter=self.gabor["n_orientations"], + segment_width=self.dictionary["clocks"]["pointer_width"] + * dictionary_size_PIX, + segment_length=self.dictionary["clocks"]["pointer_length"] + * dictionary_size_PIX, + ) + + self.send_dictionaries = True + + return + + # classes_detect, out_x, out_y + def update(self, data_in): + + # handle parameter change + + if data_in: + + print("Incoming -----------> ", data_in) + + self.number_of_patches = data_in["number_of_patches"] + + self.classes_detect = data_in["value"] + + self.gabor["sigma_kernel_DVA"] = data_in["sigma_kernel_DVA"] + self.gabor["lambda_kernel_DVA"] = data_in["sigma_kernel_DVA"] * 2 + self.gabor["n_orientations"] = data_in["n_orientations"] + + self.dictionary["size_DVA"] = data_in["size_DVA"] + self.dictionary["phosphene"]["sigma_width"] = data_in["sigma_width"] + self.dictionary["clocks"]["n_dir"] = data_in["n_dir"] + self.dictionary["clocks"]["n_open"] = data_in["n_dir"] // 2 + self.dictionary["clocks"]["pointer_width"] = data_in["pointer_width"] + self.dictionary["clocks"]["pointer_length"] = data_in["pointer_length"] + + self.encoding["use_exp_deadzone"] = data_in["use_exp_deadzone"] + self.encoding["size_exp_deadzone_DVA"] = data_in["size_exp_deadzone_DVA"] + self.encoding["use_cutout_deadzone"] = data_in["use_cutout_deadzone"] + self.encoding["size_cutout_deadzone_DVA"] = data_in[ + "size_cutout_deadzone_DVA" + ] + + self.control["show_capture"] = data_in["enable_cam"] + self.control["show_object"] = data_in["enable_yolo"] + self.control["show_contours"] = data_in["enable_contour"] + # TODO Fenster zumachen + self.apply_parameter_changes() + + # some constants for addressing specific components of output arrays + image_id_CONST: int = 0 + overlap_index_CONST: int = 1 + + # format: color_RGB, height, width float, range=0,1 + print("OE-ProcessFrame: capturing frame") + frame = self.cap.get_frame() + if frame == None: + raise OSError(f"Can not capture frame {i_frame}") + if self.verbose: + if self.control["show_capture"]: + show_torch_frame(frame, title="Captured", target=self.verbose) + else: + try: + cv2.destroyWindow("Captured") + except: + pass + + # perform segmentation + + frame = frame.to(device=self.torch_device) + print("OE-ProcessFrame: frame segmentation by YOLO") + frame_segmented = self.yolo(frame.unsqueeze(0), classes=self.classes_detect) + + # This extracts the frame in x to convert the mask in a video format + if self.yolo.found_class_id != None: + + n_found = len(self.yolo.found_class_id) + print( + f"OE-ProcessFrame: {n_found} occurrences of desired object found in frame!" + ) + + mask = frame_segmented[0] + + # is there something in the mask? + if not mask.sum() == 0: + + # yes, cut only the part of the frame that has our object of interest + frame_masked = mask * frame + + x_height = mask.sum(axis=-2) + x_indices = torch.where(x_height > 0) + x_max = x_indices[0].max() + 1 + x_min = x_indices[0].min() + + y_height = mask.sum(axis=-1) + y_indices = torch.where(y_height > 0) + y_max = y_indices[0].max() + 1 + y_min = y_indices[0].min() + + frame_cut = frame_masked[:, y_min:y_max, x_min:x_max] + else: + print(f"OE-ProcessFrame: Mask contains all zeros in current frame!") + frame_cut = None + else: + print(f"OE-ProcessFrame: No objects found in current frame!") + frame_cut = None + + if frame_cut == None: + # out_torch = torch.zeros([self.out_y, self.out_x]) + position_selection = torch.zeros((1, 0, 3)) + contour_shape = [1, self.gabor["n_orientations"], 1, 1] + else: + if self.verbose: + if self.control["show_object"]: + show_torch_frame( + frame_cut, title="Selected Object", target=self.verbose + ) + else: + try: + cv2.destroyWindow("Selected Object") + except: + pass + + # UDO: from here on, we proceed as before, just handing + # UDO: over the frame_cut --> image + image = frame_cut + + # Determine target size of image + # image: [RGB, Height, Width], dtype= tensor.torch.uint8 + print("OE-ProcessFrame: Computing downsampling factor image -> display") + f_x: float = self.display_size_max_x_PIX / image.shape[-1] + f_y: float = self.display_size_max_y_PIX / image.shape[-2] + f_xy_min: float = min(f_x, f_y) + downsampling_x: int = int(f_xy_min * image.shape[-1]) + downsampling_y: int = int(f_xy_min * image.shape[-2]) + + # CURRENTLY we do not crop in the end... + # Image size for removing the fft crop later + # center_crop_x: int = downsampling_x + # center_crop_y: int = downsampling_y + + # define contour extraction processing chain + # ------------------------------------------ + print("OE-ProcessFrame: Extracting contours") + train_processing_chain = tv.transforms.Compose( + transforms=[ + tv.transforms.Grayscale(num_output_channels=1), # RGB to grayscale + tv.transforms.Resize( + size=(downsampling_y, downsampling_x) + ), # downsampling + tv.transforms.Pad( # extra white padding around the picture + padding=(self.padding_PIX, self.padding_PIX), + fill=self.padding_fill, + ), + ContourExtract( # contour extraction + n_orientations=self.gabor["n_orientations"], + sigma_kernel=self.sigma_kernel_PIX, + lambda_kernel=self.lambda_kernel_PIX, + torch_device=self.torch_device, + ), + # CURRENTLY we do not crop in the end! + # tv.transforms.CenterCrop( # Remove the padding + # size=(center_crop_x, center_crop_y) + # ), + ], + ) + # ...with and without orientation channels + contour = train_processing_chain(image.unsqueeze(0)) + contour_collapse = train_processing_chain.transforms[-1].create_collapse( + contour + ) + + if self.verbose: + if self.control["show_contours"]: + show_torch_frame( + contour_collapse, + title="Contours Extracted", + cmap="gray", + target=self.verbose, + ) + else: + try: + cv2.destroyWindow("Contours Extracted") + except: + pass + + # generate a prior for mapping the contour to the dictionary + # CURRENTLY we use an uniform prior... + # ---------------------------------------------------------- + dictionary_prior = torch.ones( + (self.clocks_filter.shape[0]), + dtype=self.default_dtype, + device=torch.device(self.torch_device), + ) + + # instantiate and execute sparsifier + # ---------------------------------- + print("OE-ProcessFrame: Performing sparsification") + sparsifier = Sparsifier( + dictionary_filter=self.clocks_filter, + dictionary=self.clocks, + dictionary_prior=dictionary_prior, + number_of_patches=self.encoding["n_patches_compute"], + size_exp_deadzone=self.encoding["size_exp_deadzone_DVA"] + * self.display["PIX_per_DVA"], + plot_use_map=False, # self.control["plot_deadzone"], + deadzone_exp=self.encoding["use_exp_deadzone"], + deadzone_hard_cutout=self.encoding["use_cutout_deadzone"], + deadzone_hard_cutout_size=self.encoding["size_cutout_deadzone_DVA"] + * self.display["PIX_per_DVA"], + padding_deadzone_size_x=self.padding_PIX, + padding_deadzone_size_y=self.padding_PIX, + torch_device=self.torch_device, + ) + sparsifier(contour) + assert sparsifier.position_found is not None + + # extract and normalize the overlap found + overlap_found = sparsifier.overlap_found[ + image_id_CONST, :, overlap_index_CONST + ] + overlap_found = overlap_found / overlap_found.max() + + # get overlap above certain threshold, extract corresponding elements + overlap_idcs_valid = torch.where( + overlap_found >= self.encoding["overlap_threshold"] + )[0] + position_selection = sparsifier.position_found[ + image_id_CONST : image_id_CONST + 1, overlap_idcs_valid, : + ] + n_elements = len(overlap_idcs_valid) + print(f"OE-ProcessFrame: {n_elements} elements positioned!") + + contour_shape = contour.shape + + n_cut = min(position_selection.shape[-2], self.number_of_patches) + + data_out = { + "position_found": position_selection[:, :n_cut, :], + "canvas_size": contour_shape, + } + if self.send_dictionaries: + data_out["features"] = self.clocks + data_out["phosphene"] = self.phosphene + self.send_dictionaries = False + + return data_out + + def __del__(self): + + print("OE-Delete: exiting gracefully!") + self.cap.close_cam() + try: + cv2.destroyAllWindows() + except: + pass + + +# TODO no output file +# TODO detect end of file if input is video file + +if __name__ == "__main__": + + verbose = "cv2" + source = 0 # "GoProWireless" + frame_count = 20 + i_frame = 0 + + data_in = None + + oe = OnlineEncoding(source=source, verbose=verbose) + + # Loop over the frames + while i_frame < frame_count: + + i_frame += 1 + + if i_frame == (frame_count // 3): + oe.dictionary["size_DVA"] = 0.5 + oe.apply_parameter_changes() + + if i_frame == (frame_count * 2 // 3): + oe.dictionary["size_DVA"] = 2.0 + oe.apply_parameter_changes() + + data_out = oe.update(data_in) + position_selection = data_out["position_found"] + contour_shape = data_out["canvas_size"] + + # SENDE/EMPANGSLOGIK: + # + # <- PACKET empfangen + # Parameteränderungen? + # in Instanz se übertragen + # "apply_parameter_changes" aufrufen + # folgende variablen in sendepacket: + # se.clocks, se.phosphene, se.out_x, se.out_y + # "process_frame" + # folgende variablen in sendepacket: + # position_selection, contour_shape + # -> PACKET zurückgeben + + # build the full image! + image_clocks = BuildImage( + canvas_size=contour_shape, + dictionary=oe.clocks, + position_found=position_selection, + default_dtype=oe.default_dtype, + torch_device=oe.torch_device, + ) + # image_phosphenes = BuildImage( + # canvas_size=contour.shape, + # dictionary=dictionary_phosphene, + # position_found=position_selection, + # default_dtype=default_dtype, + # torch_device=torch_device, + # ) + + # normalize to range [0...1] + m = image_clocks[0].max() + if m == 0: + m = 1 + image_clocks_normalized = image_clocks[0] / m + + # embed into frame of desired output size + out_torch = embed_image( + image_clocks_normalized, out_height=oe.out_y, out_width=oe.out_x + ) + + # show, if desired + if verbose: + if oe.control["show_percept"]: + show_torch_frame( + out_torch, title="Percept", cmap="gray", target=verbose + ) + + # if output_file != None: + # out_pixel = ( + # (out_torch * torch.ones([3, 1, 1]) * 255) + # .type(dtype=torch.uint8) + # .movedim(0, -1) + # .numpy() + # ) + # out.write(out_pixel) + + del oe + + # if output_file != None: + # out.release() + +# %% diff --git a/processing_chain/OnlinePerception.py b/processing_chain/OnlinePerception.py new file mode 100644 index 0000000..7586ddc --- /dev/null +++ b/processing_chain/OnlinePerception.py @@ -0,0 +1,299 @@ +#%% + +import torch +import numpy as np +from processing_chain.BuildImage import BuildImage +import matplotlib.pyplot as plt +import time +import cv2 +from gui.GUIEvents import GUIEvents +from gui.GUICombiData import GUICombiData +import tkinter as tk + + +# TODO required? +def embed_image(frame_torch, out_height, out_width, torch_device, init_value=0): + + out_shape = torch.tensor(frame_torch.shape) + + frame_width = frame_torch.shape[-1] + frame_height = frame_torch.shape[-2] + + frame_width_idx0 = max([0, (frame_width - out_width) // 2]) + frame_height_idx0 = max([0, (frame_height - out_height) // 2]) + + select_width = min([frame_width, out_width]) + select_height = min([frame_height, out_height]) + + out_shape[-1] = out_width + out_shape[-2] = out_height + + out_torch = init_value * torch.ones(tuple(out_shape), device=torch_device) + + out_width_idx0 = max([0, (out_width - frame_width) // 2]) + out_height_idx0 = max([0, (out_height - frame_height) // 2]) + + out_torch[ + ..., + out_height_idx0 : (out_height_idx0 + select_height), + out_width_idx0 : (out_width_idx0 + select_width), + ] = frame_torch[ + ..., + frame_height_idx0 : (frame_height_idx0 + select_height), + frame_width_idx0 : (frame_width_idx0 + select_width), + ] + + return out_torch + + +class OnlinePerception: + + # SELFies... + # + # torch_device, default_dtype (fixed) + # canvas_size, features, phosphene, position_found (parameters) + # percept + # + # root, events, confdata, use_gui + # + + def __init__(self, target, use_gui=False): + + # CPU or GPU? + self.default_dtype = torch.float32 + torch.set_default_dtype(self.default_dtype) + if torch.cuda.is_available(): + self.torch_device = torch.device("cuda") + else: + self.torch_device = torch.device("cpu") + + self.use_gui = use_gui + if self.use_gui: + self.root = tk.Tk() + self.confdata: GUICombiData = GUICombiData() + self.events = GUIEvents(tk_root=self.root, confdata=self.confdata) + + # default dictionary parameters + n_xy_canvas = 400 + n_xy_features = 41 + n_features = 32 + n_positions = 1 + + # populate internal parameters + self.canvas_size = [1, 8, n_xy_canvas, n_xy_canvas] + self.features = torch.rand( + (n_features, 1, n_xy_features, n_xy_features), device=self.torch_device + ) + self.phosphene = torch.ones( + (1, 1, n_xy_features, n_xy_features), device=self.torch_device + ) + self.position_found = torch.zeros((1, n_positions, 3), device=self.torch_device) + self.position_found[0, :, 0] = torch.randint(n_features, (n_positions,)) + self.position_found[0, :, 1:] = torch.randint(n_xy_canvas, (n_positions, 2)) + self.percept = torch.zeros( + (1, 1, n_xy_canvas, n_xy_canvas), device=self.torch_device + ) + + self.p_FEATperSECperPOS = 3.0 + self.tau_SEC = 0.3 # percept time constant + self.t = time.time() + + self.target = target # display target + self.display = target + + self.selection = 0 + + return + + def update(self, data_in: dict): + + if data_in: # not NONE? + print(f"Parameter update requested for keys {data_in.keys()}!") + if "position_found" in data_in.keys(): + self.position_found = data_in["position_found"] + if "canvas_size" in data_in.keys(): + self.canvas_size = data_in["canvas_size"] + self.percept = embed_image( + self.percept, + self.canvas_size[-2], + self.canvas_size[-1], + torch_device=self.torch_device, + init_value=0, + ) + if "features" in data_in.keys(): + print(self.features.shape, self.features.max(), self.features.min()) + self.features = data_in["features"] + self.features /= self.features.max() + print(self.features.shape, self.features.max(), self.features.min()) + if "phosphene" in data_in.keys(): + self.phosphene = data_in["phosphene"] + self.phosphene /= self.phosphene.max() + + # parameters of (optional) GUI changed? + data_out = None + if self.use_gui: + self.root.update() + if not self.confdata.gui_running: + data_out = {"exit": 42} + # graceful exit + else: + if self.confdata.check_for_change() is True: + + data_out = { + "value": self.confdata.yolo_class.value, + "sigma_kernel_DVA": self.confdata.contour_extraction.sigma_kernel_DVA, + "n_orientations": self.confdata.contour_extraction.n_orientations, + "size_DVA": self.confdata.alphabet.size_DVA, + # "tau_SEC": self.confdata.alphabet.tau_SEC, + # "p_FEATperSECperPOS": self.confdata.alphabet.p_FEATperSECperPOS, + "sigma_width": self.confdata.alphabet.phosphene_sigma_width, + "n_dir": self.confdata.alphabet.clocks_n_dir, + "pointer_width": self.confdata.alphabet.clocks_pointer_width, + "pointer_length": self.confdata.alphabet.clocks_pointer_length, + "number_of_patches": self.confdata.sparsifier.number_of_patches, + "use_exp_deadzone": self.confdata.sparsifier.use_exp_deadzone, + "use_cutout_deadzone": self.confdata.sparsifier.use_cutout_deadzone, + "size_exp_deadzone_DVA": self.confdata.sparsifier.size_exp_deadzone_DVA, + "size_cutout_deadzone_DVA": self.confdata.sparsifier.size_cutout_deadzone_DVA, + "enable_cam": self.confdata.output_mode.enable_cam, + "enable_yolo": self.confdata.output_mode.enable_yolo, + "enable_contour": self.confdata.output_mode.enable_contour, + } + print(data_out) + + self.p_FEATperSECperPOS = self.confdata.alphabet.p_FEATperSECperPOS + self.tau_SEC = self.confdata.alphabet.tau_SEC + + # print(self.confdata.alphabet.selection) + self.selection = self.confdata.alphabet.selection + # print(f"Selektion gemacht {self.selection}") + + # print(self.confdata.output_mode.enable_percept) + if self.confdata.output_mode.enable_percept: + self.display = self.target + else: + self.display = None + if self.target == "cv2": + cv2.destroyWindow("Percept") + + self.confdata.reset_change_detector() + + # keep track of time, yields dt + t_new = time.time() + dt = t_new - self.t + self.t = t_new + + # exponential decay + self.percept *= torch.exp(-torch.tensor(dt / self.tau_SEC)) + + # new stimulation + p_dt = self.p_FEATperSECperPOS * dt + n_positions = self.position_found.shape[-2] + position_select = torch.rand((n_positions,), device=self.torch_device) < p_dt + n_select = position_select.sum() + if n_select > 0: + # print(f"Selektion ausgewertet {self.selection}") + if self.selection: + dictionary = self.features + else: + dictionary = self.phosphene + percept_addon = BuildImage( + canvas_size=self.canvas_size, + dictionary=dictionary, + position_found=self.position_found[:, position_select, :], + default_dtype=self.default_dtype, + torch_device=self.torch_device, + ) + self.percept += percept_addon + + # prepare for display + display = self.percept[0, 0].cpu().numpy() + if self.display == "cv2": + + # display, and update + cv2.namedWindow("Percept", cv2.WINDOW_NORMAL) + cv2.imshow("Percept", display) + q = cv2.waitKey(1) + + if self.display == "pyplot": + + # display, RUNS SLOWLY, just for testing + plt.imshow(display, cmap="gray", vmin=0, vmax=1) + plt.show() + time.sleep(0.5) + + return data_out + + def __del__(self): + print("Now I'm deleting me!") + try: + cv2.destroyAllWindows() + except: + pass + if self.use_gui: + print("...and the GUI!") + try: + if self.confdata.gui_running is True: + self.root.destroy() + except: + pass + + +if __name__ == "__main__": + + use_gui = True + + data_in = None + op = OnlinePerception("cv2", use_gui=use_gui) + + t_max = 40.0 + dt_update = 10.0 + t_update = dt_update + + t = time.time() + t0 = t + while t - t0 < t_max: + + data_out = op.update(data_in) + data_in = None + if data_out: + print("Output given!") + if "exit" in data_out.keys(): + break + + t = time.time() + if t - t0 > t_update: + + # new canvas size + n_xy_canvas = int(400 + torch.randint(200, (1,))) + canvas_size = [1, 8, n_xy_canvas, n_xy_canvas] + + # new features/phosphenes + n_features = int(16 + torch.randint(16, (1,))) + n_xy_features = int(31 + 2 * torch.randint(10, (1,))) + features = torch.rand( + (n_features, 1, n_xy_features, n_xy_features), device=op.torch_device + ) + phosphene = torch.ones( + (1, 1, n_xy_features, n_xy_features), device=op.torch_device + ) + + # new positions + n_positions = int(1 + torch.randint(3, (1,))) + position_found = torch.zeros((1, n_positions, 3), device=op.torch_device) + position_found[0, :, 0] = torch.randint(n_features, (n_positions,)) + position_found[0, :, 1] = torch.randint(n_xy_canvas, (n_positions,)) + position_found[0, :, 2] = torch.randint(n_xy_canvas, (n_positions,)) + t_update += dt_update + data_in = { + "position_found": position_found, + "canvas_size": canvas_size, + "features": features, + "phosphene": phosphene, + } + + print("done!") + del op + + +# %% diff --git a/processing_chain/PatchGenerator.py b/processing_chain/PatchGenerator.py new file mode 100644 index 0000000..4406372 --- /dev/null +++ b/processing_chain/PatchGenerator.py @@ -0,0 +1,237 @@ +# %% +# PatchGenerator.py +# ==================================== +# generates dictionaries (currently: phosphenes or clocks) +# +# Version V1.0, pre-07.03.2023: +# no actual changes, is David's last code version... +# +# Version V1.1, 07.03.2023: +# merged David's rebuild code (GUI capable) +# (there was not really anything to merge :-)) +# + +import torch +import math + + +class PatchGenerator: + + pi: torch.Tensor + torch_device: torch.device + default_dtype = torch.float32 + + def __init__(self, torch_device: str = "cpu"): + self.torch_device = torch.device(torch_device) + + self.pi = torch.tensor( + math.pi, + device=self.torch_device, + dtype=self.default_dtype, + ) + + def alphabet_phosphene( + self, + sigma_width: float = 2.5, + patch_size: int = 41, + ) -> torch.Tensor: + + n: int = int(patch_size // 2) + temp_grid: torch.Tensor = torch.arange( + start=-n, + end=n + 1, + device=self.torch_device, + dtype=self.default_dtype, + ) + + x, y = torch.meshgrid(temp_grid, temp_grid, indexing="ij") + + phosphene: torch.Tensor = torch.exp(-(x**2 + y**2) / (2 * sigma_width**2)) + phosphene /= phosphene.sum() + + return phosphene.unsqueeze(0).unsqueeze(0) + + def alphabet_clocks( + self, + n_dir: int = 8, + n_open: int = 4, + n_filter: int = 4, + patch_size: int = 41, + segment_width: float = 2.5, + segment_length: float = 15.0, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + # for n_dir directions, there are n_open_max opening angles possible... + assert n_dir % 2 == 0, "n_dir must be multiple of 2" + n_open_max: int = n_dir // 2 + + # ...but this number can be reduced by integer multiples: + assert ( + n_open_max % n_open == 0 + ), "n_open_max = n_dir // 2 must be multiple of n_open" + mul_open: int = n_open_max // n_open + + # filter planes must be multiple of number of orientations implied by n_dir + assert n_filter % n_open_max == 0, "n_filter must be multiple of (n_dir // 2)" + mul_filter: int = n_filter // n_open_max + # compute single segments + segments: torch.Tensor = torch.zeros( + (n_dir, patch_size, patch_size), + device=self.torch_device, + dtype=self.default_dtype, + ) + dirs: torch.Tensor = ( + 2 + * self.pi + * torch.arange( + start=0, end=n_dir, device=self.torch_device, dtype=self.default_dtype + ) + / n_dir + ) + + for i_dir in range(n_dir): + segments[i_dir] = self.draw_segment( + patch_size=patch_size, + phi=float(dirs[i_dir]), + segment_length=segment_length, + segment_width=segment_width, + ) + + # compute patches from segments + clocks = torch.zeros( + (n_open, n_dir, patch_size, patch_size), + device=self.torch_device, + dtype=self.default_dtype, + ) + clocks_filter = torch.zeros( + (n_open, n_dir, n_filter, patch_size, patch_size), + device=self.torch_device, + dtype=self.default_dtype, + ) + + for i_dir in range(n_dir): + for i_open in range(n_open): + + seg1 = segments[i_dir] + seg2 = segments[(i_dir + (i_open + 1) * mul_open) % n_dir] + clocks[i_open, i_dir] = torch.where(seg1 > seg2, seg1, seg2) + + i_filter_seg1 = (i_dir * mul_filter) % n_filter + i_filter_seg2 = ( + (i_dir + (i_open + 1) * mul_open) * mul_filter + ) % n_filter + + if i_filter_seg1 == i_filter_seg2: + clock_merged = torch.where(seg1 > seg2, seg1, seg2) + clocks_filter[i_open, i_dir, i_filter_seg1] = clock_merged + else: + clocks_filter[i_open, i_dir, i_filter_seg1] = seg1 + clocks_filter[i_open, i_dir, i_filter_seg2] = seg2 + + clocks_filter = clocks_filter.reshape( + (n_open * n_dir, n_filter, patch_size, patch_size) + ) + clocks_filter = clocks_filter / clocks_filter.sum( + axis=(-3, -2, -1), keepdims=True + ) + clocks = clocks.reshape((n_open * n_dir, 1, patch_size, patch_size)) + clocks = clocks / clocks.sum(axis=(-2, -1), keepdims=True) + + return clocks_filter, clocks, segments + + def draw_segment( + self, + patch_size: float, + phi: float, + segment_length: float, + segment_width: float, + ) -> torch.Tensor: + + # extension of patch beyond origin + n: int = int(patch_size // 2) + + # grid for the patch + temp_grid: torch.Tensor = torch.arange( + start=-n, + end=n + 1, + device=self.torch_device, + dtype=self.default_dtype, + ) + + x, y = torch.meshgrid(temp_grid, temp_grid, indexing="ij") + + r: torch.Tensor = torch.cat((x.unsqueeze(-1), y.unsqueeze(-1)), dim=2) + + # target orientation of basis segment + phi90: torch.Tensor = phi + self.pi / 2 + + # vector pointing to the ending point of segment (direction) + # + # when result is displayed with plt.imshow(segment), + # phi=0 points to the right, and increasing phi rotates + # the segment counterclockwise + # + e: torch.Tensor = torch.tensor( + [torch.cos(phi90), torch.sin(phi90)], + device=self.torch_device, + dtype=self.default_dtype, + ) + + # tangential vectors + e_tang: torch.Tensor = e.flip(dims=[0]) * torch.tensor( + [-1, 1], device=self.torch_device, dtype=self.default_dtype + ) + + # compute distances to segment: parallel/tangential + d = torch.maximum( + torch.zeros( + (r.shape[0], r.shape[1]), + device=self.torch_device, + dtype=self.default_dtype, + ), + torch.abs( + (r * e.unsqueeze(0).unsqueeze(0)).sum(dim=-1) - segment_length / 2 + ) + - segment_length / 2, + ) + + d_tang = (r * e_tang.unsqueeze(0).unsqueeze(0)).sum(dim=-1) + + # compute minimum distance to any of the two pointers + dr = torch.sqrt(d**2 + d_tang**2) + + segment = torch.exp(-(dr**2) / 2 / segment_width**2) + segment = segment / segment.sum() + + return segment + + +if __name__ == "__main__": + import matplotlib.pyplot as plt + + pg = PatchGenerator() + + patch1 = pg.draw_segment( + patch_size=81, phi=math.pi / 4, segment_length=30, segment_width=5 + ) + patch2 = pg.draw_segment(patch_size=81, phi=0, segment_length=30, segment_width=5) + plt.imshow(torch.where(patch1 > patch2, patch1, patch2).cpu()) + + phos = pg.alphabet_phosphene() + plt.imshow(phos[0, 0].cpu()) + + n_filter = 8 + n_dir = 8 + clocks_filter, clocks, segments = pg.alphabet_clocks(n_dir=n_dir, n_filter=n_filter) + + n_features = clocks_filter.shape[0] + print(n_features, "clock features generated!") + for i_feature in range(n_features): + for i_filter in range(n_filter): + plt.subplot(1, n_filter, i_filter + 1) + plt.imshow(clocks_filter[i_feature, i_filter].cpu()) + plt.title("Feature #{}, Dir #{}".format(i_feature, i_filter)) + plt.show() + + +# %% diff --git a/processing_chain/Sparsifier.py b/processing_chain/Sparsifier.py new file mode 100644 index 0000000..a084660 --- /dev/null +++ b/processing_chain/Sparsifier.py @@ -0,0 +1,418 @@ +# Sparsifier.py +# ==================================== +# matches dictionary patches to contour images +# +# Version V1.0, 07.03.2023: +# slight parameter scaling changes to David's last code version... +# +# Version V1.1, 07.03.2023: +# merged David's rebuild code (GUI capable) +# + +import torch +import math + +import matplotlib.pyplot as plt + +import time + + +class Sparsifier(torch.nn.Module): + + dictionary_filter_fft: torch.Tensor | None = None + dictionary_filter: torch.Tensor + dictionary: torch.Tensor + + parameter_ready: bool + dictionary_ready: bool + + contour_convolved_sum: torch.Tensor | None = None + use_map: torch.Tensor | None = None + position_found: torch.Tensor | None = None + + size_exp_deadzone: float + + number_of_patches: int + padding_deadzone_size_x: int + padding_deadzone_size_y: int + + plot_use_map: bool + + deadzone_exp: bool + deadzone_hard_cutout: int # 0 = not, 1 = round, 2 = box + deadzone_hard_cutout_size: float + + dictionary_prior: torch.Tensor | None + + pi: torch.Tensor + torch_device: torch.device + default_dtype = torch.float32 + + def __init__( + self, + dictionary_filter: torch.Tensor, + dictionary: torch.Tensor, + dictionary_prior: torch.Tensor | None = None, + number_of_patches: int = 10, # + size_exp_deadzone: float = 1.0, # + padding_deadzone_size_x: int = 0, # + padding_deadzone_size_y: int = 0, # + plot_use_map: bool = False, + deadzone_exp: bool = True, # + deadzone_hard_cutout: int = 1, # 0 = not, 1 = round + deadzone_hard_cutout_size: float = 1.0, # + torch_device: str = "cpu", + ): + super().__init__() + + self.dictionary_ready = False + self.parameter_ready = False + + self.torch_device = torch.device(torch_device) + + self.pi = torch.tensor( + math.pi, + device=self.torch_device, + dtype=self.default_dtype, + ) + + self.plot_use_map = plot_use_map + + self.update_parameter( + number_of_patches, + size_exp_deadzone, + padding_deadzone_size_x, + padding_deadzone_size_y, + deadzone_exp, + deadzone_hard_cutout, + deadzone_hard_cutout_size, + ) + + self.update_dictionary(dictionary_filter, dictionary, dictionary_prior) + + def update_parameter( + self, + number_of_patches: int, + size_exp_deadzone: float, + padding_deadzone_size_x: int, + padding_deadzone_size_y: int, + deadzone_exp: bool, + deadzone_hard_cutout: int, + deadzone_hard_cutout_size: float, + ) -> None: + + self.parameter_ready = False + + assert size_exp_deadzone > 0.0 + assert number_of_patches > 0 + assert padding_deadzone_size_x >= 0 + assert padding_deadzone_size_y >= 0 + + self.number_of_patches = number_of_patches + self.size_exp_deadzone = size_exp_deadzone + self.padding_deadzone_size_x = padding_deadzone_size_x + self.padding_deadzone_size_y = padding_deadzone_size_y + + self.deadzone_exp = deadzone_exp + self.deadzone_hard_cutout = deadzone_hard_cutout + self.deadzone_hard_cutout_size = deadzone_hard_cutout_size + + self.parameter_ready = True + + def update_dictionary( + self, + dictionary_filter: torch.Tensor, + dictionary: torch.Tensor, + dictionary_prior: torch.Tensor | None = None, + ) -> None: + + self.dictionary_ready = False + + assert dictionary_filter.ndim == 4 + assert dictionary.ndim == 4 + + # Odd number of pixels. Please! + assert (dictionary_filter.shape[-2] % 2) == 1 + assert (dictionary_filter.shape[-1] % 2) == 1 + + self.dictionary_filter = dictionary_filter.detach().clone() + self.dictionary = dictionary + + if dictionary_prior is not None: + assert dictionary_prior.ndim == 1 + assert dictionary_prior.shape[0] == dictionary_filter.shape[0] + + self.dictionary_prior = dictionary_prior + + self.dictionary_filter_fft = None + + self.dictionary_ready = True + + def fourier_convolution(self, contour: torch.Tensor): + # Pattern, X, Y + assert contour.dim() == 4 + assert self.dictionary_filter is not None + assert contour.shape[-2] >= self.dictionary_filter.shape[-2] + assert contour.shape[-1] >= self.dictionary_filter.shape[-1] + + t0 = time.time() + + contour_fft = torch.fft.rfft2(contour, dim=(-2, -1)) + + t1 = time.time() + + if ( + (self.dictionary_filter_fft is None) + or (self.dictionary_filter_fft.dim() != 4) + or (self.dictionary_filter_fft.shape[-2] != contour.shape[-2]) + or (self.dictionary_filter_fft.shape[-1] != contour.shape[-1]) + ): + dictionary_padded: torch.Tensor = torch.zeros( + ( + self.dictionary_filter.shape[0], + self.dictionary_filter.shape[1], + contour.shape[-2], + contour.shape[-1], + ), + device=self.torch_device, + dtype=self.default_dtype, + ) + dictionary_padded[ + :, + :, + : self.dictionary_filter.shape[-2], + : self.dictionary_filter.shape[-1], + ] = self.dictionary_filter.flip(dims=[-2, -1]) + + dictionary_padded = dictionary_padded.roll( + ( + -(self.dictionary_filter.shape[-2] // 2), + -(self.dictionary_filter.shape[-1] // 2), + ), + (-2, -1), + ) + self.dictionary_filter_fft = torch.fft.rfft2( + dictionary_padded, dim=(-2, -1) + ) + + t2 = time.time() + + assert self.dictionary_filter_fft is not None + + # dimension order for multiplication: [pat, feat, ori, x, y] + self.contour_convolved_sum = torch.fft.irfft2( + contour_fft.unsqueeze(1) * self.dictionary_filter_fft.unsqueeze(0), + dim=(-2, -1), + ).sum(dim=2) + # --> [pat, feat, x, y] + t3 = time.time() + + self.use_map: torch.Tensor = torch.ones( + (contour.shape[0], 1, contour.shape[-2], contour.shape[-1]), + device=self.torch_device, + dtype=self.default_dtype, + ) + + if self.padding_deadzone_size_x > 0: + self.use_map[:, 0, : self.padding_deadzone_size_x, :] = 0.0 + self.use_map[:, 0, -self.padding_deadzone_size_x :, :] = 0.0 + + if self.padding_deadzone_size_y > 0: + self.use_map[:, 0, :, : self.padding_deadzone_size_y] = 0.0 + self.use_map[:, 0, :, -self.padding_deadzone_size_y :] = 0.0 + + t4 = time.time() + print( + "Sparsifier-convol {:.3f}s: fft-img-{:.3f}s, fft-fil-{:.3f}s, convol-{:.3f}s, pad-{:.3f}s".format( + t4 - t0, t1 - t0, t2 - t1, t3 - t2, t4 - t3 + ) + ) + + return + + def find_next_element(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert self.use_map is not None + assert self.contour_convolved_sum is not None + assert self.dictionary_filter is not None + + # store feature index, x pos and y pos (3 entries) + position_found = torch.zeros( + (self.contour_convolved_sum.shape[0], 3), + dtype=torch.int64, + device=self.torch_device, + ) + + overlap_found = torch.zeros( + (self.contour_convolved_sum.shape[0], 2), + device=self.torch_device, + dtype=self.default_dtype, + ) + + for pattern_id in range(0, self.contour_convolved_sum.shape[0]): + + t0 = time.time() + # search_tensor: torch.Tensor = ( + # self.contour_convolved[pattern_id] * self.use_map[pattern_id] + # ).sum(dim=1) + search_tensor: torch.Tensor = ( + self.contour_convolved_sum[pattern_id] * self.use_map[pattern_id] + ) + + t1 = time.time() + + if self.dictionary_prior is not None: + search_tensor *= self.dictionary_prior.unsqueeze(-1).unsqueeze(-1) + + t2 = time.time() + + temp, max_0 = search_tensor.max(dim=0) + temp, max_1 = temp.max(dim=0) + temp_overlap, max_2 = temp.max(dim=0) + + position_base_3: int = int(max_2) + position_base_2: int = int(max_1[position_base_3]) + position_base_1: int = int(max_0[position_base_2, position_base_3]) + position_base_0: int = int(pattern_id) + + position_found[position_base_0, 0] = position_base_1 + position_found[position_base_0, 1] = position_base_2 + position_found[position_base_0, 2] = position_base_3 + + overlap_found[pattern_id, 0] = temp_overlap + overlap_found[pattern_id, 1] = self.contour_convolved_sum[ + position_base_0, position_base_1, position_base_2, position_base_3 + ] + + t3 = time.time() + + x_max: int = int(self.contour_convolved_sum.shape[-2]) + y_max: int = int(self.contour_convolved_sum.shape[-1]) + + # Center arround the max position + x_0 = int(-position_base_2) + x_1 = int(x_max - position_base_2) + y_0 = int(-position_base_3) + y_1 = int(y_max - position_base_3) + + temp_grid_x: torch.Tensor = torch.arange( + start=x_0, + end=x_1, + device=self.torch_device, + dtype=self.default_dtype, + ) + + temp_grid_y: torch.Tensor = torch.arange( + start=y_0, + end=y_1, + device=self.torch_device, + dtype=self.default_dtype, + ) + + x, y = torch.meshgrid(temp_grid_x, temp_grid_y, indexing="ij") + + # discourage the neigbourhood around for the future + if self.deadzone_exp is True: + self.temp_map: torch.Tensor = 1.0 - torch.exp( + -(x**2 + y**2) / (2 * self.size_exp_deadzone**2) + ) + else: + self.temp_map = torch.ones( + (x.shape[0], x.shape[1]), + device=self.torch_device, + dtype=self.default_dtype, + ) + + assert self.deadzone_hard_cutout >= 0 + assert self.deadzone_hard_cutout <= 1 + assert self.deadzone_hard_cutout_size >= 0 + + if self.deadzone_hard_cutout == 1: + temp = x**2 + y**2 + self.temp_map *= torch.where( + temp <= self.deadzone_hard_cutout_size**2, + 0.0, + 1.0, + ) + + self.use_map[position_base_0, 0, :, :] *= self.temp_map + + t4 = time.time() + + # Only for keeping it in the float32 range: + self.use_map[position_base_0, 0, :, :] /= self.use_map[ + position_base_0, 0, :, : + ].max() + + t5 = time.time() + + # print( + # "{}, {}, {}, {}, {}".format(t1 - t0, t2 - t1, t3 - t2, t4 - t3, t5 - t4) + # ) + + return ( + position_found, + overlap_found, + torch.tensor((t1 - t0, t2 - t1, t3 - t2, t4 - t3, t5 - t4)), + ) + + def forward(self, input: torch.Tensor) -> None: + assert self.number_of_patches > 0 + + assert self.dictionary_ready is True + assert self.parameter_ready is True + + self.position_found = torch.zeros( + (input.shape[0], self.number_of_patches, 3), + dtype=torch.int64, + device=self.torch_device, + ) + self.overlap_found = torch.zeros( + (input.shape[0], self.number_of_patches, 2), + device=self.torch_device, + dtype=self.default_dtype, + ) + + # Folding the images and the dictionary + self.fourier_convolution(input) + + t0 = time.time() + dt = torch.tensor([0, 0, 0, 0, 0]) + for patch_id in range(0, self.number_of_patches): + ( + self.position_found[:, patch_id, :], + self.overlap_found[:, patch_id, :], + dt_tmp, + ) = self.find_next_element() + + dt = dt + dt_tmp + + if self.plot_use_map is True: + + assert self.position_found.shape[0] == 1 + assert self.use_map is not None + + print("Position Saliency:") + print(self.overlap_found[0, :, 0]) + print("Overlap with Contour Image:") + print(self.overlap_found[0, :, 1]) + plt.subplot(1, 2, 1) + plt.imshow(self.use_map[0, 0].cpu(), cmap="gray") + plt.title(f"patch-number: {patch_id}") + plt.axis("off") + plt.colorbar(shrink=0.5) + plt.subplot(1, 2, 2) + plt.imshow( + (self.use_map[0, 0] * input[0].sum(dim=0)).cpu(), cmap="gray" + ) + plt.title(f"patch-number: {patch_id}") + plt.axis("off") + plt.colorbar(shrink=0.5) + plt.show() + + # self.overlap_found /= self.overlap_found.max(dim=1, keepdim=True)[0] + t1 = time.time() + print( + "Sparsifier-forward {:.3f}s: usemap-{:.3f}s, prior-{:.3f}s, findmax-{:.3f}s, notch-{:.3f}s, norm-{:.3f}s: (sum-{:.3f})".format( + t1 - t0, dt[0], dt[1], dt[2], dt[3], dt[4], dt.sum() + ) + ) diff --git a/processing_chain/WebCam.py b/processing_chain/WebCam.py new file mode 100644 index 0000000..6a78e1c --- /dev/null +++ b/processing_chain/WebCam.py @@ -0,0 +1,325 @@ +#%% +# +# WebCam.py +# ======================================================== +# interface to cv2 for using a webcam or for reading from +# a video file +# +# Version 1.0, before 30.03.2023: +# written by David... +# +# Version 1.1, 30.03.2023: +# thrown out test image +# added test code +# added code to "capture" from video file +# +# Version 1.2, 20.06.2023: +# added code to capture wirelessly from "GoPro" camera +# added test code for "GoPro" +# +# Version 1.3, 23.06.2023 +# test display in pyplot or cv2 +# +# Version 1.4, 28.06.2023 +# solved Windows DirectShow problem +# + + +from PIL import Image +import os +import cv2 +import torch +import torchvision as tv + +# for GoPro +import time +import socket + +try: + print("Trying to import GoPro modules...") + from goprocam import GoProCamera, constants + + gopro_exists = True +except: + print("...not found, continuing!") + gopro_exists = False +import platform + + +class WebCam: + + # test_pattern: torch.Tensor + # test_pattern_gray: torch.Tensor + + source: int + framesize: tuple[int, int] + fps: float + cap_frame_width: int + cap_frame_height: int + cap_fps: float + webcam_is_ready: bool + cap_frames_available: int + + default_dtype = torch.float32 + + def __init__( + self, + source: str | int = 1, + framesize: tuple[int, int] = (720, 1280), # (1920, 1080), # (640, 480), + fps: float = 30.0, + ): + super().__init__() + assert fps > 0 + + self.source = source + self.framesize = framesize + self.cap = None + self.fps = fps + self.webcam_is_ready = False + + def open_cam(self) -> bool: + if self.cap is not None: + self.cap.release() + self.cap = None + self.webcam_is_ready = False + + # handle GoPro... + if self.source == "GoProWireless": + + if not gopro_exists: + print("No GoPro driver/support!") + self.webcam_is_ready = False + return False + + print("GoPro: Starting access") + gpCam = GoProCamera.GoPro() + self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + print("GoPro: Socket created!!!") + + self.t = time.time() + gpCam.livestream("start") + gpCam.video_settings(res="1080p", fps="30") + gpCam.gpControlSet( + constants.Stream.WINDOW_SIZE, constants.Stream.WindowSize.R720 + ) + + self.cap = cv2.VideoCapture("udp://10.5.5.9:8554", cv2.CAP_FFMPEG) + print("GoPro: Video capture started!!!") + + else: + self.sock = None + self.t = -1 + if platform.system().lower() == "windows": + self.cap = cv2.VideoCapture(self.source, cv2.CAP_DSHOW) + else: + self.cap = cv2.VideoCapture(self.source) + print("Normal capture started!!!") + + assert self.cap is not None + + if self.cap.isOpened() is not True: + self.webcam_is_ready = False + return False + + if type(self.source) != str: + self.cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*"MJPG")) + self.cap.set(cv2.CAP_PROP_FPS, self.fps) + self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, self.framesize[0]) + self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.framesize[1]) + self.cap_frames_available = None + else: + # ...for files and GoPro... + self.cap_frames_available = self.cap.get(cv2.CAP_PROP_FRAME_COUNT) + + self.cap_frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.cap_frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.cap_fps = float(self.cap.get(cv2.CAP_PROP_FPS)) + + print( + ( + f"Capturing or reading with: {self.cap_frame_width:.0f} x " + f"{self.cap_frame_height:.0f} @ " + f"{self.cap_fps:.1f}." + ) + ) + self.webcam_is_ready = True + return True + + def close_cam(self) -> None: + if self.cap is not None: + self.cap.release() + + def get_frame(self) -> torch.Tensor | None: + if self.cap is None: + return None + else: + if self.sock: + dt_min = 0.015 + success, frame = self.cap.read() + t_next = time.time() + t_prev = t_next + while t_next - t_prev < dt_min: + t_prev = t_next + success, frame = self.cap.read() + t_next = time.time() + + if self.t >= 0: + if time.time() - self.t > 2.5: + print("GoPro-Stream must be kept awake!...") + self.sock.sendto( + "_GPHD_:0:0:2:0.000000\n".encode(), ("10.5.5.9", 8554) + ) + self.t = time.time() + else: + success, frame = self.cap.read() + + if success is False: + self.webcam_is_ready = False + return None + + output = ( + torch.tensor(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + .movedim(-1, 0) + .type(dtype=self.default_dtype) + / 255.0 + ) + return output + + +# for testing the code if module is executed from command line +if __name__ == "__main__": + + import matplotlib.pyplot as plt + import cv2 + + TEST_FILEREAD = False + TEST_WEBCAM = True + TEST_GOPRO = False + + display = "cv2" + n_capture = 200 + delay_capture = 0.001 + + print("Testing the WebCam interface") + + if TEST_FILEREAD: + + file_name = "level1.mp4" + + # open + print("Opening video file") + w = WebCam(file_name) + if not w.open_cam(): + raise OSError(f"Opening file with name {file_name} failed!") + + # print information + print( + f"Frame size {w.cap_frame_width} x {w.cap_frame_height} at {w.cap_fps} fps." + ) + + # capture three frames and show them + for i in range(min([n_capture, w.cap_frames_available])): # TODO: available? + frame = w.get_frame() + if frame == None: + raise OSError(f"Can not get frame from file with name {file_name}!") + print(f"frame {i} has shape {frame.shape}") + + frame_numpy = (frame.movedim(0, -1) * 255).type(dtype=torch.uint8).numpy() + + if display == "pyplot": + plt.imshow(frame_numpy) + plt.show() + if display == "cv2": + cv2.imshow("File", frame_numpy[:, :, (2, 1, 0)]) + cv2.waitKey(1) + time.sleep(delay_capture) + + # close + print("Closing file") + w.close_cam() + + if TEST_WEBCAM: + + camera_index = 0 + + # open + print("Opening camera") + w = WebCam(camera_index) + if not w.open_cam(): + raise OSError(f"Opening web cam with index {camera_index} failed!") + + # print information + print( + f"Frame size {w.cap_frame_width} x {w.cap_frame_height} at {w.cap_fps} fps." + ) + + # capture three frames and show them + for i in range(n_capture): + frame = w.get_frame() + if frame == None: + raise OSError( + f"Can not get frame from camera with index {camera_index}!" + ) + print(f"frame {i} has shape {frame.shape}") + + frame_numpy = (frame.movedim(0, -1) * 255).type(dtype=torch.uint8).numpy() + if display == "pyplot": + plt.imshow(frame_numpy) + plt.show() + if display == "cv2": + cv2.imshow("WebCam", frame_numpy[:, :, (2, 1, 0)]) + cv2.waitKey(1) + time.sleep(delay_capture) + + # close + print("Closing camera") + w.close_cam() + + if TEST_GOPRO: + + camera_name = "GoProWireless" + + # open + print("Opening GoPro") + w = WebCam(camera_name) + if not w.open_cam(): + raise OSError(f"Opening GoPro with index {camera_index} failed!") + + # print information + print( + f"Frame size {w.cap_frame_width} x {w.cap_frame_height} at {w.cap_fps} fps." + ) + w.cap.set(cv2.CAP_PROP_BUFFERSIZE, 0) + + # capture three frames and show them + # print("Empty Buffer...") + # for i in range(500): + # print(i) + # frame = w.get_frame() + # print("Buffer Emptied...") + + for i in range(n_capture): + frame = w.get_frame() + if frame == None: + raise OSError( + f"Can not get frame from camera with index {camera_index}!" + ) + print(f"frame {i} has shape {frame.shape}") + + frame_numpy = (frame.movedim(0, -1) * 255).type(dtype=torch.uint8).numpy() + if display == "pyplot": + plt.imshow(frame_numpy) + plt.show() + if display == "cv2": + cv2.imshow("GoPro", frame_numpy[:, :, (2, 1, 0)]) + cv2.waitKey(1) + time.sleep(delay_capture) + + # close + print("Closing Cam/File/GoPro") + w.close_cam() + + if display == "cv2": + cv2.destroyAllWindows() + +# %% diff --git a/processing_chain/Yolo5Segmentation.py b/processing_chain/Yolo5Segmentation.py new file mode 100644 index 0000000..ae80363 --- /dev/null +++ b/processing_chain/Yolo5Segmentation.py @@ -0,0 +1,402 @@ +import torch +import os +import torchvision as tv +import time + +from models.yolo import Detect, Model + +# Warning: The line "#matplotlib.use('Agg') # for writing to files only" +# in utils/plots.py prevents the further use of matplotlib + + +class Yolo5Segmentation(torch.nn.Module): + + default_dtype = torch.float32 + + conf: float = 0.25 # NMS confidence threshold + iou: float = 0.45 # NMS IoU threshold + agnostic: bool = False # NMS class-agnostic + multi_label: bool = False # NMS multiple labels per box + max_det: int = 1000 # maximum number of detections per image + number_of_maps: int = 32 + imgsz: tuple[int, int] = (640, 640) # inference size (height, width) + + device: torch.device = torch.device("cpu") + weigh_path: str = "" + + class_names: dict + stride: int + + found_class_id: torch.Tensor | None = None + + def __init__(self, mode: int = 3, torch_device: str = "cpu"): + super().__init__() + + model_pretrained_path: str = "segment_pretrained" + assert mode < 5 + assert mode >= 0 + if mode == 0: + model_pretrained_weights: str = "yolov5n-seg.pt" + elif mode == 1: + model_pretrained_weights = "yolov5s-seg.pt" + elif mode == 2: + model_pretrained_weights = "yolov5m-seg.pt" + elif mode == 3: + model_pretrained_weights = "yolov5l-seg.pt" + elif mode == 4: + model_pretrained_weights = "yolov5x-seg.pt" + + self.weigh_path = os.path.join(model_pretrained_path, model_pretrained_weights) + + self.device = torch.device(torch_device) + + self.network = self.attempt_load( + self.weigh_path, device=self.device, inplace=True, fuse=True + ) + self.stride = max(int(self.network.stride.max()), 32) # model stride + self.network.float() + self.class_names = dict(self.network.names) # type: ignore + + # classes: (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs + def forward(self, input: torch.Tensor, classes=None) -> torch.Tensor: + + assert input.ndim == 4 + assert input.shape[0] == 1 + assert input.shape[1] == 3 + + input_resized, ( + remove_left, + remove_top, + remove_height, + remove_width, + ) = self.scale_and_pad( + input, + ) + + network_output = self.network(input_resized) + number_of_classes = network_output[0].shape[2] - self.number_of_maps - 5 + assert len(self.class_names) == number_of_classes + + maps = network_output[1] + + # results matrix: + # Fist Dimension: + # Image Number + # ... + # Last Dimension: + # center_x: 0 + # center_y: 1 + # width: 2 + # height: 3 + # obj_conf (object): 4 + # cls_conf (class): 5 + + results = non_max_suppression( + network_output[0], + self.conf, + self.iou, + classes, + self.agnostic, + self.multi_label, + max_det=self.max_det, + nm=self.number_of_maps, + ) + + image_id = 0 + + if results[image_id].shape[0] > 0: + masks = self.process_mask_native( + maps[image_id], + results[image_id][:, 6:], + results[image_id][:, :4], + ) + self.found_class_id = results[image_id][:, 5] + + output = tv.transforms.functional.resize( + tv.transforms.functional.crop( + img=masks, + top=int(remove_top), + left=int(remove_left), + height=int(remove_height), + width=int(remove_width), + ), + size=(input.shape[-2], input.shape[-1]), + ) + else: + output = None + self.found_class_id = None + + return output + + # ------------------------------------------------------------------------- + # ------------------------------------------------------------------------- + # ------------------------------------------------------------------------- + # ------------------------------------------------------------------------- + + # code stolen and/or modified from Yolov5 -> + def scale_and_pad( + self, + input, + ): + ratio = min(self.imgsz[0] / input.shape[-2], self.imgsz[1] / input.shape[-1]) + + shape_new_x = int(input.shape[-2] * ratio) + shape_new_y = int(input.shape[-1] * ratio) + + dx = self.imgsz[0] - shape_new_x + dy = self.imgsz[1] - shape_new_y + + dx_0 = dx // 2 + dy_0 = dy // 2 + + image_resize = tv.transforms.functional.pad( + tv.transforms.functional.resize(input, size=(shape_new_x, shape_new_y)), + padding=[dy_0, dx_0, int(dy - dy_0), int(dx - dx_0)], + fill=float(114.0 / 255.0), + ) + + return image_resize, (dy_0, dx_0, shape_new_x, shape_new_y) + + def process_mask_native(self, protos, masks_in, bboxes): + masks = ( + (masks_in @ protos.float().view(protos.shape[0], -1)) + .sigmoid() + .view(-1, protos.shape[1], protos.shape[2]) + ) + + masks = torch.nn.functional.interpolate( + masks[None], + (self.imgsz[0], self.imgsz[1]), + mode="bilinear", + align_corners=False, + )[ + 0 + ] # CHW + + x1, y1, x2, y2 = torch.chunk(bboxes[:, :, None], 4, 1) # x1 shape(1,1,n) + r = torch.arange(masks.shape[2], device=masks.device, dtype=x1.dtype)[ + None, None, : + ] # rows shape(1,w,1) + c = torch.arange(masks.shape[1], device=masks.device, dtype=x1.dtype)[ + None, :, None + ] # cols shape(h,1,1) + + masks = masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2)) + + return masks.gt_(0.5) + + def attempt_load(self, weights, device=None, inplace=True, fuse=True): + # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a + + model = Ensemble() + for w in weights if isinstance(weights, list) else [weights]: + ckpt = torch.load(w, map_location="cpu") # load + ckpt = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model + + # Model compatibility updates + if not hasattr(ckpt, "stride"): + ckpt.stride = torch.tensor([32.0]) + if hasattr(ckpt, "names") and isinstance(ckpt.names, (list, tuple)): + ckpt.names = dict(enumerate(ckpt.names)) # convert to dict + + model.append( + ckpt.fuse().eval() if fuse and hasattr(ckpt, "fuse") else ckpt.eval() + ) # model in eval mode + + # Module compatibility updates + for m in model.modules(): + t = type(m) + if t in ( + torch.nn.Hardswish, + torch.nn.LeakyReLU, + torch.nn.ReLU, + torch.nn.ReLU6, + torch.nn.SiLU, + Detect, + Model, + ): + m.inplace = inplace # torch 1.7.0 compatibility + if t is Detect and not isinstance(m.anchor_grid, list): + delattr(m, "anchor_grid") + setattr(m, "anchor_grid", [torch.zeros(1)] * m.nl) + elif t is torch.nn.Upsample and not hasattr(m, "recompute_scale_factor"): + m.recompute_scale_factor = None # torch 1.11.0 compatibility + + # Return model + if len(model) == 1: + return model[-1] + + # Return detection ensemble + print(f"Ensemble created with {weights}\n") + for k in "names", "nc", "yaml": + setattr(model, k, getattr(model[0], k)) + model.stride = model[ + torch.argmax(torch.tensor([m.stride.max() for m in model])).int() + ].stride # max stride + assert all( + model[0].nc == m.nc for m in model + ), f"Models have different class counts: {[m.nc for m in model]}" + return model + + +class Ensemble(torch.nn.ModuleList): + # Ensemble of models + def __init__(self): + super().__init__() + + def forward(self, x, augment=False, profile=False, visualize=False): + y = [module(x, augment, profile, visualize)[0] for module in self] + # y = torch.stack(y).max(0)[0] # max ensemble + # y = torch.stack(y).mean(0) # mean ensemble + y = torch.cat(y, 1) # nms ensemble + return y, None # inference, train output + + +def non_max_suppression( + prediction, + conf_thres=0.25, + iou_thres=0.45, + classes=None, + agnostic=False, + multi_label=False, + labels=(), + max_det=300, + nm=0, # number of masks +): + """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections + + Returns: + list of detections, on (n,6) tensor per image [xyxy, conf, cls] + """ + + if isinstance( + prediction, (list, tuple) + ): # YOLOv5 model in validation model, output = (inference_out, loss_out) + prediction = prediction[0] # select only inference output + + device = prediction.device + mps = "mps" in device.type # Apple MPS + if mps: # MPS not fully supported yet, convert tensors to CPU before NMS + prediction = prediction.cpu() + bs = prediction.shape[0] # batch size + nc = prediction.shape[2] - nm - 5 # number of classes + xc = prediction[..., 4] > conf_thres # candidates + + # Checks + assert ( + 0 <= conf_thres <= 1 + ), f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0" + assert ( + 0 <= iou_thres <= 1 + ), f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0" + + # Settings + # min_wh = 2 # (pixels) minimum box width and height + max_wh = 7680 # (pixels) maximum box width and height + max_nms = 30000 # maximum number of boxes into torchvision.ops.nms() + time_limit = 0.5 + 0.05 * bs # seconds to quit after + redundant = True # require redundant detections + multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) + merge = False # use merge-NMS + + t = time.time() + mi = 5 + nc # mask start index + output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs + for xi, x in enumerate(prediction): # image index, image inference + # Apply constraints + # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height + x = x[xc[xi]] # confidence + + # Cat apriori labels if autolabelling + if labels and len(labels[xi]): + lb = labels[xi] + v = torch.zeros((len(lb), nc + nm + 5), device=x.device) + v[:, :4] = lb[:, 1:5] # box + v[:, 4] = 1.0 # conf + v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls + x = torch.cat((x, v), 0) + + # If none remain process next image + if not x.shape[0]: + continue + + # Compute conf + x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf + + # Box/Mask + box = xywh2xyxy( + x[:, :4] + ) # center_x, center_y, width, height) to (x1, y1, x2, y2) + mask = x[:, mi:] # zero columns if no masks + + # Detections matrix nx6 (xyxy, conf, cls) + if multi_label: + i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T + x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1) + else: # best class only + conf, j = x[:, 5:mi].max(1, keepdim=True) + x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres] + + # Filter by class + if classes is not None: + x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] + + # Apply finite constraint + # if not torch.isfinite(x).all(): + # x = x[torch.isfinite(x).all(1)] + + # Check shape + n = x.shape[0] # number of boxes + if not n: # no boxes + continue + elif n > max_nms: # excess boxes + x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence + else: + x = x[x[:, 4].argsort(descending=True)] # sort by confidence + + # Batched NMS + c = x[:, 5:6] * (0 if agnostic else max_wh) # classes + boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores + i = tv.ops.nms(boxes, scores, iou_thres) # NMS + if i.shape[0] > max_det: # limit detections + i = i[:max_det] + if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean) + # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) + iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix + weights = iou * scores[None] # box weights + x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum( + 1, keepdim=True + ) # merged boxes + if redundant: + i = i[iou.sum(1) > 1] # require redundancy + + output[xi] = x[i] + if mps: + output[xi] = output[xi].to(device) + if (time.time() - t) > time_limit: + print(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded") + break # time limit exceeded + + return output + + +def box_iou(box1, box2, eps=1e-7): + # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2) + (a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2) + inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2) + + # IoU = inter / (area1 + area2 - inter) + return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps) + + +def xywh2xyxy(x): + # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right + y = x.clone() + y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x + y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y + y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x + y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y + return y + + +# <- code stolen and/or modified from Yolov5