300 lines
11 KiB
Python
300 lines
11 KiB
Python
|
#%%
|
||
|
|
||
|
import torch
|
||
|
import numpy as np
|
||
|
from processing_chain.BuildImage import BuildImage
|
||
|
import matplotlib.pyplot as plt
|
||
|
import time
|
||
|
import cv2
|
||
|
from gui.GUIEvents import GUIEvents
|
||
|
from gui.GUICombiData import GUICombiData
|
||
|
import tkinter as tk
|
||
|
|
||
|
|
||
|
# TODO required?
|
||
|
def embed_image(frame_torch, out_height, out_width, torch_device, init_value=0):
|
||
|
|
||
|
out_shape = torch.tensor(frame_torch.shape)
|
||
|
|
||
|
frame_width = frame_torch.shape[-1]
|
||
|
frame_height = frame_torch.shape[-2]
|
||
|
|
||
|
frame_width_idx0 = max([0, (frame_width - out_width) // 2])
|
||
|
frame_height_idx0 = max([0, (frame_height - out_height) // 2])
|
||
|
|
||
|
select_width = min([frame_width, out_width])
|
||
|
select_height = min([frame_height, out_height])
|
||
|
|
||
|
out_shape[-1] = out_width
|
||
|
out_shape[-2] = out_height
|
||
|
|
||
|
out_torch = init_value * torch.ones(tuple(out_shape), device=torch_device)
|
||
|
|
||
|
out_width_idx0 = max([0, (out_width - frame_width) // 2])
|
||
|
out_height_idx0 = max([0, (out_height - frame_height) // 2])
|
||
|
|
||
|
out_torch[
|
||
|
...,
|
||
|
out_height_idx0 : (out_height_idx0 + select_height),
|
||
|
out_width_idx0 : (out_width_idx0 + select_width),
|
||
|
] = frame_torch[
|
||
|
...,
|
||
|
frame_height_idx0 : (frame_height_idx0 + select_height),
|
||
|
frame_width_idx0 : (frame_width_idx0 + select_width),
|
||
|
]
|
||
|
|
||
|
return out_torch
|
||
|
|
||
|
|
||
|
class OnlinePerception:
|
||
|
|
||
|
# SELFies...
|
||
|
#
|
||
|
# torch_device, default_dtype (fixed)
|
||
|
# canvas_size, features, phosphene, position_found (parameters)
|
||
|
# percept
|
||
|
#
|
||
|
# root, events, confdata, use_gui
|
||
|
#
|
||
|
|
||
|
def __init__(self, target, use_gui=False):
|
||
|
|
||
|
# CPU or GPU?
|
||
|
self.default_dtype = torch.float32
|
||
|
torch.set_default_dtype(self.default_dtype)
|
||
|
if torch.cuda.is_available():
|
||
|
self.torch_device = torch.device("cuda")
|
||
|
else:
|
||
|
self.torch_device = torch.device("cpu")
|
||
|
|
||
|
self.use_gui = use_gui
|
||
|
if self.use_gui:
|
||
|
self.root = tk.Tk()
|
||
|
self.confdata: GUICombiData = GUICombiData()
|
||
|
self.events = GUIEvents(tk_root=self.root, confdata=self.confdata)
|
||
|
|
||
|
# default dictionary parameters
|
||
|
n_xy_canvas = 400
|
||
|
n_xy_features = 41
|
||
|
n_features = 32
|
||
|
n_positions = 1
|
||
|
|
||
|
# populate internal parameters
|
||
|
self.canvas_size = [1, 8, n_xy_canvas, n_xy_canvas]
|
||
|
self.features = torch.rand(
|
||
|
(n_features, 1, n_xy_features, n_xy_features), device=self.torch_device
|
||
|
)
|
||
|
self.phosphene = torch.ones(
|
||
|
(1, 1, n_xy_features, n_xy_features), device=self.torch_device
|
||
|
)
|
||
|
self.position_found = torch.zeros((1, n_positions, 3), device=self.torch_device)
|
||
|
self.position_found[0, :, 0] = torch.randint(n_features, (n_positions,))
|
||
|
self.position_found[0, :, 1:] = torch.randint(n_xy_canvas, (n_positions, 2))
|
||
|
self.percept = torch.zeros(
|
||
|
(1, 1, n_xy_canvas, n_xy_canvas), device=self.torch_device
|
||
|
)
|
||
|
|
||
|
self.p_FEATperSECperPOS = 3.0
|
||
|
self.tau_SEC = 0.3 # percept time constant
|
||
|
self.t = time.time()
|
||
|
|
||
|
self.target = target # display target
|
||
|
self.display = target
|
||
|
|
||
|
self.selection = 0
|
||
|
|
||
|
return
|
||
|
|
||
|
def update(self, data_in: dict):
|
||
|
|
||
|
if data_in: # not NONE?
|
||
|
print(f"Parameter update requested for keys {data_in.keys()}!")
|
||
|
if "position_found" in data_in.keys():
|
||
|
self.position_found = data_in["position_found"]
|
||
|
if "canvas_size" in data_in.keys():
|
||
|
self.canvas_size = data_in["canvas_size"]
|
||
|
self.percept = embed_image(
|
||
|
self.percept,
|
||
|
self.canvas_size[-2],
|
||
|
self.canvas_size[-1],
|
||
|
torch_device=self.torch_device,
|
||
|
init_value=0,
|
||
|
)
|
||
|
if "features" in data_in.keys():
|
||
|
print(self.features.shape, self.features.max(), self.features.min())
|
||
|
self.features = data_in["features"]
|
||
|
self.features /= self.features.max()
|
||
|
print(self.features.shape, self.features.max(), self.features.min())
|
||
|
if "phosphene" in data_in.keys():
|
||
|
self.phosphene = data_in["phosphene"]
|
||
|
self.phosphene /= self.phosphene.max()
|
||
|
|
||
|
# parameters of (optional) GUI changed?
|
||
|
data_out = None
|
||
|
if self.use_gui:
|
||
|
self.root.update()
|
||
|
if not self.confdata.gui_running:
|
||
|
data_out = {"exit": 42}
|
||
|
# graceful exit
|
||
|
else:
|
||
|
if self.confdata.check_for_change() is True:
|
||
|
|
||
|
data_out = {
|
||
|
"value": self.confdata.yolo_class.value,
|
||
|
"sigma_kernel_DVA": self.confdata.contour_extraction.sigma_kernel_DVA,
|
||
|
"n_orientations": self.confdata.contour_extraction.n_orientations,
|
||
|
"size_DVA": self.confdata.alphabet.size_DVA,
|
||
|
# "tau_SEC": self.confdata.alphabet.tau_SEC,
|
||
|
# "p_FEATperSECperPOS": self.confdata.alphabet.p_FEATperSECperPOS,
|
||
|
"sigma_width": self.confdata.alphabet.phosphene_sigma_width,
|
||
|
"n_dir": self.confdata.alphabet.clocks_n_dir,
|
||
|
"pointer_width": self.confdata.alphabet.clocks_pointer_width,
|
||
|
"pointer_length": self.confdata.alphabet.clocks_pointer_length,
|
||
|
"number_of_patches": self.confdata.sparsifier.number_of_patches,
|
||
|
"use_exp_deadzone": self.confdata.sparsifier.use_exp_deadzone,
|
||
|
"use_cutout_deadzone": self.confdata.sparsifier.use_cutout_deadzone,
|
||
|
"size_exp_deadzone_DVA": self.confdata.sparsifier.size_exp_deadzone_DVA,
|
||
|
"size_cutout_deadzone_DVA": self.confdata.sparsifier.size_cutout_deadzone_DVA,
|
||
|
"enable_cam": self.confdata.output_mode.enable_cam,
|
||
|
"enable_yolo": self.confdata.output_mode.enable_yolo,
|
||
|
"enable_contour": self.confdata.output_mode.enable_contour,
|
||
|
}
|
||
|
print(data_out)
|
||
|
|
||
|
self.p_FEATperSECperPOS = self.confdata.alphabet.p_FEATperSECperPOS
|
||
|
self.tau_SEC = self.confdata.alphabet.tau_SEC
|
||
|
|
||
|
# print(self.confdata.alphabet.selection)
|
||
|
self.selection = self.confdata.alphabet.selection
|
||
|
# print(f"Selektion gemacht {self.selection}")
|
||
|
|
||
|
# print(self.confdata.output_mode.enable_percept)
|
||
|
if self.confdata.output_mode.enable_percept:
|
||
|
self.display = self.target
|
||
|
else:
|
||
|
self.display = None
|
||
|
if self.target == "cv2":
|
||
|
cv2.destroyWindow("Percept")
|
||
|
|
||
|
self.confdata.reset_change_detector()
|
||
|
|
||
|
# keep track of time, yields dt
|
||
|
t_new = time.time()
|
||
|
dt = t_new - self.t
|
||
|
self.t = t_new
|
||
|
|
||
|
# exponential decay
|
||
|
self.percept *= torch.exp(-torch.tensor(dt / self.tau_SEC))
|
||
|
|
||
|
# new stimulation
|
||
|
p_dt = self.p_FEATperSECperPOS * dt
|
||
|
n_positions = self.position_found.shape[-2]
|
||
|
position_select = torch.rand((n_positions,), device=self.torch_device) < p_dt
|
||
|
n_select = position_select.sum()
|
||
|
if n_select > 0:
|
||
|
# print(f"Selektion ausgewertet {self.selection}")
|
||
|
if self.selection:
|
||
|
dictionary = self.features
|
||
|
else:
|
||
|
dictionary = self.phosphene
|
||
|
percept_addon = BuildImage(
|
||
|
canvas_size=self.canvas_size,
|
||
|
dictionary=dictionary,
|
||
|
position_found=self.position_found[:, position_select, :],
|
||
|
default_dtype=self.default_dtype,
|
||
|
torch_device=self.torch_device,
|
||
|
)
|
||
|
self.percept += percept_addon
|
||
|
|
||
|
# prepare for display
|
||
|
display = self.percept[0, 0].cpu().numpy()
|
||
|
if self.display == "cv2":
|
||
|
|
||
|
# display, and update
|
||
|
cv2.namedWindow("Percept", cv2.WINDOW_NORMAL)
|
||
|
cv2.imshow("Percept", display)
|
||
|
q = cv2.waitKey(1)
|
||
|
|
||
|
if self.display == "pyplot":
|
||
|
|
||
|
# display, RUNS SLOWLY, just for testing
|
||
|
plt.imshow(display, cmap="gray", vmin=0, vmax=1)
|
||
|
plt.show()
|
||
|
time.sleep(0.5)
|
||
|
|
||
|
return data_out
|
||
|
|
||
|
def __del__(self):
|
||
|
print("Now I'm deleting me!")
|
||
|
try:
|
||
|
cv2.destroyAllWindows()
|
||
|
except:
|
||
|
pass
|
||
|
if self.use_gui:
|
||
|
print("...and the GUI!")
|
||
|
try:
|
||
|
if self.confdata.gui_running is True:
|
||
|
self.root.destroy()
|
||
|
except:
|
||
|
pass
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
|
||
|
use_gui = True
|
||
|
|
||
|
data_in = None
|
||
|
op = OnlinePerception("cv2", use_gui=use_gui)
|
||
|
|
||
|
t_max = 40.0
|
||
|
dt_update = 10.0
|
||
|
t_update = dt_update
|
||
|
|
||
|
t = time.time()
|
||
|
t0 = t
|
||
|
while t - t0 < t_max:
|
||
|
|
||
|
data_out = op.update(data_in)
|
||
|
data_in = None
|
||
|
if data_out:
|
||
|
print("Output given!")
|
||
|
if "exit" in data_out.keys():
|
||
|
break
|
||
|
|
||
|
t = time.time()
|
||
|
if t - t0 > t_update:
|
||
|
|
||
|
# new canvas size
|
||
|
n_xy_canvas = int(400 + torch.randint(200, (1,)))
|
||
|
canvas_size = [1, 8, n_xy_canvas, n_xy_canvas]
|
||
|
|
||
|
# new features/phosphenes
|
||
|
n_features = int(16 + torch.randint(16, (1,)))
|
||
|
n_xy_features = int(31 + 2 * torch.randint(10, (1,)))
|
||
|
features = torch.rand(
|
||
|
(n_features, 1, n_xy_features, n_xy_features), device=op.torch_device
|
||
|
)
|
||
|
phosphene = torch.ones(
|
||
|
(1, 1, n_xy_features, n_xy_features), device=op.torch_device
|
||
|
)
|
||
|
|
||
|
# new positions
|
||
|
n_positions = int(1 + torch.randint(3, (1,)))
|
||
|
position_found = torch.zeros((1, n_positions, 3), device=op.torch_device)
|
||
|
position_found[0, :, 0] = torch.randint(n_features, (n_positions,))
|
||
|
position_found[0, :, 1] = torch.randint(n_xy_canvas, (n_positions,))
|
||
|
position_found[0, :, 2] = torch.randint(n_xy_canvas, (n_positions,))
|
||
|
t_update += dt_update
|
||
|
data_in = {
|
||
|
"position_found": position_found,
|
||
|
"canvas_size": canvas_size,
|
||
|
"features": features,
|
||
|
"phosphene": phosphene,
|
||
|
}
|
||
|
|
||
|
print("done!")
|
||
|
del op
|
||
|
|
||
|
|
||
|
# %%
|