diff --git a/offline_encoder/config.json b/offline_encoder/config.json new file mode 100644 index 0000000..81e0a60 --- /dev/null +++ b/offline_encoder/config.json @@ -0,0 +1,100 @@ +{ + // Define parameters + // ======================================================== + // Unit abbreviations: + // dva: degrees of visual angle + // pix: pixels + "verbose": true, + + // display: Defines geometry of target display + // ======================================================== + // The encoded image will be scaled such that it optimally uses + // the max space available. If the orignal image has a different aspect + // ratio than the display region, it will only use one spatial + // dimension (horizontal or vertical) to its full extent + // + // If one dva corresponds to different pix_per_dva on the display, + // (i.e. varying distance observers from screen), it should be set + // larger than the largest pix_per_dva required, for avoiding + // extrapolation artefacts or blur. + // + "display": { + "size_max_x_dva": 10.0, // maximum x size of encoded image + "size_max_y_dva": 10.0, // minimum y size of encoded image + "pix_per_dva": 40.0, // scaling factor pixels to dva + "scale": "same_range" // "same_luminance" or "same_range" + }, + + // gabor: Defines paras of Gabor filters for contour extraction + // ============================================================== + "gabor": { + "sigma_kernel_dva": 0.06, + "lambda_kernel_dva": 0.12, + "n_orientations": 8 + }, + + // encoding: Defines parameters of sparse encoding process + // ======================================================== + // Roughly speaking, after contour extraction dictionary elements + // will be placed starting from the position with the highest + // overlap with the contour. Elements placed can be surrounded + // by a dead or inhibitory zone to prevent placing further elements + // too closely. The procedure will map 'n_patches_compute' elements + // and then stop. For each element one obtains an overlap with the + // contour image. + // + // After placement, the overlaps found are normalized to the max + // overlap found, and then all elements with a larger normalized overlap + // than 'overlap_threshold' will be selected. These remaining + // elements will comprise a 'full' encoding of the contour. + // + // To generate even sparser representations, the full encoding can + // be reduced to a certain percentage of elements in the full encoding + // by setting the variable 'percentages' + // + // Example: n_patches_compute: 100 reduced by overlap_threshold: 0.1 + // to 80 elements. Requesting a percentage of 30% yields representation + // with 24 elements. + // + "encoding": { + "n_patches_compute": 100, // this amount of patches will be placed + "use_exp_deadzone": true, // parameters of Gaussian deadzone + "size_exp_deadzone_dva": 1.20, // PREVIOUSLY 1.4283 + "use_cutout_deadzone": true, // parameters of cutout deadzone + "size_cutout_deadzone_dva": 0.65, // PREVIOUSLY 0.7575 + "overlap_threshold": 0.1, // relative overlap threshold + "percentages": 100 + }, + + "number_of_patches": 100, // TODO: Repeated from encoding + + // dictionary: Defines parameters of dictionary + // ======================================================== + "dictionary": { + "size_dva": 1.0, // PREVIOUSLY 1.25, + "clocks": { + "n_dir": 8, // number of directions for clock pointer segments + "n_open": 4, // number of opening angles between two clock pointer segments + "pointer_width": 0.07, // PREVIOUSLY 0.05, // relative width and size of tip extension of clock pointer + "pointer_length": 0.18 // PREVIOUSLY 0.15, // relative length of clock pointer + }, + "phosphene": { + "sigma_width": 0.18 // DEFAULT 0.15, // half-width of Gaussian + } + }, + + + // control: For controlling plotting options and flow of script + // ======================================================== + "control": { + "force_torch_use_cpu": false, // force using CPU even if GPU available + // "show_capture": true, // shows captured image + // "show_object": true, // shows detected object + "show_mode": "cv2", // "pyplot" or "cv2" + "show_image": true, // shows input image + "show_contours": true, // shows extracted contours + "show_percept": true // shows percept +} + + +} \ No newline at end of file diff --git a/offline_encoder/offline_encoding.py b/offline_encoder/offline_encoding.py new file mode 100644 index 0000000..1c451da --- /dev/null +++ b/offline_encoder/offline_encoding.py @@ -0,0 +1,463 @@ +# %% +# +# offline_encoding.py +# ======================================================== +# encode visual scenes into sparse representations using +# different kinds of dictionaries +# +# -> derived from OnlineEncoding.py +# +# Version 1.0, 16.04.2024: +# + + +# Import Python modules +# ======================================================== +# import csv +# import time +# import os +# import glob +import matplotlib.pyplot as plt +import torch +import torchvision as tv # type:ignore +# from PIL import Image +import cv2 +import numpy as np +import json +from jsmin import jsmin # type:ignore + + +# Import our modules +# ======================================================== +from processing_chain.ContourExtract import ContourExtract +from processing_chain.PatchGenerator import PatchGenerator +from processing_chain.Sparsifier import Sparsifier +# from processing_chain.DiscardElements import discard_elements_simple +from processing_chain.BuildImage import BuildImage +# from processing_chain.WebCam import WebCam +# from processing_chain.Yolo5Segmentation import Yolo5Segmentation + + +class OfflineEncoding: + + # INPUT PARAMETERS + config: dict + + # DERIVED PARAMETERS + default_dtype: torch.dtype + torch_device: str + display_size_max_x_pix: float + display_size_max_y_pix: float + # padding_fill: float + # DEFINED PREVIOUSLY IN "apply_parameter_changes": + padding_pix: int + sigma_kernel_pix: float + lambda_kernel_pix: float + out_x: int + out_y: int + clocks: torch.Tensor + phosphene: torch.Tensor + clocks_filter: torch.Tensor + + # DELIVERED BY ENCODING + position_found: None | torch.Tensor + canvas_size: None | torch.Tensor + + def __init__(self, config="config.json"): + + # Define parameters + # ======================================================== + print("OffE-Init: Loading configuration parameters...") + with open(config, "r") as file: + config = json.loads(jsmin(file.read())) + + # store in class + self.config = config + self.position_found = None + self.canvas_size = None + + # get sub-dicts for easier access + display = self.config["display"] + dictionary = self.config["dictionary"] + gabor = self.config["gabor"] + + # print( + # "OE-Init: Defining paths, creating dirs, setting default device and datatype" + # ) + # self.path = {"output": "test/output/level1/", "input": "test/images_test/"} + # Make output directories, if necessary: the place were we dump the new images to... + # os.makedirs(self.path["output"], mode=0o777, exist_ok=True) + + # Check if GPU is available and use it, if possible + # ================================================= + self.default_dtype = torch.float32 + torch.set_default_dtype(self.default_dtype) + if self.config["control"]["force_torch_use_cpu"]: + torch_device = "cpu" + else: + torch_device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using {torch_device} as TORCH device...") + self.torch_device = torch_device + + print("OffE-Init: Compute display scaling factors and padding RGB values") + + # global scaling factors for all pixel-related length scales + self.display_size_max_x_pix = ( + display["size_max_x_dva"] * display["pix_per_dva"] + ) + self.display_size_max_y_pix = ( + display["size_max_y_dva"] * display["pix_per_dva"] + ) + + # determine padding fill value + tmp = tv.transforms.Grayscale(num_output_channels=1) + tmp_value = torch.full((3, 1, 1), 254.0/255) + self.padding_fill = int(tmp(tmp_value).squeeze()) + + # PREVIOUSLY, A SEPARATE ROUTINE APPLIED PARAMETER CHANGES + # WE DISCARD THIS HERE BUT KEEP THE CODE AS EXAMPLE + # + # self.apply_parameter_changes() + # return + # + # def apply_parameter_changes(self): + # + # GET NEW PARAMETERS + print("OffE-Init: Computing image/patch sizes from parameters") + + # BLOCK: dictionary ---------------- + # set patch size for both dictionaries, make sure it is odd number + dictionary_size_pix = ( + 1 + + (int(dictionary["size_dva"] * + display["pix_per_dva"]) // 2) * 2 + ) + + # BLOCK: gabor --------------------- + # convert contour-related parameters to pixel units + self.sigma_kernel_pix = ( + gabor["sigma_kernel_dva"] * + display["pix_per_dva"] + ) + self.lambda_kernel_pix = ( + gabor["lambda_kernel_dva"] * + display["pix_per_dva"] + ) + + # BLOCK: gabor & dictionary ------------------ + # Padding + # ------- + self.padding_pix = int( + max(3.0 * self.sigma_kernel_pix, 1.1 * dictionary_size_pix) + ) + + # define target video/representation width/height + multiple_of = 4 + out_x = self.display_size_max_x_pix + 2 * self.padding_pix + out_y = self.display_size_max_y_pix + 2 * self.padding_pix + out_x += (multiple_of - (out_x % multiple_of)) % multiple_of + out_y += (multiple_of - (out_y % multiple_of)) % multiple_of + self.out_x = int(out_x) + self.out_y = int(out_y) + + # generate dictionaries + # --------------------- + # BLOCK: dictionary -------------------------- + print("OffE-Init: Generating dictionaries...") + patch_generator = PatchGenerator(torch_device=self.torch_device) + self.phosphene = patch_generator.alphabet_phosphene( + patch_size=dictionary_size_pix, + sigma_width=dictionary["phosphene"]["sigma_width"] + * dictionary_size_pix, + ) + # BLOCK: dictionary & gabor -------------------------- + self.clocks_filter, self.clocks, segments = patch_generator.alphabet_clocks( + patch_size=dictionary_size_pix, + n_dir=dictionary["clocks"]["n_dir"], + n_filter=gabor["n_orientations"], + segment_width=dictionary["clocks"]["pointer_width"] + * dictionary_size_pix, + segment_length=dictionary["clocks"]["pointer_length"] + * dictionary_size_pix, + ) + + return + + # TODO image supposed to be torch.Tensor(3, Y, X) within 0...1 + def encode(self, image: torch.Tensor, number_of_patches: int = 42, border_pixel_value: float = 254.0 / 255) -> dict: + + assert len(image.shape) == 3, "Input image must be RGB (3 dimensions)!" + assert image.shape[0] == 3, "Input image format must be (3, HEIGHT, WIDTH)!" + control = self.config["control"] + + + # determine padding fill value + tmp = tv.transforms.Grayscale(num_output_channels=1) + tmp_value = torch.full((3, 1, 1), border_pixel_value) + padding_fill = float(tmp(tmp_value).squeeze()) + + # show input image, if desired... + if control["show_image"]: + self.__show_torch_frame( + image, + title="Encode: Input Image", + target=control["show_mode"] + ) + + # some constants for addressing specific components of output arrays + image_id_const: int = 0 + overlap_index_const: int = 1 + + # Determine target size of image + # image: [RGB, Height, Width], dtype= tensor.torch.uint8 + print("OffE-Encode: Computing downsampling factor image -> display") + f_x: float = self.display_size_max_x_pix / image.shape[-1] + f_y: float = self.display_size_max_y_pix / image.shape[-2] + f_xy_min: float = min(f_x, f_y) + downsampling_x: int = int(f_xy_min * image.shape[-1]) + downsampling_y: int = int(f_xy_min * image.shape[-2]) + + # CURRENTLY we do not crop in the end... + # Image size for removing the fft crop later + # center_crop_x: int = downsampling_x + # center_crop_y: int = downsampling_y + + # define contour extraction processing chain + # ------------------------------------------ + print("OffE-Encode: Extracting contours") + train_processing_chain = tv.transforms.Compose( + transforms=[ + tv.transforms.Grayscale(num_output_channels=1), # RGB to grayscale + tv.transforms.Resize( + size=(downsampling_y, downsampling_x) + ), # downsampling + tv.transforms.Pad( # extra white padding around the picture + padding=(self.padding_pix, self.padding_pix), + fill=padding_fill, + ), + ContourExtract( # contour extraction + n_orientations=self.config["gabor"]["n_orientations"], + sigma_kernel=self.sigma_kernel_pix, + lambda_kernel=self.lambda_kernel_pix, + torch_device=self.torch_device, + ), + # CURRENTLY we do not crop in the end! + # tv.transforms.CenterCrop( # Remove the padding + # size=(center_crop_x, center_crop_y) + # ), + ], + ) + # ...with and without orientation channels + contour = train_processing_chain(image.unsqueeze(0)) + contour_collapse = train_processing_chain.transforms[-1].create_collapse( + contour + ) + + if control["show_contours"]: + self.__show_torch_frame( + contour_collapse, + title="Encode: Contours Extracted", + cmap="gray", + target=control["show_mode"], + ) + + # generate a prior for mapping the contour to the dictionary + # CURRENTLY we use an uniform prior... + # ---------------------------------------------------------- + dictionary_prior = torch.ones( + (self.clocks_filter.shape[0]), + dtype=self.default_dtype, + device=torch.device(self.torch_device), + ) + + # instantiate and execute sparsifier + # ---------------------------------- + print("OffE-Encode: Performing sparsification") + encoding = self.config["encoding"] + display = self.config["display"] + sparsifier = Sparsifier( + dictionary_filter=self.clocks_filter, + dictionary=self.clocks, + dictionary_prior=dictionary_prior, + number_of_patches=encoding["n_patches_compute"], + size_exp_deadzone=encoding["size_exp_deadzone_dva"] + * display["pix_per_dva"], + plot_use_map=False, # self.control["plot_deadzone"], + deadzone_exp=encoding["use_exp_deadzone"], + deadzone_hard_cutout=encoding["use_cutout_deadzone"], + deadzone_hard_cutout_size=encoding["size_cutout_deadzone_dva"] + * display["pix_per_dva"], + padding_deadzone_size_x=self.padding_pix, + padding_deadzone_size_y=self.padding_pix, + torch_device=self.torch_device, + ) + sparsifier(contour) + assert sparsifier.position_found is not None + + # extract and normalize the overlap found + overlap_found = sparsifier.overlap_found[ + image_id_const, :, overlap_index_const + ] + overlap_found = overlap_found / overlap_found.max() + + # get overlap above certain threshold, extract corresponding elements + overlap_idcs_valid = torch.where( + overlap_found >= encoding["overlap_threshold"] + )[0] + position_selection = sparsifier.position_found[ + image_id_const : image_id_const + 1, overlap_idcs_valid, : + ] + n_elements = len(overlap_idcs_valid) + print(f"OffE-Encode: {n_elements} elements positioned!") + + contour_shape = contour.shape + + n_cut = min(position_selection.shape[-2], number_of_patches) + + data_out = { + "position_found": position_selection[:, :n_cut, :], + "canvas_size": contour_shape, + } + + self.position_found = data_out["position_found"] + self.canvas_size = data_out["canvas_size"] + + return data_out + + def render(self): + + assert self.position_found is not None, "Use ""encode"" before rendering!" + assert self.canvas_size is not None, "Use ""encode"" before rendering!" + + control = self.config["control"] + + # build the full image! + image_clocks = BuildImage( + canvas_size=self.canvas_size, + dictionary=self.clocks, + position_found=self.position_found, + default_dtype=self.default_dtype, + torch_device=self.torch_device, + ) + + # normalize to range [0...1] + m = image_clocks[0].max() + if m == 0: + m = 1 + image_clocks_normalized = image_clocks[0] / m + + # embed into frame of desired output size + out_torch = self.__embed_image( + image_clocks_normalized, out_height=self.out_y, out_width=self.out_x + ) + + # show, if desired... + if control["show_percept"]: + self.__show_torch_frame( + out_torch, title="Percept", + cmap="gray", target=control["show_mode"] + ) + + return + + def __show_torch_frame(self, + frame_torch: torch.Tensor, + title: str = "default", + cmap: str = "viridis", + target: str = "pyplot", + ): + frame_numpy = ( + (frame_torch.movedim(0, -1) * 255).type(dtype=torch.uint8).cpu().numpy() + ) + if target == "pyplot": + plt.imshow(frame_numpy, cmap=cmap) + plt.title(title) + plt.show() + if target == "cv2": + if frame_numpy.ndim == 3: + if frame_numpy.shape[-1] == 1: + frame_numpy = np.tile(frame_numpy, [1, 1, 3]) + frame_numpy = (frame_numpy - frame_numpy.min()) / ( + frame_numpy.max() - frame_numpy.min() + ) + # print(frame_numpy.shape, frame_numpy.max(), frame_numpy.min()) + cv2.namedWindow(title, cv2.WINDOW_NORMAL) + cv2.imshow(title, frame_numpy[:, :, (2, 1, 0)]) + cv2.waitKey(1) + + return + + def __embed_image(self, frame_torch, out_height, out_width, init_value=0): + + out_shape = torch.tensor(frame_torch.shape) + + frame_width = frame_torch.shape[-1] + frame_height = frame_torch.shape[-2] + + frame_width_idx0 = max([0, (frame_width - out_width) // 2]) + frame_height_idx0 = max([0, (frame_height - out_height) // 2]) + + select_width = min([frame_width, out_width]) + select_height = min([frame_height, out_height]) + + out_shape[-1] = out_width + out_shape[-2] = out_height + + out_torch = init_value * torch.ones(tuple(out_shape)) + + out_width_idx0 = max([0, (out_width - frame_width) // 2]) + out_height_idx0 = max([0, (out_height - frame_height) // 2]) + + out_torch[ + ..., + out_height_idx0: (out_height_idx0 + select_height), + out_width_idx0: (out_width_idx0 + select_width), + ] = frame_torch[ + ..., + frame_height_idx0: (frame_height_idx0 + select_height), + frame_width_idx0: (frame_width_idx0 + select_width), + ] + + return out_torch + + def __del__(self): + + print("OffE-Delete: exiting gracefully!") + # TODO ...only do it when necessary + cv2.destroyAllWindows() + + return + + +if __name__ == "__main__": + + source = 'bernd.jpg' + img_cv2 = cv2.imread(source) + img_torch = torch.Tensor(img_cv2[:, :, (2, 1, 0)]).movedim(-1, 0) / 255 + # show_torch_frame(img_torch, target="cv2", title=source) + print(f"CV2 Shape: {img_cv2.shape}") + print(f"Torch Shape: {img_torch.shape}") + + img = img_torch + frame_width = img.shape[-1] + frame_height = img.shape[-2] + print( + f"OffE-Test: Processing image {source} of {frame_width} x {frame_height}." + ) + + # TEST tfg = tv.transforms.Grayscale(num_output_channels=1) + # TEST pixel_fill = torch.full((3, 1, 1), 254.0 / 255) + # TEST value_fill = float(tfg(pixel_fill).squeeze()) + # TEST tfp = tv.transforms.Pad(padding=(1, 1), fill=value_fill) + + # TEST img_gray = tfg(img[:, :3, :3]) + # TEST img_pad = tfp(img_gray) + + oe = OfflineEncoding() + encoding = oe.encode(img) + stimulus = oe.render() + if oe.config["control"]["show_mode"] == "cv2": + cv2.waitKey(5000) + del oe + +# %%