percept_simulator_2023/processing_chain/OnlinePerception.py

300 lines
11 KiB
Python
Raw Permalink Normal View History

2023-07-31 15:23:13 +02:00
#%%
import torch
import numpy as np
from processing_chain.BuildImage import BuildImage
import matplotlib.pyplot as plt
import time
import cv2
from gui.GUIEvents import GUIEvents
from gui.GUICombiData import GUICombiData
import tkinter as tk
# TODO required?
def embed_image(frame_torch, out_height, out_width, torch_device, init_value=0):
out_shape = torch.tensor(frame_torch.shape)
frame_width = frame_torch.shape[-1]
frame_height = frame_torch.shape[-2]
frame_width_idx0 = max([0, (frame_width - out_width) // 2])
frame_height_idx0 = max([0, (frame_height - out_height) // 2])
select_width = min([frame_width, out_width])
select_height = min([frame_height, out_height])
out_shape[-1] = out_width
out_shape[-2] = out_height
out_torch = init_value * torch.ones(tuple(out_shape), device=torch_device)
out_width_idx0 = max([0, (out_width - frame_width) // 2])
out_height_idx0 = max([0, (out_height - frame_height) // 2])
out_torch[
...,
out_height_idx0 : (out_height_idx0 + select_height),
out_width_idx0 : (out_width_idx0 + select_width),
] = frame_torch[
...,
frame_height_idx0 : (frame_height_idx0 + select_height),
frame_width_idx0 : (frame_width_idx0 + select_width),
]
return out_torch
class OnlinePerception:
# SELFies...
#
# torch_device, default_dtype (fixed)
# canvas_size, features, phosphene, position_found (parameters)
# percept
#
# root, events, confdata, use_gui
#
def __init__(self, target, use_gui=False):
# CPU or GPU?
self.default_dtype = torch.float32
torch.set_default_dtype(self.default_dtype)
if torch.cuda.is_available():
self.torch_device = torch.device("cuda")
else:
self.torch_device = torch.device("cpu")
self.use_gui = use_gui
if self.use_gui:
self.root = tk.Tk()
self.confdata: GUICombiData = GUICombiData()
self.events = GUIEvents(tk_root=self.root, confdata=self.confdata)
# default dictionary parameters
n_xy_canvas = 400
n_xy_features = 41
n_features = 32
n_positions = 1
# populate internal parameters
self.canvas_size = [1, 8, n_xy_canvas, n_xy_canvas]
self.features = torch.rand(
(n_features, 1, n_xy_features, n_xy_features), device=self.torch_device
)
self.phosphene = torch.ones(
(1, 1, n_xy_features, n_xy_features), device=self.torch_device
)
self.position_found = torch.zeros((1, n_positions, 3), device=self.torch_device)
self.position_found[0, :, 0] = torch.randint(n_features, (n_positions,))
self.position_found[0, :, 1:] = torch.randint(n_xy_canvas, (n_positions, 2))
self.percept = torch.zeros(
(1, 1, n_xy_canvas, n_xy_canvas), device=self.torch_device
)
self.p_FEATperSECperPOS = 3.0
self.tau_SEC = 0.3 # percept time constant
self.t = time.time()
self.target = target # display target
self.display = target
self.selection = 0
return
def update(self, data_in: dict):
if data_in: # not NONE?
print(f"Parameter update requested for keys {data_in.keys()}!")
if "position_found" in data_in.keys():
self.position_found = data_in["position_found"]
if "canvas_size" in data_in.keys():
self.canvas_size = data_in["canvas_size"]
self.percept = embed_image(
self.percept,
self.canvas_size[-2],
self.canvas_size[-1],
torch_device=self.torch_device,
init_value=0,
)
if "features" in data_in.keys():
print(self.features.shape, self.features.max(), self.features.min())
self.features = data_in["features"]
self.features /= self.features.max()
print(self.features.shape, self.features.max(), self.features.min())
if "phosphene" in data_in.keys():
self.phosphene = data_in["phosphene"]
self.phosphene /= self.phosphene.max()
# parameters of (optional) GUI changed?
data_out = None
if self.use_gui:
self.root.update()
if not self.confdata.gui_running:
data_out = {"exit": 42}
# graceful exit
else:
if self.confdata.check_for_change() is True:
data_out = {
"value": self.confdata.yolo_class.value,
"sigma_kernel_DVA": self.confdata.contour_extraction.sigma_kernel_DVA,
"n_orientations": self.confdata.contour_extraction.n_orientations,
"size_DVA": self.confdata.alphabet.size_DVA,
# "tau_SEC": self.confdata.alphabet.tau_SEC,
# "p_FEATperSECperPOS": self.confdata.alphabet.p_FEATperSECperPOS,
"sigma_width": self.confdata.alphabet.phosphene_sigma_width,
"n_dir": self.confdata.alphabet.clocks_n_dir,
"pointer_width": self.confdata.alphabet.clocks_pointer_width,
"pointer_length": self.confdata.alphabet.clocks_pointer_length,
"number_of_patches": self.confdata.sparsifier.number_of_patches,
"use_exp_deadzone": self.confdata.sparsifier.use_exp_deadzone,
"use_cutout_deadzone": self.confdata.sparsifier.use_cutout_deadzone,
"size_exp_deadzone_DVA": self.confdata.sparsifier.size_exp_deadzone_DVA,
"size_cutout_deadzone_DVA": self.confdata.sparsifier.size_cutout_deadzone_DVA,
"enable_cam": self.confdata.output_mode.enable_cam,
"enable_yolo": self.confdata.output_mode.enable_yolo,
"enable_contour": self.confdata.output_mode.enable_contour,
}
print(data_out)
self.p_FEATperSECperPOS = self.confdata.alphabet.p_FEATperSECperPOS
self.tau_SEC = self.confdata.alphabet.tau_SEC
# print(self.confdata.alphabet.selection)
self.selection = self.confdata.alphabet.selection
# print(f"Selektion gemacht {self.selection}")
# print(self.confdata.output_mode.enable_percept)
if self.confdata.output_mode.enable_percept:
self.display = self.target
else:
self.display = None
if self.target == "cv2":
cv2.destroyWindow("Percept")
self.confdata.reset_change_detector()
# keep track of time, yields dt
t_new = time.time()
dt = t_new - self.t
self.t = t_new
# exponential decay
self.percept *= torch.exp(-torch.tensor(dt / self.tau_SEC))
# new stimulation
p_dt = self.p_FEATperSECperPOS * dt
n_positions = self.position_found.shape[-2]
position_select = torch.rand((n_positions,), device=self.torch_device) < p_dt
n_select = position_select.sum()
if n_select > 0:
# print(f"Selektion ausgewertet {self.selection}")
if self.selection:
dictionary = self.features
else:
dictionary = self.phosphene
percept_addon = BuildImage(
canvas_size=self.canvas_size,
dictionary=dictionary,
position_found=self.position_found[:, position_select, :],
default_dtype=self.default_dtype,
torch_device=self.torch_device,
)
self.percept += percept_addon
# prepare for display
display = self.percept[0, 0].cpu().numpy()
if self.display == "cv2":
# display, and update
cv2.namedWindow("Percept", cv2.WINDOW_NORMAL)
cv2.imshow("Percept", display)
q = cv2.waitKey(1)
if self.display == "pyplot":
# display, RUNS SLOWLY, just for testing
plt.imshow(display, cmap="gray", vmin=0, vmax=1)
plt.show()
time.sleep(0.5)
return data_out
def __del__(self):
print("Now I'm deleting me!")
try:
cv2.destroyAllWindows()
except:
pass
if self.use_gui:
print("...and the GUI!")
try:
if self.confdata.gui_running is True:
self.root.destroy()
except:
pass
if __name__ == "__main__":
use_gui = True
data_in = None
op = OnlinePerception("cv2", use_gui=use_gui)
t_max = 40.0
dt_update = 10.0
t_update = dt_update
t = time.time()
t0 = t
while t - t0 < t_max:
data_out = op.update(data_in)
data_in = None
if data_out:
print("Output given!")
if "exit" in data_out.keys():
break
t = time.time()
if t - t0 > t_update:
# new canvas size
n_xy_canvas = int(400 + torch.randint(200, (1,)))
canvas_size = [1, 8, n_xy_canvas, n_xy_canvas]
# new features/phosphenes
n_features = int(16 + torch.randint(16, (1,)))
n_xy_features = int(31 + 2 * torch.randint(10, (1,)))
features = torch.rand(
(n_features, 1, n_xy_features, n_xy_features), device=op.torch_device
)
phosphene = torch.ones(
(1, 1, n_xy_features, n_xy_features), device=op.torch_device
)
# new positions
n_positions = int(1 + torch.randint(3, (1,)))
position_found = torch.zeros((1, n_positions, 3), device=op.torch_device)
position_found[0, :, 0] = torch.randint(n_features, (n_positions,))
position_found[0, :, 1] = torch.randint(n_xy_canvas, (n_positions,))
position_found[0, :, 2] = torch.randint(n_xy_canvas, (n_positions,))
t_update += dt_update
data_in = {
"position_found": position_found,
"canvas_size": canvas_size,
"features": features,
"phosphene": phosphene,
}
print("done!")
del op
# %%