Add files via upload
This commit is contained in:
commit
dab7dcb786
9 changed files with 2931 additions and 0 deletions
89
processing_chain/BuildImage.py
Normal file
89
processing_chain/BuildImage.py
Normal file
|
@ -0,0 +1,89 @@
|
|||
import torch
|
||||
|
||||
|
||||
def clip_coordinates(x_canvas: int, dx_canvas: int, dx_dict: int):
|
||||
|
||||
x_canvas = int(x_canvas)
|
||||
dx_canvas = int(dx_canvas)
|
||||
dx_dict = int(dx_dict)
|
||||
dr_dict = int(dx_dict // 2)
|
||||
|
||||
x0_canvas = int(x_canvas - dr_dict)
|
||||
# placement outside right boundary?
|
||||
if x0_canvas >= dx_canvas:
|
||||
return None
|
||||
|
||||
x1_canvas = int(x_canvas + dr_dict + (dx_dict % 2))
|
||||
# placement outside left boundary?
|
||||
if x1_canvas <= 0:
|
||||
return None
|
||||
|
||||
# clip to the left?
|
||||
if x0_canvas < 0:
|
||||
x0_dict = -x0_canvas
|
||||
x0_canvas = 0
|
||||
else:
|
||||
x0_dict = 0
|
||||
|
||||
# clip to the right?
|
||||
if x1_canvas > dx_canvas:
|
||||
x1_dict = dx_dict - (x1_canvas - dx_canvas)
|
||||
x1_canvas = dx_canvas
|
||||
else:
|
||||
x1_dict = dx_dict
|
||||
|
||||
# print(x0_canvas, x1_canvas, x0_dict, x1_dict)
|
||||
assert (x1_canvas - x0_canvas) == (x1_dict - x0_dict)
|
||||
|
||||
return x0_canvas, x1_canvas, x0_dict, x1_dict
|
||||
|
||||
|
||||
def BuildImage(
|
||||
canvas_size: torch.Size,
|
||||
dictionary: torch.Tensor,
|
||||
position_found: torch.Tensor,
|
||||
default_dtype,
|
||||
torch_device,
|
||||
):
|
||||
|
||||
assert position_found is not None
|
||||
assert dictionary is not None
|
||||
|
||||
canvas_size_copy = torch.tensor(canvas_size)
|
||||
assert canvas_size_copy.shape[0] == 4
|
||||
canvas_size_copy[1] = 1
|
||||
output = torch.zeros(
|
||||
canvas_size_copy.tolist(),
|
||||
device=torch_device,
|
||||
dtype=default_dtype,
|
||||
)
|
||||
|
||||
dx_canvas = canvas_size[-2]
|
||||
dy_canvas = canvas_size[-1]
|
||||
dx_dict = dictionary.shape[-2]
|
||||
dy_dict = dictionary.shape[-1]
|
||||
|
||||
for pattern_id in range(0, position_found.shape[0]):
|
||||
for patch_id in range(0, position_found.shape[1]):
|
||||
|
||||
x_canvas = position_found[pattern_id, patch_id, 1]
|
||||
y_canvas = position_found[pattern_id, patch_id, 2]
|
||||
|
||||
xv = clip_coordinates(x_canvas, dx_canvas, dx_dict)
|
||||
if xv == None:
|
||||
break
|
||||
|
||||
yv = clip_coordinates(y_canvas, dy_canvas, dy_dict)
|
||||
if yv == None:
|
||||
break
|
||||
|
||||
if dictionary.shape[0] > 1:
|
||||
elem_idx = int(position_found[pattern_id, patch_id, 0])
|
||||
else:
|
||||
elem_idx = 0
|
||||
|
||||
output[pattern_id, 0, xv[0] : xv[1], yv[0] : yv[1]] += dictionary[
|
||||
elem_idx, 0, xv[2] : xv[3], yv[2] : yv[3]
|
||||
]
|
||||
|
||||
return output
|
288
processing_chain/ContourExtract.py
Normal file
288
processing_chain/ContourExtract.py
Normal file
|
@ -0,0 +1,288 @@
|
|||
# ContourExtract.py
|
||||
# ====================================
|
||||
# extracts contours from gray-level images
|
||||
#
|
||||
# Version V1.0, pre-07.03.2023:
|
||||
# no actual changes, is David's last code version...
|
||||
#
|
||||
# Version V1.1, 07.03.2023:
|
||||
# merged David's rebuild code (GUI capable)
|
||||
#
|
||||
|
||||
import torch
|
||||
import time
|
||||
import math
|
||||
|
||||
|
||||
class ContourExtract(torch.nn.Module):
|
||||
image_width: int = -1
|
||||
image_height: int = -1
|
||||
output_threshold: torch.Tensor = torch.tensor(-1)
|
||||
|
||||
n_orientations: int
|
||||
sigma_kernel: float = 0
|
||||
lambda_kernel: float = 0
|
||||
gamma_aspect_ratio: float = 0
|
||||
|
||||
kernel_axis_x: torch.Tensor | None = None
|
||||
kernel_axis_y: torch.Tensor | None = None
|
||||
target_orientations: torch.Tensor
|
||||
weight_vector: torch.Tensor
|
||||
|
||||
image_scale: float | None
|
||||
|
||||
psi_phase_offset_cos: torch.Tensor
|
||||
psi_phase_offset_sin: torch.Tensor
|
||||
|
||||
fft_gabor_cos_bank: torch.Tensor | None = None
|
||||
fft_gabor_sin_bank: torch.Tensor | None = None
|
||||
|
||||
pi: torch.Tensor
|
||||
torch_device: torch.device
|
||||
default_dtype = torch.float32
|
||||
|
||||
rebuild_kernels: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_orientations: int,
|
||||
sigma_kernel: float,
|
||||
lambda_kernel: float,
|
||||
gamma_aspect_ratio: float = 1.0,
|
||||
image_scale: float | None = 255.0,
|
||||
torch_device: str = "cpu",
|
||||
):
|
||||
super().__init__()
|
||||
self.torch_device = torch.device(torch_device)
|
||||
|
||||
self.pi = torch.tensor(
|
||||
math.pi,
|
||||
device=self.torch_device,
|
||||
dtype=self.default_dtype,
|
||||
)
|
||||
|
||||
self.psi_phase_offset_cos = torch.tensor(
|
||||
0.0,
|
||||
device=self.torch_device,
|
||||
dtype=self.default_dtype,
|
||||
)
|
||||
self.psi_phase_offset_sin = -self.pi / 2
|
||||
self.gamma_aspect_ratio = gamma_aspect_ratio
|
||||
self.image_scale = image_scale
|
||||
|
||||
self.update_settings(
|
||||
n_orientations=n_orientations,
|
||||
sigma_kernel=sigma_kernel,
|
||||
lambda_kernel=lambda_kernel,
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
assert input.ndim == 4, "input must have 4 dims!"
|
||||
|
||||
# We can only handle one color channel
|
||||
assert input.shape[1] == 1, "input.shape[1] must be 1!"
|
||||
|
||||
input = input.type(dtype=self.default_dtype).to(device=self.torch_device)
|
||||
if self.image_scale is not None:
|
||||
# scale grey level [0, 255] to range [0.0, 1.0]
|
||||
input /= self.image_scale
|
||||
|
||||
# Do we have valid kernels?
|
||||
if input.shape[-2] != self.image_width:
|
||||
self.image_width = input.shape[-2]
|
||||
self.rebuild_kernels = True
|
||||
|
||||
if input.shape[-1] != self.image_height:
|
||||
self.image_height = input.shape[-1]
|
||||
self.rebuild_kernels = True
|
||||
|
||||
assert self.image_width > 0
|
||||
assert self.image_height > 0
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
# We need to rebuild the kernels
|
||||
if self.rebuild_kernels is True:
|
||||
|
||||
self.kernel_axis_x = self.create_kernel_axis(self.image_width)
|
||||
self.kernel_axis_y = self.create_kernel_axis(self.image_height)
|
||||
|
||||
assert self.kernel_axis_x is not None
|
||||
assert self.kernel_axis_y is not None
|
||||
|
||||
gabor_cos_bank, gabor_sin_bank = self.create_gabor_filter_bank()
|
||||
|
||||
assert gabor_cos_bank is not None
|
||||
assert gabor_sin_bank is not None
|
||||
|
||||
self.fft_gabor_cos_bank = torch.fft.rfft2(
|
||||
gabor_cos_bank, s=None, dim=(-2, -1), norm=None
|
||||
).unsqueeze(0)
|
||||
|
||||
self.fft_gabor_sin_bank = torch.fft.rfft2(
|
||||
gabor_sin_bank, s=None, dim=(-2, -1), norm=None
|
||||
).unsqueeze(0)
|
||||
|
||||
# compute threshold for ignoring non-zero outputs arising due
|
||||
# to numerical imprecision (fft, kernel definition)
|
||||
assert self.weight_vector is not None
|
||||
norm_input = torch.full(
|
||||
(1, 1, self.image_width, self.image_height),
|
||||
1.0,
|
||||
device=self.torch_device,
|
||||
dtype=self.default_dtype,
|
||||
)
|
||||
norm_fft_input = torch.fft.rfft2(
|
||||
norm_input, s=None, dim=(-2, -1), norm=None
|
||||
)
|
||||
|
||||
norm_output_sin = torch.fft.irfft2(
|
||||
norm_fft_input * self.fft_gabor_sin_bank,
|
||||
s=None,
|
||||
dim=(-2, -1),
|
||||
norm=None,
|
||||
)
|
||||
norm_output_cos = torch.fft.irfft2(
|
||||
norm_fft_input * self.fft_gabor_cos_bank,
|
||||
s=None,
|
||||
dim=(-2, -1),
|
||||
norm=None,
|
||||
)
|
||||
norm_value_matrix = torch.sqrt(norm_output_sin**2 + norm_output_cos**2)
|
||||
|
||||
# norm_output = torch.abs(
|
||||
# (self.weight_vector * norm_value_matrix).sum(dim=-1)
|
||||
# ).type(dtype=torch.float32)
|
||||
|
||||
self.output_threshold = norm_value_matrix.max()
|
||||
|
||||
self.rebuild_kernels = False
|
||||
|
||||
assert self.fft_gabor_cos_bank is not None
|
||||
assert self.fft_gabor_sin_bank is not None
|
||||
|
||||
t1 = time.time()
|
||||
|
||||
fft_input = torch.fft.rfft2(input, s=None, dim=(-2, -1), norm=None)
|
||||
|
||||
output_sin = torch.fft.irfft2(
|
||||
fft_input * self.fft_gabor_sin_bank, s=None, dim=(-2, -1), norm=None
|
||||
)
|
||||
|
||||
output_cos = torch.fft.irfft2(
|
||||
fft_input * self.fft_gabor_cos_bank, s=None, dim=(-2, -1), norm=None
|
||||
)
|
||||
|
||||
t2 = time.time()
|
||||
|
||||
output = torch.sqrt(output_sin**2 + output_cos**2)
|
||||
|
||||
t3 = time.time()
|
||||
|
||||
print(
|
||||
"ContourExtract {:.3f}s: prep-{:.3f}s, fft-{:.3f}s, out-{:.3f}s".format(
|
||||
t3 - t0, t1 - t0, t2 - t1, t3 - t2
|
||||
)
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def create_collapse(self, input: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
assert self.weight_vector is not None
|
||||
|
||||
output = torch.abs((self.weight_vector * input).sum(dim=1)).type(
|
||||
dtype=self.default_dtype
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def create_kernel_axis(self, axis_size: int) -> torch.Tensor:
|
||||
|
||||
lower_bound_axis: int = -int(math.floor(axis_size / 2))
|
||||
upper_bound_axis: int = int(math.ceil(axis_size / 2))
|
||||
|
||||
kernel_axis = torch.arange(
|
||||
lower_bound_axis,
|
||||
upper_bound_axis,
|
||||
device=self.torch_device,
|
||||
dtype=self.default_dtype,
|
||||
)
|
||||
kernel_axis = torch.roll(kernel_axis, int(math.ceil(axis_size / 2)))
|
||||
|
||||
return kernel_axis
|
||||
|
||||
def create_gabor_filter_bank(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
assert self.kernel_axis_x is not None
|
||||
assert self.kernel_axis_y is not None
|
||||
assert self.target_orientations is not None
|
||||
|
||||
orientation_matrix = (
|
||||
self.target_orientations.unsqueeze(-1).unsqueeze(-1).detach().clone()
|
||||
)
|
||||
x_kernel_matrix = self.kernel_axis_x.unsqueeze(0).unsqueeze(-1).detach().clone()
|
||||
y_kernel_matrix = self.kernel_axis_y.unsqueeze(0).unsqueeze(0).detach().clone()
|
||||
|
||||
r2 = x_kernel_matrix**2 + self.gamma_aspect_ratio * y_kernel_matrix**2
|
||||
|
||||
kr = x_kernel_matrix * torch.cos(
|
||||
orientation_matrix
|
||||
) + y_kernel_matrix * torch.sin(orientation_matrix)
|
||||
|
||||
c0 = torch.exp(-2 * (self.pi * self.sigma_kernel / self.lambda_kernel) ** 2)
|
||||
|
||||
gauss: torch.Tensor = torch.exp(-r2 / 2 / self.sigma_kernel**2)
|
||||
|
||||
gabor_cos_bank: torch.Tensor = gauss * (
|
||||
torch.cos(2 * self.pi * kr / self.lambda_kernel + self.psi_phase_offset_cos)
|
||||
- c0 * torch.cos(self.psi_phase_offset_cos)
|
||||
)
|
||||
|
||||
gabor_sin_bank: torch.Tensor = gauss * (
|
||||
torch.cos(2 * self.pi * kr / self.lambda_kernel + self.psi_phase_offset_sin)
|
||||
- c0 * torch.cos(self.psi_phase_offset_sin)
|
||||
)
|
||||
|
||||
return gabor_cos_bank, gabor_sin_bank
|
||||
|
||||
def update_settings(
|
||||
self,
|
||||
n_orientations: int,
|
||||
sigma_kernel: float,
|
||||
lambda_kernel: float,
|
||||
):
|
||||
|
||||
self.rebuild_kernels = True
|
||||
|
||||
self.n_orientations = n_orientations
|
||||
self.sigma_kernel = sigma_kernel
|
||||
self.lambda_kernel = lambda_kernel
|
||||
|
||||
# generate orientation axis and axis for complex summation
|
||||
self.target_orientations: torch.Tensor = (
|
||||
torch.arange(
|
||||
start=0,
|
||||
end=int(self.n_orientations),
|
||||
device=self.torch_device,
|
||||
dtype=self.default_dtype,
|
||||
)
|
||||
* torch.tensor(
|
||||
math.pi,
|
||||
device=self.torch_device,
|
||||
dtype=self.default_dtype,
|
||||
)
|
||||
/ self.n_orientations
|
||||
)
|
||||
|
||||
self.weight_vector: torch.Tensor = (
|
||||
torch.exp(
|
||||
2.0
|
||||
* torch.complex(torch.tensor(0.0), torch.tensor(1.0))
|
||||
* self.target_orientations
|
||||
)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(0)
|
||||
)
|
162
processing_chain/DiscardElements.py
Normal file
162
processing_chain/DiscardElements.py
Normal file
|
@ -0,0 +1,162 @@
|
|||
#%%
|
||||
# DiscardElements.py
|
||||
# ====================================
|
||||
# removes elements from a sparse image representation
|
||||
# such that a 'most uniform' coverage still exists
|
||||
#
|
||||
# Version V1.0, pre-07.03.2023:
|
||||
# no actual changes, is David's last code version...
|
||||
|
||||
import numpy as np
|
||||
|
||||
# assume locations is array [n, 3]
|
||||
# the three entries are [shape_index, pos_x, pos_y]
|
||||
|
||||
|
||||
def discard_elements_simple(locations: np.ndarray, target_number_elements: list):
|
||||
|
||||
n_locations: int = locations.shape[0]
|
||||
locations_remain: list = []
|
||||
|
||||
# Loop across all target number of elements
|
||||
for target_elem in target_number_elements:
|
||||
|
||||
assert target_elem > 0, "Number of target elements must be larger than 0!"
|
||||
assert (
|
||||
target_elem <= n_locations
|
||||
), "Number of target elements must be <= number of available locations!"
|
||||
|
||||
# Build distance matrix between positions in locations_highest_res.
|
||||
# Its diagonal is defined as Inf because we don't want to consider these values in our
|
||||
# search for the minimum distances.
|
||||
distance_matrix = np.sqrt(
|
||||
((locations[np.newaxis, :, 1:] - locations[:, np.newaxis, 1:]) ** 2).sum(
|
||||
axis=-1
|
||||
)
|
||||
)
|
||||
distance_matrix[np.arange(n_locations), np.arange(n_locations)] = np.inf
|
||||
|
||||
# Find the minimal distances in upper triangle of matrix.
|
||||
idcs_remove: list = []
|
||||
while (n_locations - len(idcs_remove)) != target_elem:
|
||||
|
||||
# Get index of matrix with minimal distance
|
||||
row_idcs, col_idcs = np.where(
|
||||
distance_matrix == distance_matrix[distance_matrix > 0].min()
|
||||
)
|
||||
|
||||
# Get the max index.
|
||||
# It correspond to the index of the element we will remove in the locations_highest_res list
|
||||
sel_idx: int = max(row_idcs[0], col_idcs[0])
|
||||
idcs_remove.append(sel_idx) # Save the index
|
||||
|
||||
# Set current distance as Inf because we don't want to consider it further in our search
|
||||
distance_matrix[sel_idx, :] = np.inf
|
||||
distance_matrix[:, sel_idx] = np.inf
|
||||
|
||||
idcs_remain: list = np.setdiff1d(np.arange(n_locations), idcs_remove)
|
||||
locations_remain.append(locations[idcs_remain, :])
|
||||
|
||||
return locations_remain
|
||||
|
||||
|
||||
# assume locations is array [n, 3]
|
||||
# the three entries are [shape_index, pos_x, pos_y]
|
||||
def discard_elements(
|
||||
locations: np.ndarray, target_number_elements: list, prior: np.ndarray
|
||||
):
|
||||
|
||||
n_locations: int = locations.shape[0]
|
||||
locations_remain: list = []
|
||||
disable_value: float = np.nan
|
||||
|
||||
# if type(prior) != np.ndarray:
|
||||
# prior = np.ones((n_locations,))
|
||||
assert prior.shape == (
|
||||
n_locations,
|
||||
), "Prior must have same number of entries as elements in locations!"
|
||||
print(prior)
|
||||
|
||||
# Loop across all target number of elements
|
||||
for target_elem in target_number_elements:
|
||||
|
||||
assert target_elem > 0, "Number of target elements must be larger than 0!"
|
||||
assert (
|
||||
target_elem <= n_locations
|
||||
), "Number of target elements must be <= number of available locations!"
|
||||
|
||||
# Build distance matrix between positions in locations_highest_res.
|
||||
# Its diagonal is defined as Inf because we don't want to consider these values in our
|
||||
# search for the minimum distances.
|
||||
distance_matrix = np.sqrt(
|
||||
((locations[np.newaxis, :, 1:] - locations[:, np.newaxis, 1:]) ** 2).sum(
|
||||
axis=-1
|
||||
)
|
||||
)
|
||||
prior_matrix = prior[np.newaxis, :] * prior[:, np.newaxis]
|
||||
distance_matrix *= prior_matrix
|
||||
distance_matrix[np.arange(n_locations), np.arange(n_locations)] = disable_value
|
||||
print(distance_matrix)
|
||||
|
||||
# Find the minimal distances in upper triangle of matrix.
|
||||
idcs_remove: list = []
|
||||
while (n_locations - len(idcs_remove)) != target_elem:
|
||||
|
||||
# Get index of matrix with minimal distance
|
||||
row_idcs, col_idcs = np.where(
|
||||
# distance_matrix == distance_matrix[distance_matrix > 0].min()
|
||||
distance_matrix
|
||||
== np.nanmin(distance_matrix)
|
||||
)
|
||||
|
||||
# Get the max index.
|
||||
# It correspond to the index of the element we will remove in the locations_highest_res list
|
||||
print(row_idcs[0], col_idcs[0])
|
||||
# if prior[row_idcs[0]] >= prior[col_idcs[0]]:
|
||||
# sel_idx = row_idcs[0]
|
||||
# else:
|
||||
# sel_idx = col_idcs[0]
|
||||
d_row = np.nansum(distance_matrix[row_idcs[0], :])
|
||||
d_col = np.nansum(distance_matrix[:, col_idcs[0]])
|
||||
if d_row > d_col:
|
||||
sel_idx = col_idcs[0]
|
||||
else:
|
||||
sel_idx = row_idcs[0]
|
||||
# sel_idx: int = max(row_idcs[0], col_idcs[0])
|
||||
idcs_remove.append(sel_idx) # Save the index
|
||||
|
||||
# Set current distance as Inf because we don't want to consider it further in our search
|
||||
distance_matrix[sel_idx, :] = disable_value
|
||||
distance_matrix[:, sel_idx] = disable_value
|
||||
|
||||
idcs_remain: list = np.setdiff1d(np.arange(n_locations), idcs_remove)
|
||||
locations_remain.append(locations[idcs_remain, :])
|
||||
|
||||
return locations_remain
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# generate a circle with n locations
|
||||
n_locations: int = 20
|
||||
phi = np.arange(n_locations) / n_locations * 2 * np.pi
|
||||
locations = np.ones((n_locations, 3))
|
||||
locations[:, 1] = np.cos(phi)
|
||||
locations[:, 2] = np.sin(phi)
|
||||
prior = np.ones((n_locations,))
|
||||
prior[:10] = 0.1
|
||||
locations_remain = discard_elements(locations, [n_locations // 5], prior=prior)
|
||||
|
||||
plt.plot(locations[:, 1], locations[:, 2], "ko")
|
||||
plt.plot(locations_remain[0][:, 1], locations_remain[0][:, 2], "rx")
|
||||
plt.show()
|
||||
|
||||
locations_remain_simple = discard_elements_simple(locations, [n_locations // 5])
|
||||
|
||||
plt.plot(locations[:, 1], locations[:, 2], "ko")
|
||||
plt.plot(locations_remain_simple[0][:, 1], locations_remain_simple[0][:, 2], "rx")
|
||||
plt.show()
|
||||
|
||||
|
||||
# %%
|
711
processing_chain/OnlineEncoding.py
Normal file
711
processing_chain/OnlineEncoding.py
Normal file
|
@ -0,0 +1,711 @@
|
|||
# %%
|
||||
#
|
||||
# test_OnlineEncoding.py
|
||||
# ========================================================
|
||||
# encode visual scenes into sparse representations using
|
||||
# different kinds of dictionaries
|
||||
#
|
||||
# -> derived from test_PsychophysicsEncoding.py
|
||||
#
|
||||
# Version 1.0, 29.04.2023:
|
||||
#
|
||||
# Version 1.1, 21.06.2023:
|
||||
# define proper class
|
||||
#
|
||||
|
||||
|
||||
# Import Python modules
|
||||
# ========================================================
|
||||
# import csv
|
||||
# import time
|
||||
import os
|
||||
import glob
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
import torchvision as tv
|
||||
from PIL import Image
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
# Import our modules
|
||||
# ========================================================
|
||||
from processing_chain.ContourExtract import ContourExtract
|
||||
from processing_chain.PatchGenerator import PatchGenerator
|
||||
from processing_chain.Sparsifier import Sparsifier
|
||||
from processing_chain.DiscardElements import discard_elements_simple
|
||||
from processing_chain.BuildImage import BuildImage
|
||||
from processing_chain.WebCam import WebCam
|
||||
from processing_chain.Yolo5Segmentation import Yolo5Segmentation
|
||||
|
||||
|
||||
# TODO required?
|
||||
def show_torch_frame(
|
||||
frame_torch: torch.Tensor,
|
||||
title: str = "",
|
||||
cmap: str = "viridis",
|
||||
target: str = "pyplot",
|
||||
):
|
||||
frame_numpy = (
|
||||
(frame_torch.movedim(0, -1) * 255).type(dtype=torch.uint8).cpu().numpy()
|
||||
)
|
||||
if target == "pyplot":
|
||||
plt.imshow(frame_numpy, cmap=cmap)
|
||||
plt.title(title)
|
||||
plt.show()
|
||||
if target == "cv2":
|
||||
if frame_numpy.ndim == 3:
|
||||
if frame_numpy.shape[-1] == 1:
|
||||
frame_numpy = np.tile(frame_numpy, [1, 1, 3])
|
||||
frame_numpy = (frame_numpy - frame_numpy.min()) / (
|
||||
frame_numpy.max() - frame_numpy.min()
|
||||
)
|
||||
# print(frame_numpy.shape, frame_numpy.max(), frame_numpy.min())
|
||||
cv2.namedWindow(title, cv2.WINDOW_NORMAL)
|
||||
cv2.imshow(title, frame_numpy[:, :, (2, 1, 0)])
|
||||
cv2.waitKey(1)
|
||||
|
||||
return
|
||||
|
||||
|
||||
# TODO required?
|
||||
def embed_image(frame_torch, out_height, out_width, init_value=0):
|
||||
|
||||
out_shape = torch.tensor(frame_torch.shape)
|
||||
|
||||
frame_width = frame_torch.shape[-1]
|
||||
frame_height = frame_torch.shape[-2]
|
||||
|
||||
frame_width_idx0 = max([0, (frame_width - out_width) // 2])
|
||||
frame_height_idx0 = max([0, (frame_height - out_height) // 2])
|
||||
|
||||
select_width = min([frame_width, out_width])
|
||||
select_height = min([frame_height, out_height])
|
||||
|
||||
out_shape[-1] = out_width
|
||||
out_shape[-2] = out_height
|
||||
|
||||
out_torch = init_value * torch.ones(tuple(out_shape))
|
||||
|
||||
out_width_idx0 = max([0, (out_width - frame_width) // 2])
|
||||
out_height_idx0 = max([0, (out_height - frame_height) // 2])
|
||||
|
||||
out_torch[
|
||||
...,
|
||||
out_height_idx0 : (out_height_idx0 + select_height),
|
||||
out_width_idx0 : (out_width_idx0 + select_width),
|
||||
] = frame_torch[
|
||||
...,
|
||||
frame_height_idx0 : (frame_height_idx0 + select_height),
|
||||
frame_width_idx0 : (frame_width_idx0 + select_width),
|
||||
]
|
||||
|
||||
return out_torch
|
||||
|
||||
|
||||
class OnlineEncoding:
|
||||
|
||||
# TODO: also pre-populate self-ies here?
|
||||
#
|
||||
# DEFINED IN "__init__":
|
||||
#
|
||||
# display (fixed)
|
||||
# gabor (changeable)
|
||||
# encoding (changeable)
|
||||
# dictionary (changeable)
|
||||
# control (fixed)
|
||||
# path (fixed)
|
||||
# verbose
|
||||
# torch_device, default_dtype
|
||||
# display_size_max_x_PIX, display_size_max_y_PIX
|
||||
# padding_fill
|
||||
# cap
|
||||
# yolo
|
||||
# classes_detect
|
||||
#
|
||||
#
|
||||
# DEFINED IN "apply_parameter_changes":
|
||||
#
|
||||
# padding_PIX
|
||||
# sigma_kernel_PIX, lambda_kernel_PIX
|
||||
# out_x, out_y
|
||||
# clocks, phosphene, clocks_filter
|
||||
#
|
||||
|
||||
def __init__(self, source=0, verbose=False):
|
||||
|
||||
# Define parameters
|
||||
# ========================================================
|
||||
# Unit abbreviations:
|
||||
# dva = degrees of visual angle
|
||||
# pix = pixels
|
||||
print("OE-Init: Defining default parameters...")
|
||||
self.verbose = verbose
|
||||
|
||||
# display: Defines geometry of target display
|
||||
# ========================================================
|
||||
# The encoded image will be scaled such that it optimally uses
|
||||
# the max space available. If the orignal image has a different aspect
|
||||
# ratio than the display region, it will only use one spatial
|
||||
# dimension (horizontal or vertical) to its full extent
|
||||
#
|
||||
# If one DVA corresponds to different PIX_per_DVA on the display,
|
||||
# (i.e. varying distance observers from screen), it should be set
|
||||
# larger than the largest PIX_per_DVA required, for avoiding
|
||||
# extrapolation artefacts or blur.
|
||||
#
|
||||
self.display = {
|
||||
"size_max_x_DVA": 10.0, # maximum x size of encoded image
|
||||
"size_max_y_DVA": 10.0, # minimum y size of encoded image
|
||||
"PIX_per_DVA": 40.0, # scaling factor pixels to DVA
|
||||
"scale": "same_range", # "same_luminance" or "same_range"
|
||||
}
|
||||
|
||||
# gabor: Defines paras of Gabor filters for contour extraction
|
||||
# ==============================================================
|
||||
self.gabor = {
|
||||
"sigma_kernel_DVA": 0.06,
|
||||
"lambda_kernel_DVA": 0.12,
|
||||
"n_orientations": 8,
|
||||
}
|
||||
|
||||
# encoding: Defines parameters of sparse encoding process
|
||||
# ========================================================
|
||||
# Roughly speaking, after contour extraction dictionary elements
|
||||
# will be placed starting from the position with the highest
|
||||
# overlap with the contour. Elements placed can be surrounded
|
||||
# by a dead or inhibitory zone to prevent placing further elements
|
||||
# too closely. The procedure will map 'n_patches_compute' elements
|
||||
# and then stop. For each element one obtains an overlap with the
|
||||
# contour image.
|
||||
#
|
||||
# After placement, the overlaps found are normalized to the max
|
||||
# overlap found, and then all elements with a larger normalized overlap
|
||||
# than 'overlap_threshold' will be selected. These remaining
|
||||
# elements will comprise a 'full' encoding of the contour.
|
||||
#
|
||||
# To generate even sparser representations, the full encoding can
|
||||
# be reduced to a certain percentage of elements in the full encoding
|
||||
# by setting the variable 'percentages'
|
||||
#
|
||||
# Example: n_patches_compute = 100 reduced by overlap_threshold = 0.1
|
||||
# to 80 elements. Requesting a percentage of 30% yields a representation
|
||||
# with 24 elements.
|
||||
#
|
||||
self.encoding = {
|
||||
"n_patches_compute": 100, # this amount of patches will be placed
|
||||
"use_exp_deadzone": True, # parameters of Gaussian deadzone
|
||||
"size_exp_deadzone_DVA": 1.20, # PREVIOUSLY 1.4283
|
||||
"use_cutout_deadzone": True, # parameters of cutout deadzone
|
||||
"size_cutout_deadzone_DVA": 0.65, # PREVIOUSLY 0.7575
|
||||
"overlap_threshold": 0.1, # relative overlap threshold
|
||||
"percentages": torch.tensor([100]),
|
||||
}
|
||||
self.number_of_patches = self.encoding["n_patches_compute"]
|
||||
|
||||
# dictionary: Defines parameters of dictionary
|
||||
# ========================================================
|
||||
self.dictionary = {
|
||||
"size_DVA": 1.0, # PREVIOUSLY 1.25,
|
||||
"clocks": None, # parameters for clocks dictionary, see below
|
||||
"phosphene": None, # paramters for phosphene dictionary, see below
|
||||
}
|
||||
|
||||
self.dictionary["phosphene"]: dict[float] = {
|
||||
"sigma_width": 0.18, # DEFAULT 0.15, # half-width of Gaussian
|
||||
}
|
||||
|
||||
self.dictionary["clocks"]: dict[int, int, float, float] = {
|
||||
"n_dir": 8, # number of directions for clock pointer segments
|
||||
"n_open": 4, # number of opening angles between two clock pointer segments
|
||||
"pointer_width": 0.07, # PREVIOUSLY 0.05, # relative width and size of tip extension of clock pointer
|
||||
"pointer_length": 0.18, # PREVIOUSLY 0.15, # relative length of clock pointer
|
||||
}
|
||||
|
||||
# control: For controlling plotting options and flow of script
|
||||
# ========================================================
|
||||
self.control = {
|
||||
"force_torch_use_cpu": False, # force using CPU even if GPU available
|
||||
"show_capture": True, # shows captured image
|
||||
"show_object": True, # shows detected object
|
||||
"show_contours": True, # shows extracted contours
|
||||
"show_percept": True, # shows percept
|
||||
}
|
||||
|
||||
# specify classes to detect
|
||||
class_person = 0
|
||||
self.classes_detect = [class_person]
|
||||
|
||||
print(
|
||||
"OE-Init: Defining paths, creating dirs, setting default device and datatype"
|
||||
)
|
||||
|
||||
# path: Path infos for input and output images
|
||||
# ========================================================
|
||||
self.path = {"output": "test/output/level1/", "input": "test/images_test/"}
|
||||
# Make output directories, if necessary: the place were we dump the new images to...
|
||||
# os.makedirs(self.path["output"], mode=0o777, exist_ok=True)
|
||||
|
||||
# Check if GPU is available and use it, if possible
|
||||
# =================================================
|
||||
self.default_dtype = torch.float32
|
||||
torch.set_default_dtype(self.default_dtype)
|
||||
if self.control["force_torch_use_cpu"]:
|
||||
torch_device: str = "cpu"
|
||||
else:
|
||||
torch_device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Using {torch_device} as TORCH device...")
|
||||
self.torch_device = torch_device
|
||||
|
||||
print("OE-Init: Compute display scaling factors and padding RGB values")
|
||||
|
||||
# global scaling factors for all pixel-related length scales
|
||||
self.display_size_max_x_PIX: float = (
|
||||
self.display["size_max_x_DVA"] * self.display["PIX_per_DVA"]
|
||||
)
|
||||
self.display_size_max_y_PIX: float = (
|
||||
self.display["size_max_y_DVA"] * self.display["PIX_per_DVA"]
|
||||
)
|
||||
|
||||
# determine padding fill value
|
||||
tmp = tv.transforms.Grayscale(num_output_channels=1)
|
||||
tmp_value = torch.full((3, 1, 1), 0)
|
||||
self.padding_fill: int = int(tmp(tmp_value).squeeze())
|
||||
|
||||
print(f"OE-Init: Opening camera source or video file '{source}'")
|
||||
|
||||
# open source
|
||||
self.cap = WebCam(source)
|
||||
if not self.cap.open_cam():
|
||||
raise OSError(f"Opening source {source} failed!")
|
||||
|
||||
# get the video frame size, frame count and frame rate
|
||||
frame_width = self.cap.cap_frame_width
|
||||
frame_height = self.cap.cap_frame_height
|
||||
fps = self.cap.cap_fps
|
||||
print(
|
||||
f"OE-Init: Processing frames of {frame_width} x {frame_height} @ {fps} fps."
|
||||
)
|
||||
|
||||
# open output file if we want to save frames
|
||||
# if output_file != None:
|
||||
# out = cv2.VideoWriter(
|
||||
# output_file,
|
||||
# cv2.VideoWriter_fourcc(*"MJPG"),
|
||||
# fps,
|
||||
# (out_x, out_y),
|
||||
# )
|
||||
# if out == None:
|
||||
# raise OSError(f"Can not open file {output_file} for writing!")
|
||||
|
||||
# get an instance of the Yolo segmentation network
|
||||
print("OE-Init: initialize YOLO")
|
||||
self.yolo = Yolo5Segmentation(torch_device=self.torch_device)
|
||||
|
||||
self.send_dictionaries = False
|
||||
|
||||
self.apply_parameter_changes()
|
||||
|
||||
return
|
||||
|
||||
def apply_parameter_changes(self):
|
||||
|
||||
# GET NEW PARAMETERS
|
||||
print("OE-AppParChg: Computing sizes from new parameters")
|
||||
|
||||
### BLOCK: dictionary ----------------
|
||||
# set patch size for both dictionaries, make sure it is odd number
|
||||
dictionary_size_PIX: int = (
|
||||
1
|
||||
+ (int(self.dictionary["size_DVA"] * self.display["PIX_per_DVA"]) // 2) * 2
|
||||
)
|
||||
|
||||
### BLOCK: gabor ---------------------
|
||||
# convert contour-related parameters to pixel units
|
||||
self.sigma_kernel_PIX: float = (
|
||||
self.gabor["sigma_kernel_DVA"] * self.display["PIX_per_DVA"]
|
||||
)
|
||||
self.lambda_kernel_PIX: float = (
|
||||
self.gabor["lambda_kernel_DVA"] * self.display["PIX_per_DVA"]
|
||||
)
|
||||
|
||||
### BLOCK: gabor & dictionary ------------------
|
||||
# Padding
|
||||
# -------
|
||||
self.padding_PIX: int = int(
|
||||
max(3.0 * self.sigma_kernel_PIX, 1.1 * dictionary_size_PIX)
|
||||
)
|
||||
|
||||
# define target video/representation width/height
|
||||
multiple_of = 4
|
||||
out_x = self.display_size_max_x_PIX + 2 * self.padding_PIX
|
||||
out_y = self.display_size_max_y_PIX + 2 * self.padding_PIX
|
||||
out_x += (multiple_of - (out_x % multiple_of)) % multiple_of
|
||||
out_y += (multiple_of - (out_y % multiple_of)) % multiple_of
|
||||
self.out_x = int(out_x)
|
||||
self.out_y = int(out_y)
|
||||
|
||||
# generate dictionaries
|
||||
# ---------------------
|
||||
### BLOCK: dictionary --------------------------
|
||||
print("OE-AppParChg: Generating dictionaries...")
|
||||
patch_generator = PatchGenerator(torch_device=self.torch_device)
|
||||
self.phosphene = patch_generator.alphabet_phosphene(
|
||||
patch_size=dictionary_size_PIX,
|
||||
sigma_width=self.dictionary["phosphene"]["sigma_width"]
|
||||
* dictionary_size_PIX,
|
||||
)
|
||||
### BLOCK: dictionary & gabor --------------------------
|
||||
self.clocks_filter, self.clocks, segments = patch_generator.alphabet_clocks(
|
||||
patch_size=dictionary_size_PIX,
|
||||
n_dir=self.dictionary["clocks"]["n_dir"],
|
||||
n_filter=self.gabor["n_orientations"],
|
||||
segment_width=self.dictionary["clocks"]["pointer_width"]
|
||||
* dictionary_size_PIX,
|
||||
segment_length=self.dictionary["clocks"]["pointer_length"]
|
||||
* dictionary_size_PIX,
|
||||
)
|
||||
|
||||
self.send_dictionaries = True
|
||||
|
||||
return
|
||||
|
||||
# classes_detect, out_x, out_y
|
||||
def update(self, data_in):
|
||||
|
||||
# handle parameter change
|
||||
|
||||
if data_in:
|
||||
|
||||
print("Incoming -----------> ", data_in)
|
||||
|
||||
self.number_of_patches = data_in["number_of_patches"]
|
||||
|
||||
self.classes_detect = data_in["value"]
|
||||
|
||||
self.gabor["sigma_kernel_DVA"] = data_in["sigma_kernel_DVA"]
|
||||
self.gabor["lambda_kernel_DVA"] = data_in["sigma_kernel_DVA"] * 2
|
||||
self.gabor["n_orientations"] = data_in["n_orientations"]
|
||||
|
||||
self.dictionary["size_DVA"] = data_in["size_DVA"]
|
||||
self.dictionary["phosphene"]["sigma_width"] = data_in["sigma_width"]
|
||||
self.dictionary["clocks"]["n_dir"] = data_in["n_dir"]
|
||||
self.dictionary["clocks"]["n_open"] = data_in["n_dir"] // 2
|
||||
self.dictionary["clocks"]["pointer_width"] = data_in["pointer_width"]
|
||||
self.dictionary["clocks"]["pointer_length"] = data_in["pointer_length"]
|
||||
|
||||
self.encoding["use_exp_deadzone"] = data_in["use_exp_deadzone"]
|
||||
self.encoding["size_exp_deadzone_DVA"] = data_in["size_exp_deadzone_DVA"]
|
||||
self.encoding["use_cutout_deadzone"] = data_in["use_cutout_deadzone"]
|
||||
self.encoding["size_cutout_deadzone_DVA"] = data_in[
|
||||
"size_cutout_deadzone_DVA"
|
||||
]
|
||||
|
||||
self.control["show_capture"] = data_in["enable_cam"]
|
||||
self.control["show_object"] = data_in["enable_yolo"]
|
||||
self.control["show_contours"] = data_in["enable_contour"]
|
||||
# TODO Fenster zumachen
|
||||
self.apply_parameter_changes()
|
||||
|
||||
# some constants for addressing specific components of output arrays
|
||||
image_id_CONST: int = 0
|
||||
overlap_index_CONST: int = 1
|
||||
|
||||
# format: color_RGB, height, width <class 'torch.tensor'> float, range=0,1
|
||||
print("OE-ProcessFrame: capturing frame")
|
||||
frame = self.cap.get_frame()
|
||||
if frame == None:
|
||||
raise OSError(f"Can not capture frame {i_frame}")
|
||||
if self.verbose:
|
||||
if self.control["show_capture"]:
|
||||
show_torch_frame(frame, title="Captured", target=self.verbose)
|
||||
else:
|
||||
try:
|
||||
cv2.destroyWindow("Captured")
|
||||
except:
|
||||
pass
|
||||
|
||||
# perform segmentation
|
||||
|
||||
frame = frame.to(device=self.torch_device)
|
||||
print("OE-ProcessFrame: frame segmentation by YOLO")
|
||||
frame_segmented = self.yolo(frame.unsqueeze(0), classes=self.classes_detect)
|
||||
|
||||
# This extracts the frame in x to convert the mask in a video format
|
||||
if self.yolo.found_class_id != None:
|
||||
|
||||
n_found = len(self.yolo.found_class_id)
|
||||
print(
|
||||
f"OE-ProcessFrame: {n_found} occurrences of desired object found in frame!"
|
||||
)
|
||||
|
||||
mask = frame_segmented[0]
|
||||
|
||||
# is there something in the mask?
|
||||
if not mask.sum() == 0:
|
||||
|
||||
# yes, cut only the part of the frame that has our object of interest
|
||||
frame_masked = mask * frame
|
||||
|
||||
x_height = mask.sum(axis=-2)
|
||||
x_indices = torch.where(x_height > 0)
|
||||
x_max = x_indices[0].max() + 1
|
||||
x_min = x_indices[0].min()
|
||||
|
||||
y_height = mask.sum(axis=-1)
|
||||
y_indices = torch.where(y_height > 0)
|
||||
y_max = y_indices[0].max() + 1
|
||||
y_min = y_indices[0].min()
|
||||
|
||||
frame_cut = frame_masked[:, y_min:y_max, x_min:x_max]
|
||||
else:
|
||||
print(f"OE-ProcessFrame: Mask contains all zeros in current frame!")
|
||||
frame_cut = None
|
||||
else:
|
||||
print(f"OE-ProcessFrame: No objects found in current frame!")
|
||||
frame_cut = None
|
||||
|
||||
if frame_cut == None:
|
||||
# out_torch = torch.zeros([self.out_y, self.out_x])
|
||||
position_selection = torch.zeros((1, 0, 3))
|
||||
contour_shape = [1, self.gabor["n_orientations"], 1, 1]
|
||||
else:
|
||||
if self.verbose:
|
||||
if self.control["show_object"]:
|
||||
show_torch_frame(
|
||||
frame_cut, title="Selected Object", target=self.verbose
|
||||
)
|
||||
else:
|
||||
try:
|
||||
cv2.destroyWindow("Selected Object")
|
||||
except:
|
||||
pass
|
||||
|
||||
# UDO: from here on, we proceed as before, just handing
|
||||
# UDO: over the frame_cut --> image
|
||||
image = frame_cut
|
||||
|
||||
# Determine target size of image
|
||||
# image: [RGB, Height, Width], dtype= tensor.torch.uint8
|
||||
print("OE-ProcessFrame: Computing downsampling factor image -> display")
|
||||
f_x: float = self.display_size_max_x_PIX / image.shape[-1]
|
||||
f_y: float = self.display_size_max_y_PIX / image.shape[-2]
|
||||
f_xy_min: float = min(f_x, f_y)
|
||||
downsampling_x: int = int(f_xy_min * image.shape[-1])
|
||||
downsampling_y: int = int(f_xy_min * image.shape[-2])
|
||||
|
||||
# CURRENTLY we do not crop in the end...
|
||||
# Image size for removing the fft crop later
|
||||
# center_crop_x: int = downsampling_x
|
||||
# center_crop_y: int = downsampling_y
|
||||
|
||||
# define contour extraction processing chain
|
||||
# ------------------------------------------
|
||||
print("OE-ProcessFrame: Extracting contours")
|
||||
train_processing_chain = tv.transforms.Compose(
|
||||
transforms=[
|
||||
tv.transforms.Grayscale(num_output_channels=1), # RGB to grayscale
|
||||
tv.transforms.Resize(
|
||||
size=(downsampling_y, downsampling_x)
|
||||
), # downsampling
|
||||
tv.transforms.Pad( # extra white padding around the picture
|
||||
padding=(self.padding_PIX, self.padding_PIX),
|
||||
fill=self.padding_fill,
|
||||
),
|
||||
ContourExtract( # contour extraction
|
||||
n_orientations=self.gabor["n_orientations"],
|
||||
sigma_kernel=self.sigma_kernel_PIX,
|
||||
lambda_kernel=self.lambda_kernel_PIX,
|
||||
torch_device=self.torch_device,
|
||||
),
|
||||
# CURRENTLY we do not crop in the end!
|
||||
# tv.transforms.CenterCrop( # Remove the padding
|
||||
# size=(center_crop_x, center_crop_y)
|
||||
# ),
|
||||
],
|
||||
)
|
||||
# ...with and without orientation channels
|
||||
contour = train_processing_chain(image.unsqueeze(0))
|
||||
contour_collapse = train_processing_chain.transforms[-1].create_collapse(
|
||||
contour
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
if self.control["show_contours"]:
|
||||
show_torch_frame(
|
||||
contour_collapse,
|
||||
title="Contours Extracted",
|
||||
cmap="gray",
|
||||
target=self.verbose,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
cv2.destroyWindow("Contours Extracted")
|
||||
except:
|
||||
pass
|
||||
|
||||
# generate a prior for mapping the contour to the dictionary
|
||||
# CURRENTLY we use an uniform prior...
|
||||
# ----------------------------------------------------------
|
||||
dictionary_prior = torch.ones(
|
||||
(self.clocks_filter.shape[0]),
|
||||
dtype=self.default_dtype,
|
||||
device=torch.device(self.torch_device),
|
||||
)
|
||||
|
||||
# instantiate and execute sparsifier
|
||||
# ----------------------------------
|
||||
print("OE-ProcessFrame: Performing sparsification")
|
||||
sparsifier = Sparsifier(
|
||||
dictionary_filter=self.clocks_filter,
|
||||
dictionary=self.clocks,
|
||||
dictionary_prior=dictionary_prior,
|
||||
number_of_patches=self.encoding["n_patches_compute"],
|
||||
size_exp_deadzone=self.encoding["size_exp_deadzone_DVA"]
|
||||
* self.display["PIX_per_DVA"],
|
||||
plot_use_map=False, # self.control["plot_deadzone"],
|
||||
deadzone_exp=self.encoding["use_exp_deadzone"],
|
||||
deadzone_hard_cutout=self.encoding["use_cutout_deadzone"],
|
||||
deadzone_hard_cutout_size=self.encoding["size_cutout_deadzone_DVA"]
|
||||
* self.display["PIX_per_DVA"],
|
||||
padding_deadzone_size_x=self.padding_PIX,
|
||||
padding_deadzone_size_y=self.padding_PIX,
|
||||
torch_device=self.torch_device,
|
||||
)
|
||||
sparsifier(contour)
|
||||
assert sparsifier.position_found is not None
|
||||
|
||||
# extract and normalize the overlap found
|
||||
overlap_found = sparsifier.overlap_found[
|
||||
image_id_CONST, :, overlap_index_CONST
|
||||
]
|
||||
overlap_found = overlap_found / overlap_found.max()
|
||||
|
||||
# get overlap above certain threshold, extract corresponding elements
|
||||
overlap_idcs_valid = torch.where(
|
||||
overlap_found >= self.encoding["overlap_threshold"]
|
||||
)[0]
|
||||
position_selection = sparsifier.position_found[
|
||||
image_id_CONST : image_id_CONST + 1, overlap_idcs_valid, :
|
||||
]
|
||||
n_elements = len(overlap_idcs_valid)
|
||||
print(f"OE-ProcessFrame: {n_elements} elements positioned!")
|
||||
|
||||
contour_shape = contour.shape
|
||||
|
||||
n_cut = min(position_selection.shape[-2], self.number_of_patches)
|
||||
|
||||
data_out = {
|
||||
"position_found": position_selection[:, :n_cut, :],
|
||||
"canvas_size": contour_shape,
|
||||
}
|
||||
if self.send_dictionaries:
|
||||
data_out["features"] = self.clocks
|
||||
data_out["phosphene"] = self.phosphene
|
||||
self.send_dictionaries = False
|
||||
|
||||
return data_out
|
||||
|
||||
def __del__(self):
|
||||
|
||||
print("OE-Delete: exiting gracefully!")
|
||||
self.cap.close_cam()
|
||||
try:
|
||||
cv2.destroyAllWindows()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
# TODO no output file
|
||||
# TODO detect end of file if input is video file
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
verbose = "cv2"
|
||||
source = 0 # "GoProWireless"
|
||||
frame_count = 20
|
||||
i_frame = 0
|
||||
|
||||
data_in = None
|
||||
|
||||
oe = OnlineEncoding(source=source, verbose=verbose)
|
||||
|
||||
# Loop over the frames
|
||||
while i_frame < frame_count:
|
||||
|
||||
i_frame += 1
|
||||
|
||||
if i_frame == (frame_count // 3):
|
||||
oe.dictionary["size_DVA"] = 0.5
|
||||
oe.apply_parameter_changes()
|
||||
|
||||
if i_frame == (frame_count * 2 // 3):
|
||||
oe.dictionary["size_DVA"] = 2.0
|
||||
oe.apply_parameter_changes()
|
||||
|
||||
data_out = oe.update(data_in)
|
||||
position_selection = data_out["position_found"]
|
||||
contour_shape = data_out["canvas_size"]
|
||||
|
||||
# SENDE/EMPANGSLOGIK:
|
||||
#
|
||||
# <- PACKET empfangen
|
||||
# Parameteränderungen?
|
||||
# in Instanz se übertragen
|
||||
# "apply_parameter_changes" aufrufen
|
||||
# folgende variablen in sendepacket:
|
||||
# se.clocks, se.phosphene, se.out_x, se.out_y
|
||||
# "process_frame"
|
||||
# folgende variablen in sendepacket:
|
||||
# position_selection, contour_shape
|
||||
# -> PACKET zurückgeben
|
||||
|
||||
# build the full image!
|
||||
image_clocks = BuildImage(
|
||||
canvas_size=contour_shape,
|
||||
dictionary=oe.clocks,
|
||||
position_found=position_selection,
|
||||
default_dtype=oe.default_dtype,
|
||||
torch_device=oe.torch_device,
|
||||
)
|
||||
# image_phosphenes = BuildImage(
|
||||
# canvas_size=contour.shape,
|
||||
# dictionary=dictionary_phosphene,
|
||||
# position_found=position_selection,
|
||||
# default_dtype=default_dtype,
|
||||
# torch_device=torch_device,
|
||||
# )
|
||||
|
||||
# normalize to range [0...1]
|
||||
m = image_clocks[0].max()
|
||||
if m == 0:
|
||||
m = 1
|
||||
image_clocks_normalized = image_clocks[0] / m
|
||||
|
||||
# embed into frame of desired output size
|
||||
out_torch = embed_image(
|
||||
image_clocks_normalized, out_height=oe.out_y, out_width=oe.out_x
|
||||
)
|
||||
|
||||
# show, if desired
|
||||
if verbose:
|
||||
if oe.control["show_percept"]:
|
||||
show_torch_frame(
|
||||
out_torch, title="Percept", cmap="gray", target=verbose
|
||||
)
|
||||
|
||||
# if output_file != None:
|
||||
# out_pixel = (
|
||||
# (out_torch * torch.ones([3, 1, 1]) * 255)
|
||||
# .type(dtype=torch.uint8)
|
||||
# .movedim(0, -1)
|
||||
# .numpy()
|
||||
# )
|
||||
# out.write(out_pixel)
|
||||
|
||||
del oe
|
||||
|
||||
# if output_file != None:
|
||||
# out.release()
|
||||
|
||||
# %%
|
299
processing_chain/OnlinePerception.py
Normal file
299
processing_chain/OnlinePerception.py
Normal file
|
@ -0,0 +1,299 @@
|
|||
#%%
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from processing_chain.BuildImage import BuildImage
|
||||
import matplotlib.pyplot as plt
|
||||
import time
|
||||
import cv2
|
||||
from gui.GUIEvents import GUIEvents
|
||||
from gui.GUICombiData import GUICombiData
|
||||
import tkinter as tk
|
||||
|
||||
|
||||
# TODO required?
|
||||
def embed_image(frame_torch, out_height, out_width, torch_device, init_value=0):
|
||||
|
||||
out_shape = torch.tensor(frame_torch.shape)
|
||||
|
||||
frame_width = frame_torch.shape[-1]
|
||||
frame_height = frame_torch.shape[-2]
|
||||
|
||||
frame_width_idx0 = max([0, (frame_width - out_width) // 2])
|
||||
frame_height_idx0 = max([0, (frame_height - out_height) // 2])
|
||||
|
||||
select_width = min([frame_width, out_width])
|
||||
select_height = min([frame_height, out_height])
|
||||
|
||||
out_shape[-1] = out_width
|
||||
out_shape[-2] = out_height
|
||||
|
||||
out_torch = init_value * torch.ones(tuple(out_shape), device=torch_device)
|
||||
|
||||
out_width_idx0 = max([0, (out_width - frame_width) // 2])
|
||||
out_height_idx0 = max([0, (out_height - frame_height) // 2])
|
||||
|
||||
out_torch[
|
||||
...,
|
||||
out_height_idx0 : (out_height_idx0 + select_height),
|
||||
out_width_idx0 : (out_width_idx0 + select_width),
|
||||
] = frame_torch[
|
||||
...,
|
||||
frame_height_idx0 : (frame_height_idx0 + select_height),
|
||||
frame_width_idx0 : (frame_width_idx0 + select_width),
|
||||
]
|
||||
|
||||
return out_torch
|
||||
|
||||
|
||||
class OnlinePerception:
|
||||
|
||||
# SELFies...
|
||||
#
|
||||
# torch_device, default_dtype (fixed)
|
||||
# canvas_size, features, phosphene, position_found (parameters)
|
||||
# percept
|
||||
#
|
||||
# root, events, confdata, use_gui
|
||||
#
|
||||
|
||||
def __init__(self, target, use_gui=False):
|
||||
|
||||
# CPU or GPU?
|
||||
self.default_dtype = torch.float32
|
||||
torch.set_default_dtype(self.default_dtype)
|
||||
if torch.cuda.is_available():
|
||||
self.torch_device = torch.device("cuda")
|
||||
else:
|
||||
self.torch_device = torch.device("cpu")
|
||||
|
||||
self.use_gui = use_gui
|
||||
if self.use_gui:
|
||||
self.root = tk.Tk()
|
||||
self.confdata: GUICombiData = GUICombiData()
|
||||
self.events = GUIEvents(tk_root=self.root, confdata=self.confdata)
|
||||
|
||||
# default dictionary parameters
|
||||
n_xy_canvas = 400
|
||||
n_xy_features = 41
|
||||
n_features = 32
|
||||
n_positions = 1
|
||||
|
||||
# populate internal parameters
|
||||
self.canvas_size = [1, 8, n_xy_canvas, n_xy_canvas]
|
||||
self.features = torch.rand(
|
||||
(n_features, 1, n_xy_features, n_xy_features), device=self.torch_device
|
||||
)
|
||||
self.phosphene = torch.ones(
|
||||
(1, 1, n_xy_features, n_xy_features), device=self.torch_device
|
||||
)
|
||||
self.position_found = torch.zeros((1, n_positions, 3), device=self.torch_device)
|
||||
self.position_found[0, :, 0] = torch.randint(n_features, (n_positions,))
|
||||
self.position_found[0, :, 1:] = torch.randint(n_xy_canvas, (n_positions, 2))
|
||||
self.percept = torch.zeros(
|
||||
(1, 1, n_xy_canvas, n_xy_canvas), device=self.torch_device
|
||||
)
|
||||
|
||||
self.p_FEATperSECperPOS = 3.0
|
||||
self.tau_SEC = 0.3 # percept time constant
|
||||
self.t = time.time()
|
||||
|
||||
self.target = target # display target
|
||||
self.display = target
|
||||
|
||||
self.selection = 0
|
||||
|
||||
return
|
||||
|
||||
def update(self, data_in: dict):
|
||||
|
||||
if data_in: # not NONE?
|
||||
print(f"Parameter update requested for keys {data_in.keys()}!")
|
||||
if "position_found" in data_in.keys():
|
||||
self.position_found = data_in["position_found"]
|
||||
if "canvas_size" in data_in.keys():
|
||||
self.canvas_size = data_in["canvas_size"]
|
||||
self.percept = embed_image(
|
||||
self.percept,
|
||||
self.canvas_size[-2],
|
||||
self.canvas_size[-1],
|
||||
torch_device=self.torch_device,
|
||||
init_value=0,
|
||||
)
|
||||
if "features" in data_in.keys():
|
||||
print(self.features.shape, self.features.max(), self.features.min())
|
||||
self.features = data_in["features"]
|
||||
self.features /= self.features.max()
|
||||
print(self.features.shape, self.features.max(), self.features.min())
|
||||
if "phosphene" in data_in.keys():
|
||||
self.phosphene = data_in["phosphene"]
|
||||
self.phosphene /= self.phosphene.max()
|
||||
|
||||
# parameters of (optional) GUI changed?
|
||||
data_out = None
|
||||
if self.use_gui:
|
||||
self.root.update()
|
||||
if not self.confdata.gui_running:
|
||||
data_out = {"exit": 42}
|
||||
# graceful exit
|
||||
else:
|
||||
if self.confdata.check_for_change() is True:
|
||||
|
||||
data_out = {
|
||||
"value": self.confdata.yolo_class.value,
|
||||
"sigma_kernel_DVA": self.confdata.contour_extraction.sigma_kernel_DVA,
|
||||
"n_orientations": self.confdata.contour_extraction.n_orientations,
|
||||
"size_DVA": self.confdata.alphabet.size_DVA,
|
||||
# "tau_SEC": self.confdata.alphabet.tau_SEC,
|
||||
# "p_FEATperSECperPOS": self.confdata.alphabet.p_FEATperSECperPOS,
|
||||
"sigma_width": self.confdata.alphabet.phosphene_sigma_width,
|
||||
"n_dir": self.confdata.alphabet.clocks_n_dir,
|
||||
"pointer_width": self.confdata.alphabet.clocks_pointer_width,
|
||||
"pointer_length": self.confdata.alphabet.clocks_pointer_length,
|
||||
"number_of_patches": self.confdata.sparsifier.number_of_patches,
|
||||
"use_exp_deadzone": self.confdata.sparsifier.use_exp_deadzone,
|
||||
"use_cutout_deadzone": self.confdata.sparsifier.use_cutout_deadzone,
|
||||
"size_exp_deadzone_DVA": self.confdata.sparsifier.size_exp_deadzone_DVA,
|
||||
"size_cutout_deadzone_DVA": self.confdata.sparsifier.size_cutout_deadzone_DVA,
|
||||
"enable_cam": self.confdata.output_mode.enable_cam,
|
||||
"enable_yolo": self.confdata.output_mode.enable_yolo,
|
||||
"enable_contour": self.confdata.output_mode.enable_contour,
|
||||
}
|
||||
print(data_out)
|
||||
|
||||
self.p_FEATperSECperPOS = self.confdata.alphabet.p_FEATperSECperPOS
|
||||
self.tau_SEC = self.confdata.alphabet.tau_SEC
|
||||
|
||||
# print(self.confdata.alphabet.selection)
|
||||
self.selection = self.confdata.alphabet.selection
|
||||
# print(f"Selektion gemacht {self.selection}")
|
||||
|
||||
# print(self.confdata.output_mode.enable_percept)
|
||||
if self.confdata.output_mode.enable_percept:
|
||||
self.display = self.target
|
||||
else:
|
||||
self.display = None
|
||||
if self.target == "cv2":
|
||||
cv2.destroyWindow("Percept")
|
||||
|
||||
self.confdata.reset_change_detector()
|
||||
|
||||
# keep track of time, yields dt
|
||||
t_new = time.time()
|
||||
dt = t_new - self.t
|
||||
self.t = t_new
|
||||
|
||||
# exponential decay
|
||||
self.percept *= torch.exp(-torch.tensor(dt / self.tau_SEC))
|
||||
|
||||
# new stimulation
|
||||
p_dt = self.p_FEATperSECperPOS * dt
|
||||
n_positions = self.position_found.shape[-2]
|
||||
position_select = torch.rand((n_positions,), device=self.torch_device) < p_dt
|
||||
n_select = position_select.sum()
|
||||
if n_select > 0:
|
||||
# print(f"Selektion ausgewertet {self.selection}")
|
||||
if self.selection:
|
||||
dictionary = self.features
|
||||
else:
|
||||
dictionary = self.phosphene
|
||||
percept_addon = BuildImage(
|
||||
canvas_size=self.canvas_size,
|
||||
dictionary=dictionary,
|
||||
position_found=self.position_found[:, position_select, :],
|
||||
default_dtype=self.default_dtype,
|
||||
torch_device=self.torch_device,
|
||||
)
|
||||
self.percept += percept_addon
|
||||
|
||||
# prepare for display
|
||||
display = self.percept[0, 0].cpu().numpy()
|
||||
if self.display == "cv2":
|
||||
|
||||
# display, and update
|
||||
cv2.namedWindow("Percept", cv2.WINDOW_NORMAL)
|
||||
cv2.imshow("Percept", display)
|
||||
q = cv2.waitKey(1)
|
||||
|
||||
if self.display == "pyplot":
|
||||
|
||||
# display, RUNS SLOWLY, just for testing
|
||||
plt.imshow(display, cmap="gray", vmin=0, vmax=1)
|
||||
plt.show()
|
||||
time.sleep(0.5)
|
||||
|
||||
return data_out
|
||||
|
||||
def __del__(self):
|
||||
print("Now I'm deleting me!")
|
||||
try:
|
||||
cv2.destroyAllWindows()
|
||||
except:
|
||||
pass
|
||||
if self.use_gui:
|
||||
print("...and the GUI!")
|
||||
try:
|
||||
if self.confdata.gui_running is True:
|
||||
self.root.destroy()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
use_gui = True
|
||||
|
||||
data_in = None
|
||||
op = OnlinePerception("cv2", use_gui=use_gui)
|
||||
|
||||
t_max = 40.0
|
||||
dt_update = 10.0
|
||||
t_update = dt_update
|
||||
|
||||
t = time.time()
|
||||
t0 = t
|
||||
while t - t0 < t_max:
|
||||
|
||||
data_out = op.update(data_in)
|
||||
data_in = None
|
||||
if data_out:
|
||||
print("Output given!")
|
||||
if "exit" in data_out.keys():
|
||||
break
|
||||
|
||||
t = time.time()
|
||||
if t - t0 > t_update:
|
||||
|
||||
# new canvas size
|
||||
n_xy_canvas = int(400 + torch.randint(200, (1,)))
|
||||
canvas_size = [1, 8, n_xy_canvas, n_xy_canvas]
|
||||
|
||||
# new features/phosphenes
|
||||
n_features = int(16 + torch.randint(16, (1,)))
|
||||
n_xy_features = int(31 + 2 * torch.randint(10, (1,)))
|
||||
features = torch.rand(
|
||||
(n_features, 1, n_xy_features, n_xy_features), device=op.torch_device
|
||||
)
|
||||
phosphene = torch.ones(
|
||||
(1, 1, n_xy_features, n_xy_features), device=op.torch_device
|
||||
)
|
||||
|
||||
# new positions
|
||||
n_positions = int(1 + torch.randint(3, (1,)))
|
||||
position_found = torch.zeros((1, n_positions, 3), device=op.torch_device)
|
||||
position_found[0, :, 0] = torch.randint(n_features, (n_positions,))
|
||||
position_found[0, :, 1] = torch.randint(n_xy_canvas, (n_positions,))
|
||||
position_found[0, :, 2] = torch.randint(n_xy_canvas, (n_positions,))
|
||||
t_update += dt_update
|
||||
data_in = {
|
||||
"position_found": position_found,
|
||||
"canvas_size": canvas_size,
|
||||
"features": features,
|
||||
"phosphene": phosphene,
|
||||
}
|
||||
|
||||
print("done!")
|
||||
del op
|
||||
|
||||
|
||||
# %%
|
237
processing_chain/PatchGenerator.py
Normal file
237
processing_chain/PatchGenerator.py
Normal file
|
@ -0,0 +1,237 @@
|
|||
# %%
|
||||
# PatchGenerator.py
|
||||
# ====================================
|
||||
# generates dictionaries (currently: phosphenes or clocks)
|
||||
#
|
||||
# Version V1.0, pre-07.03.2023:
|
||||
# no actual changes, is David's last code version...
|
||||
#
|
||||
# Version V1.1, 07.03.2023:
|
||||
# merged David's rebuild code (GUI capable)
|
||||
# (there was not really anything to merge :-))
|
||||
#
|
||||
|
||||
import torch
|
||||
import math
|
||||
|
||||
|
||||
class PatchGenerator:
|
||||
|
||||
pi: torch.Tensor
|
||||
torch_device: torch.device
|
||||
default_dtype = torch.float32
|
||||
|
||||
def __init__(self, torch_device: str = "cpu"):
|
||||
self.torch_device = torch.device(torch_device)
|
||||
|
||||
self.pi = torch.tensor(
|
||||
math.pi,
|
||||
device=self.torch_device,
|
||||
dtype=self.default_dtype,
|
||||
)
|
||||
|
||||
def alphabet_phosphene(
|
||||
self,
|
||||
sigma_width: float = 2.5,
|
||||
patch_size: int = 41,
|
||||
) -> torch.Tensor:
|
||||
|
||||
n: int = int(patch_size // 2)
|
||||
temp_grid: torch.Tensor = torch.arange(
|
||||
start=-n,
|
||||
end=n + 1,
|
||||
device=self.torch_device,
|
||||
dtype=self.default_dtype,
|
||||
)
|
||||
|
||||
x, y = torch.meshgrid(temp_grid, temp_grid, indexing="ij")
|
||||
|
||||
phosphene: torch.Tensor = torch.exp(-(x**2 + y**2) / (2 * sigma_width**2))
|
||||
phosphene /= phosphene.sum()
|
||||
|
||||
return phosphene.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
def alphabet_clocks(
|
||||
self,
|
||||
n_dir: int = 8,
|
||||
n_open: int = 4,
|
||||
n_filter: int = 4,
|
||||
patch_size: int = 41,
|
||||
segment_width: float = 2.5,
|
||||
segment_length: float = 15.0,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
|
||||
# for n_dir directions, there are n_open_max opening angles possible...
|
||||
assert n_dir % 2 == 0, "n_dir must be multiple of 2"
|
||||
n_open_max: int = n_dir // 2
|
||||
|
||||
# ...but this number can be reduced by integer multiples:
|
||||
assert (
|
||||
n_open_max % n_open == 0
|
||||
), "n_open_max = n_dir // 2 must be multiple of n_open"
|
||||
mul_open: int = n_open_max // n_open
|
||||
|
||||
# filter planes must be multiple of number of orientations implied by n_dir
|
||||
assert n_filter % n_open_max == 0, "n_filter must be multiple of (n_dir // 2)"
|
||||
mul_filter: int = n_filter // n_open_max
|
||||
# compute single segments
|
||||
segments: torch.Tensor = torch.zeros(
|
||||
(n_dir, patch_size, patch_size),
|
||||
device=self.torch_device,
|
||||
dtype=self.default_dtype,
|
||||
)
|
||||
dirs: torch.Tensor = (
|
||||
2
|
||||
* self.pi
|
||||
* torch.arange(
|
||||
start=0, end=n_dir, device=self.torch_device, dtype=self.default_dtype
|
||||
)
|
||||
/ n_dir
|
||||
)
|
||||
|
||||
for i_dir in range(n_dir):
|
||||
segments[i_dir] = self.draw_segment(
|
||||
patch_size=patch_size,
|
||||
phi=float(dirs[i_dir]),
|
||||
segment_length=segment_length,
|
||||
segment_width=segment_width,
|
||||
)
|
||||
|
||||
# compute patches from segments
|
||||
clocks = torch.zeros(
|
||||
(n_open, n_dir, patch_size, patch_size),
|
||||
device=self.torch_device,
|
||||
dtype=self.default_dtype,
|
||||
)
|
||||
clocks_filter = torch.zeros(
|
||||
(n_open, n_dir, n_filter, patch_size, patch_size),
|
||||
device=self.torch_device,
|
||||
dtype=self.default_dtype,
|
||||
)
|
||||
|
||||
for i_dir in range(n_dir):
|
||||
for i_open in range(n_open):
|
||||
|
||||
seg1 = segments[i_dir]
|
||||
seg2 = segments[(i_dir + (i_open + 1) * mul_open) % n_dir]
|
||||
clocks[i_open, i_dir] = torch.where(seg1 > seg2, seg1, seg2)
|
||||
|
||||
i_filter_seg1 = (i_dir * mul_filter) % n_filter
|
||||
i_filter_seg2 = (
|
||||
(i_dir + (i_open + 1) * mul_open) * mul_filter
|
||||
) % n_filter
|
||||
|
||||
if i_filter_seg1 == i_filter_seg2:
|
||||
clock_merged = torch.where(seg1 > seg2, seg1, seg2)
|
||||
clocks_filter[i_open, i_dir, i_filter_seg1] = clock_merged
|
||||
else:
|
||||
clocks_filter[i_open, i_dir, i_filter_seg1] = seg1
|
||||
clocks_filter[i_open, i_dir, i_filter_seg2] = seg2
|
||||
|
||||
clocks_filter = clocks_filter.reshape(
|
||||
(n_open * n_dir, n_filter, patch_size, patch_size)
|
||||
)
|
||||
clocks_filter = clocks_filter / clocks_filter.sum(
|
||||
axis=(-3, -2, -1), keepdims=True
|
||||
)
|
||||
clocks = clocks.reshape((n_open * n_dir, 1, patch_size, patch_size))
|
||||
clocks = clocks / clocks.sum(axis=(-2, -1), keepdims=True)
|
||||
|
||||
return clocks_filter, clocks, segments
|
||||
|
||||
def draw_segment(
|
||||
self,
|
||||
patch_size: float,
|
||||
phi: float,
|
||||
segment_length: float,
|
||||
segment_width: float,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# extension of patch beyond origin
|
||||
n: int = int(patch_size // 2)
|
||||
|
||||
# grid for the patch
|
||||
temp_grid: torch.Tensor = torch.arange(
|
||||
start=-n,
|
||||
end=n + 1,
|
||||
device=self.torch_device,
|
||||
dtype=self.default_dtype,
|
||||
)
|
||||
|
||||
x, y = torch.meshgrid(temp_grid, temp_grid, indexing="ij")
|
||||
|
||||
r: torch.Tensor = torch.cat((x.unsqueeze(-1), y.unsqueeze(-1)), dim=2)
|
||||
|
||||
# target orientation of basis segment
|
||||
phi90: torch.Tensor = phi + self.pi / 2
|
||||
|
||||
# vector pointing to the ending point of segment (direction)
|
||||
#
|
||||
# when result is displayed with plt.imshow(segment),
|
||||
# phi=0 points to the right, and increasing phi rotates
|
||||
# the segment counterclockwise
|
||||
#
|
||||
e: torch.Tensor = torch.tensor(
|
||||
[torch.cos(phi90), torch.sin(phi90)],
|
||||
device=self.torch_device,
|
||||
dtype=self.default_dtype,
|
||||
)
|
||||
|
||||
# tangential vectors
|
||||
e_tang: torch.Tensor = e.flip(dims=[0]) * torch.tensor(
|
||||
[-1, 1], device=self.torch_device, dtype=self.default_dtype
|
||||
)
|
||||
|
||||
# compute distances to segment: parallel/tangential
|
||||
d = torch.maximum(
|
||||
torch.zeros(
|
||||
(r.shape[0], r.shape[1]),
|
||||
device=self.torch_device,
|
||||
dtype=self.default_dtype,
|
||||
),
|
||||
torch.abs(
|
||||
(r * e.unsqueeze(0).unsqueeze(0)).sum(dim=-1) - segment_length / 2
|
||||
)
|
||||
- segment_length / 2,
|
||||
)
|
||||
|
||||
d_tang = (r * e_tang.unsqueeze(0).unsqueeze(0)).sum(dim=-1)
|
||||
|
||||
# compute minimum distance to any of the two pointers
|
||||
dr = torch.sqrt(d**2 + d_tang**2)
|
||||
|
||||
segment = torch.exp(-(dr**2) / 2 / segment_width**2)
|
||||
segment = segment / segment.sum()
|
||||
|
||||
return segment
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
pg = PatchGenerator()
|
||||
|
||||
patch1 = pg.draw_segment(
|
||||
patch_size=81, phi=math.pi / 4, segment_length=30, segment_width=5
|
||||
)
|
||||
patch2 = pg.draw_segment(patch_size=81, phi=0, segment_length=30, segment_width=5)
|
||||
plt.imshow(torch.where(patch1 > patch2, patch1, patch2).cpu())
|
||||
|
||||
phos = pg.alphabet_phosphene()
|
||||
plt.imshow(phos[0, 0].cpu())
|
||||
|
||||
n_filter = 8
|
||||
n_dir = 8
|
||||
clocks_filter, clocks, segments = pg.alphabet_clocks(n_dir=n_dir, n_filter=n_filter)
|
||||
|
||||
n_features = clocks_filter.shape[0]
|
||||
print(n_features, "clock features generated!")
|
||||
for i_feature in range(n_features):
|
||||
for i_filter in range(n_filter):
|
||||
plt.subplot(1, n_filter, i_filter + 1)
|
||||
plt.imshow(clocks_filter[i_feature, i_filter].cpu())
|
||||
plt.title("Feature #{}, Dir #{}".format(i_feature, i_filter))
|
||||
plt.show()
|
||||
|
||||
|
||||
# %%
|
418
processing_chain/Sparsifier.py
Normal file
418
processing_chain/Sparsifier.py
Normal file
|
@ -0,0 +1,418 @@
|
|||
# Sparsifier.py
|
||||
# ====================================
|
||||
# matches dictionary patches to contour images
|
||||
#
|
||||
# Version V1.0, 07.03.2023:
|
||||
# slight parameter scaling changes to David's last code version...
|
||||
#
|
||||
# Version V1.1, 07.03.2023:
|
||||
# merged David's rebuild code (GUI capable)
|
||||
#
|
||||
|
||||
import torch
|
||||
import math
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import time
|
||||
|
||||
|
||||
class Sparsifier(torch.nn.Module):
|
||||
|
||||
dictionary_filter_fft: torch.Tensor | None = None
|
||||
dictionary_filter: torch.Tensor
|
||||
dictionary: torch.Tensor
|
||||
|
||||
parameter_ready: bool
|
||||
dictionary_ready: bool
|
||||
|
||||
contour_convolved_sum: torch.Tensor | None = None
|
||||
use_map: torch.Tensor | None = None
|
||||
position_found: torch.Tensor | None = None
|
||||
|
||||
size_exp_deadzone: float
|
||||
|
||||
number_of_patches: int
|
||||
padding_deadzone_size_x: int
|
||||
padding_deadzone_size_y: int
|
||||
|
||||
plot_use_map: bool
|
||||
|
||||
deadzone_exp: bool
|
||||
deadzone_hard_cutout: int # 0 = not, 1 = round, 2 = box
|
||||
deadzone_hard_cutout_size: float
|
||||
|
||||
dictionary_prior: torch.Tensor | None
|
||||
|
||||
pi: torch.Tensor
|
||||
torch_device: torch.device
|
||||
default_dtype = torch.float32
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dictionary_filter: torch.Tensor,
|
||||
dictionary: torch.Tensor,
|
||||
dictionary_prior: torch.Tensor | None = None,
|
||||
number_of_patches: int = 10, #
|
||||
size_exp_deadzone: float = 1.0, #
|
||||
padding_deadzone_size_x: int = 0, #
|
||||
padding_deadzone_size_y: int = 0, #
|
||||
plot_use_map: bool = False,
|
||||
deadzone_exp: bool = True, #
|
||||
deadzone_hard_cutout: int = 1, # 0 = not, 1 = round
|
||||
deadzone_hard_cutout_size: float = 1.0, #
|
||||
torch_device: str = "cpu",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dictionary_ready = False
|
||||
self.parameter_ready = False
|
||||
|
||||
self.torch_device = torch.device(torch_device)
|
||||
|
||||
self.pi = torch.tensor(
|
||||
math.pi,
|
||||
device=self.torch_device,
|
||||
dtype=self.default_dtype,
|
||||
)
|
||||
|
||||
self.plot_use_map = plot_use_map
|
||||
|
||||
self.update_parameter(
|
||||
number_of_patches,
|
||||
size_exp_deadzone,
|
||||
padding_deadzone_size_x,
|
||||
padding_deadzone_size_y,
|
||||
deadzone_exp,
|
||||
deadzone_hard_cutout,
|
||||
deadzone_hard_cutout_size,
|
||||
)
|
||||
|
||||
self.update_dictionary(dictionary_filter, dictionary, dictionary_prior)
|
||||
|
||||
def update_parameter(
|
||||
self,
|
||||
number_of_patches: int,
|
||||
size_exp_deadzone: float,
|
||||
padding_deadzone_size_x: int,
|
||||
padding_deadzone_size_y: int,
|
||||
deadzone_exp: bool,
|
||||
deadzone_hard_cutout: int,
|
||||
deadzone_hard_cutout_size: float,
|
||||
) -> None:
|
||||
|
||||
self.parameter_ready = False
|
||||
|
||||
assert size_exp_deadzone > 0.0
|
||||
assert number_of_patches > 0
|
||||
assert padding_deadzone_size_x >= 0
|
||||
assert padding_deadzone_size_y >= 0
|
||||
|
||||
self.number_of_patches = number_of_patches
|
||||
self.size_exp_deadzone = size_exp_deadzone
|
||||
self.padding_deadzone_size_x = padding_deadzone_size_x
|
||||
self.padding_deadzone_size_y = padding_deadzone_size_y
|
||||
|
||||
self.deadzone_exp = deadzone_exp
|
||||
self.deadzone_hard_cutout = deadzone_hard_cutout
|
||||
self.deadzone_hard_cutout_size = deadzone_hard_cutout_size
|
||||
|
||||
self.parameter_ready = True
|
||||
|
||||
def update_dictionary(
|
||||
self,
|
||||
dictionary_filter: torch.Tensor,
|
||||
dictionary: torch.Tensor,
|
||||
dictionary_prior: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
|
||||
self.dictionary_ready = False
|
||||
|
||||
assert dictionary_filter.ndim == 4
|
||||
assert dictionary.ndim == 4
|
||||
|
||||
# Odd number of pixels. Please!
|
||||
assert (dictionary_filter.shape[-2] % 2) == 1
|
||||
assert (dictionary_filter.shape[-1] % 2) == 1
|
||||
|
||||
self.dictionary_filter = dictionary_filter.detach().clone()
|
||||
self.dictionary = dictionary
|
||||
|
||||
if dictionary_prior is not None:
|
||||
assert dictionary_prior.ndim == 1
|
||||
assert dictionary_prior.shape[0] == dictionary_filter.shape[0]
|
||||
|
||||
self.dictionary_prior = dictionary_prior
|
||||
|
||||
self.dictionary_filter_fft = None
|
||||
|
||||
self.dictionary_ready = True
|
||||
|
||||
def fourier_convolution(self, contour: torch.Tensor):
|
||||
# Pattern, X, Y
|
||||
assert contour.dim() == 4
|
||||
assert self.dictionary_filter is not None
|
||||
assert contour.shape[-2] >= self.dictionary_filter.shape[-2]
|
||||
assert contour.shape[-1] >= self.dictionary_filter.shape[-1]
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
contour_fft = torch.fft.rfft2(contour, dim=(-2, -1))
|
||||
|
||||
t1 = time.time()
|
||||
|
||||
if (
|
||||
(self.dictionary_filter_fft is None)
|
||||
or (self.dictionary_filter_fft.dim() != 4)
|
||||
or (self.dictionary_filter_fft.shape[-2] != contour.shape[-2])
|
||||
or (self.dictionary_filter_fft.shape[-1] != contour.shape[-1])
|
||||
):
|
||||
dictionary_padded: torch.Tensor = torch.zeros(
|
||||
(
|
||||
self.dictionary_filter.shape[0],
|
||||
self.dictionary_filter.shape[1],
|
||||
contour.shape[-2],
|
||||
contour.shape[-1],
|
||||
),
|
||||
device=self.torch_device,
|
||||
dtype=self.default_dtype,
|
||||
)
|
||||
dictionary_padded[
|
||||
:,
|
||||
:,
|
||||
: self.dictionary_filter.shape[-2],
|
||||
: self.dictionary_filter.shape[-1],
|
||||
] = self.dictionary_filter.flip(dims=[-2, -1])
|
||||
|
||||
dictionary_padded = dictionary_padded.roll(
|
||||
(
|
||||
-(self.dictionary_filter.shape[-2] // 2),
|
||||
-(self.dictionary_filter.shape[-1] // 2),
|
||||
),
|
||||
(-2, -1),
|
||||
)
|
||||
self.dictionary_filter_fft = torch.fft.rfft2(
|
||||
dictionary_padded, dim=(-2, -1)
|
||||
)
|
||||
|
||||
t2 = time.time()
|
||||
|
||||
assert self.dictionary_filter_fft is not None
|
||||
|
||||
# dimension order for multiplication: [pat, feat, ori, x, y]
|
||||
self.contour_convolved_sum = torch.fft.irfft2(
|
||||
contour_fft.unsqueeze(1) * self.dictionary_filter_fft.unsqueeze(0),
|
||||
dim=(-2, -1),
|
||||
).sum(dim=2)
|
||||
# --> [pat, feat, x, y]
|
||||
t3 = time.time()
|
||||
|
||||
self.use_map: torch.Tensor = torch.ones(
|
||||
(contour.shape[0], 1, contour.shape[-2], contour.shape[-1]),
|
||||
device=self.torch_device,
|
||||
dtype=self.default_dtype,
|
||||
)
|
||||
|
||||
if self.padding_deadzone_size_x > 0:
|
||||
self.use_map[:, 0, : self.padding_deadzone_size_x, :] = 0.0
|
||||
self.use_map[:, 0, -self.padding_deadzone_size_x :, :] = 0.0
|
||||
|
||||
if self.padding_deadzone_size_y > 0:
|
||||
self.use_map[:, 0, :, : self.padding_deadzone_size_y] = 0.0
|
||||
self.use_map[:, 0, :, -self.padding_deadzone_size_y :] = 0.0
|
||||
|
||||
t4 = time.time()
|
||||
print(
|
||||
"Sparsifier-convol {:.3f}s: fft-img-{:.3f}s, fft-fil-{:.3f}s, convol-{:.3f}s, pad-{:.3f}s".format(
|
||||
t4 - t0, t1 - t0, t2 - t1, t3 - t2, t4 - t3
|
||||
)
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
def find_next_element(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert self.use_map is not None
|
||||
assert self.contour_convolved_sum is not None
|
||||
assert self.dictionary_filter is not None
|
||||
|
||||
# store feature index, x pos and y pos (3 entries)
|
||||
position_found = torch.zeros(
|
||||
(self.contour_convolved_sum.shape[0], 3),
|
||||
dtype=torch.int64,
|
||||
device=self.torch_device,
|
||||
)
|
||||
|
||||
overlap_found = torch.zeros(
|
||||
(self.contour_convolved_sum.shape[0], 2),
|
||||
device=self.torch_device,
|
||||
dtype=self.default_dtype,
|
||||
)
|
||||
|
||||
for pattern_id in range(0, self.contour_convolved_sum.shape[0]):
|
||||
|
||||
t0 = time.time()
|
||||
# search_tensor: torch.Tensor = (
|
||||
# self.contour_convolved[pattern_id] * self.use_map[pattern_id]
|
||||
# ).sum(dim=1)
|
||||
search_tensor: torch.Tensor = (
|
||||
self.contour_convolved_sum[pattern_id] * self.use_map[pattern_id]
|
||||
)
|
||||
|
||||
t1 = time.time()
|
||||
|
||||
if self.dictionary_prior is not None:
|
||||
search_tensor *= self.dictionary_prior.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
t2 = time.time()
|
||||
|
||||
temp, max_0 = search_tensor.max(dim=0)
|
||||
temp, max_1 = temp.max(dim=0)
|
||||
temp_overlap, max_2 = temp.max(dim=0)
|
||||
|
||||
position_base_3: int = int(max_2)
|
||||
position_base_2: int = int(max_1[position_base_3])
|
||||
position_base_1: int = int(max_0[position_base_2, position_base_3])
|
||||
position_base_0: int = int(pattern_id)
|
||||
|
||||
position_found[position_base_0, 0] = position_base_1
|
||||
position_found[position_base_0, 1] = position_base_2
|
||||
position_found[position_base_0, 2] = position_base_3
|
||||
|
||||
overlap_found[pattern_id, 0] = temp_overlap
|
||||
overlap_found[pattern_id, 1] = self.contour_convolved_sum[
|
||||
position_base_0, position_base_1, position_base_2, position_base_3
|
||||
]
|
||||
|
||||
t3 = time.time()
|
||||
|
||||
x_max: int = int(self.contour_convolved_sum.shape[-2])
|
||||
y_max: int = int(self.contour_convolved_sum.shape[-1])
|
||||
|
||||
# Center arround the max position
|
||||
x_0 = int(-position_base_2)
|
||||
x_1 = int(x_max - position_base_2)
|
||||
y_0 = int(-position_base_3)
|
||||
y_1 = int(y_max - position_base_3)
|
||||
|
||||
temp_grid_x: torch.Tensor = torch.arange(
|
||||
start=x_0,
|
||||
end=x_1,
|
||||
device=self.torch_device,
|
||||
dtype=self.default_dtype,
|
||||
)
|
||||
|
||||
temp_grid_y: torch.Tensor = torch.arange(
|
||||
start=y_0,
|
||||
end=y_1,
|
||||
device=self.torch_device,
|
||||
dtype=self.default_dtype,
|
||||
)
|
||||
|
||||
x, y = torch.meshgrid(temp_grid_x, temp_grid_y, indexing="ij")
|
||||
|
||||
# discourage the neigbourhood around for the future
|
||||
if self.deadzone_exp is True:
|
||||
self.temp_map: torch.Tensor = 1.0 - torch.exp(
|
||||
-(x**2 + y**2) / (2 * self.size_exp_deadzone**2)
|
||||
)
|
||||
else:
|
||||
self.temp_map = torch.ones(
|
||||
(x.shape[0], x.shape[1]),
|
||||
device=self.torch_device,
|
||||
dtype=self.default_dtype,
|
||||
)
|
||||
|
||||
assert self.deadzone_hard_cutout >= 0
|
||||
assert self.deadzone_hard_cutout <= 1
|
||||
assert self.deadzone_hard_cutout_size >= 0
|
||||
|
||||
if self.deadzone_hard_cutout == 1:
|
||||
temp = x**2 + y**2
|
||||
self.temp_map *= torch.where(
|
||||
temp <= self.deadzone_hard_cutout_size**2,
|
||||
0.0,
|
||||
1.0,
|
||||
)
|
||||
|
||||
self.use_map[position_base_0, 0, :, :] *= self.temp_map
|
||||
|
||||
t4 = time.time()
|
||||
|
||||
# Only for keeping it in the float32 range:
|
||||
self.use_map[position_base_0, 0, :, :] /= self.use_map[
|
||||
position_base_0, 0, :, :
|
||||
].max()
|
||||
|
||||
t5 = time.time()
|
||||
|
||||
# print(
|
||||
# "{}, {}, {}, {}, {}".format(t1 - t0, t2 - t1, t3 - t2, t4 - t3, t5 - t4)
|
||||
# )
|
||||
|
||||
return (
|
||||
position_found,
|
||||
overlap_found,
|
||||
torch.tensor((t1 - t0, t2 - t1, t3 - t2, t4 - t3, t5 - t4)),
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> None:
|
||||
assert self.number_of_patches > 0
|
||||
|
||||
assert self.dictionary_ready is True
|
||||
assert self.parameter_ready is True
|
||||
|
||||
self.position_found = torch.zeros(
|
||||
(input.shape[0], self.number_of_patches, 3),
|
||||
dtype=torch.int64,
|
||||
device=self.torch_device,
|
||||
)
|
||||
self.overlap_found = torch.zeros(
|
||||
(input.shape[0], self.number_of_patches, 2),
|
||||
device=self.torch_device,
|
||||
dtype=self.default_dtype,
|
||||
)
|
||||
|
||||
# Folding the images and the dictionary
|
||||
self.fourier_convolution(input)
|
||||
|
||||
t0 = time.time()
|
||||
dt = torch.tensor([0, 0, 0, 0, 0])
|
||||
for patch_id in range(0, self.number_of_patches):
|
||||
(
|
||||
self.position_found[:, patch_id, :],
|
||||
self.overlap_found[:, patch_id, :],
|
||||
dt_tmp,
|
||||
) = self.find_next_element()
|
||||
|
||||
dt = dt + dt_tmp
|
||||
|
||||
if self.plot_use_map is True:
|
||||
|
||||
assert self.position_found.shape[0] == 1
|
||||
assert self.use_map is not None
|
||||
|
||||
print("Position Saliency:")
|
||||
print(self.overlap_found[0, :, 0])
|
||||
print("Overlap with Contour Image:")
|
||||
print(self.overlap_found[0, :, 1])
|
||||
plt.subplot(1, 2, 1)
|
||||
plt.imshow(self.use_map[0, 0].cpu(), cmap="gray")
|
||||
plt.title(f"patch-number: {patch_id}")
|
||||
plt.axis("off")
|
||||
plt.colorbar(shrink=0.5)
|
||||
plt.subplot(1, 2, 2)
|
||||
plt.imshow(
|
||||
(self.use_map[0, 0] * input[0].sum(dim=0)).cpu(), cmap="gray"
|
||||
)
|
||||
plt.title(f"patch-number: {patch_id}")
|
||||
plt.axis("off")
|
||||
plt.colorbar(shrink=0.5)
|
||||
plt.show()
|
||||
|
||||
# self.overlap_found /= self.overlap_found.max(dim=1, keepdim=True)[0]
|
||||
t1 = time.time()
|
||||
print(
|
||||
"Sparsifier-forward {:.3f}s: usemap-{:.3f}s, prior-{:.3f}s, findmax-{:.3f}s, notch-{:.3f}s, norm-{:.3f}s: (sum-{:.3f})".format(
|
||||
t1 - t0, dt[0], dt[1], dt[2], dt[3], dt[4], dt.sum()
|
||||
)
|
||||
)
|
325
processing_chain/WebCam.py
Normal file
325
processing_chain/WebCam.py
Normal file
|
@ -0,0 +1,325 @@
|
|||
#%%
|
||||
#
|
||||
# WebCam.py
|
||||
# ========================================================
|
||||
# interface to cv2 for using a webcam or for reading from
|
||||
# a video file
|
||||
#
|
||||
# Version 1.0, before 30.03.2023:
|
||||
# written by David...
|
||||
#
|
||||
# Version 1.1, 30.03.2023:
|
||||
# thrown out test image
|
||||
# added test code
|
||||
# added code to "capture" from video file
|
||||
#
|
||||
# Version 1.2, 20.06.2023:
|
||||
# added code to capture wirelessly from "GoPro" camera
|
||||
# added test code for "GoPro"
|
||||
#
|
||||
# Version 1.3, 23.06.2023
|
||||
# test display in pyplot or cv2
|
||||
#
|
||||
# Version 1.4, 28.06.2023
|
||||
# solved Windows DirectShow problem
|
||||
#
|
||||
|
||||
|
||||
from PIL import Image
|
||||
import os
|
||||
import cv2
|
||||
import torch
|
||||
import torchvision as tv
|
||||
|
||||
# for GoPro
|
||||
import time
|
||||
import socket
|
||||
|
||||
try:
|
||||
print("Trying to import GoPro modules...")
|
||||
from goprocam import GoProCamera, constants
|
||||
|
||||
gopro_exists = True
|
||||
except:
|
||||
print("...not found, continuing!")
|
||||
gopro_exists = False
|
||||
import platform
|
||||
|
||||
|
||||
class WebCam:
|
||||
|
||||
# test_pattern: torch.Tensor
|
||||
# test_pattern_gray: torch.Tensor
|
||||
|
||||
source: int
|
||||
framesize: tuple[int, int]
|
||||
fps: float
|
||||
cap_frame_width: int
|
||||
cap_frame_height: int
|
||||
cap_fps: float
|
||||
webcam_is_ready: bool
|
||||
cap_frames_available: int
|
||||
|
||||
default_dtype = torch.float32
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source: str | int = 1,
|
||||
framesize: tuple[int, int] = (720, 1280), # (1920, 1080), # (640, 480),
|
||||
fps: float = 30.0,
|
||||
):
|
||||
super().__init__()
|
||||
assert fps > 0
|
||||
|
||||
self.source = source
|
||||
self.framesize = framesize
|
||||
self.cap = None
|
||||
self.fps = fps
|
||||
self.webcam_is_ready = False
|
||||
|
||||
def open_cam(self) -> bool:
|
||||
if self.cap is not None:
|
||||
self.cap.release()
|
||||
self.cap = None
|
||||
self.webcam_is_ready = False
|
||||
|
||||
# handle GoPro...
|
||||
if self.source == "GoProWireless":
|
||||
|
||||
if not gopro_exists:
|
||||
print("No GoPro driver/support!")
|
||||
self.webcam_is_ready = False
|
||||
return False
|
||||
|
||||
print("GoPro: Starting access")
|
||||
gpCam = GoProCamera.GoPro()
|
||||
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
print("GoPro: Socket created!!!")
|
||||
|
||||
self.t = time.time()
|
||||
gpCam.livestream("start")
|
||||
gpCam.video_settings(res="1080p", fps="30")
|
||||
gpCam.gpControlSet(
|
||||
constants.Stream.WINDOW_SIZE, constants.Stream.WindowSize.R720
|
||||
)
|
||||
|
||||
self.cap = cv2.VideoCapture("udp://10.5.5.9:8554", cv2.CAP_FFMPEG)
|
||||
print("GoPro: Video capture started!!!")
|
||||
|
||||
else:
|
||||
self.sock = None
|
||||
self.t = -1
|
||||
if platform.system().lower() == "windows":
|
||||
self.cap = cv2.VideoCapture(self.source, cv2.CAP_DSHOW)
|
||||
else:
|
||||
self.cap = cv2.VideoCapture(self.source)
|
||||
print("Normal capture started!!!")
|
||||
|
||||
assert self.cap is not None
|
||||
|
||||
if self.cap.isOpened() is not True:
|
||||
self.webcam_is_ready = False
|
||||
return False
|
||||
|
||||
if type(self.source) != str:
|
||||
self.cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*"MJPG"))
|
||||
self.cap.set(cv2.CAP_PROP_FPS, self.fps)
|
||||
self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, self.framesize[0])
|
||||
self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.framesize[1])
|
||||
self.cap_frames_available = None
|
||||
else:
|
||||
# ...for files and GoPro...
|
||||
self.cap_frames_available = self.cap.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||
|
||||
self.cap_frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
self.cap_frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
self.cap_fps = float(self.cap.get(cv2.CAP_PROP_FPS))
|
||||
|
||||
print(
|
||||
(
|
||||
f"Capturing or reading with: {self.cap_frame_width:.0f} x "
|
||||
f"{self.cap_frame_height:.0f} @ "
|
||||
f"{self.cap_fps:.1f}."
|
||||
)
|
||||
)
|
||||
self.webcam_is_ready = True
|
||||
return True
|
||||
|
||||
def close_cam(self) -> None:
|
||||
if self.cap is not None:
|
||||
self.cap.release()
|
||||
|
||||
def get_frame(self) -> torch.Tensor | None:
|
||||
if self.cap is None:
|
||||
return None
|
||||
else:
|
||||
if self.sock:
|
||||
dt_min = 0.015
|
||||
success, frame = self.cap.read()
|
||||
t_next = time.time()
|
||||
t_prev = t_next
|
||||
while t_next - t_prev < dt_min:
|
||||
t_prev = t_next
|
||||
success, frame = self.cap.read()
|
||||
t_next = time.time()
|
||||
|
||||
if self.t >= 0:
|
||||
if time.time() - self.t > 2.5:
|
||||
print("GoPro-Stream must be kept awake!...")
|
||||
self.sock.sendto(
|
||||
"_GPHD_:0:0:2:0.000000\n".encode(), ("10.5.5.9", 8554)
|
||||
)
|
||||
self.t = time.time()
|
||||
else:
|
||||
success, frame = self.cap.read()
|
||||
|
||||
if success is False:
|
||||
self.webcam_is_ready = False
|
||||
return None
|
||||
|
||||
output = (
|
||||
torch.tensor(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||
.movedim(-1, 0)
|
||||
.type(dtype=self.default_dtype)
|
||||
/ 255.0
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
# for testing the code if module is executed from command line
|
||||
if __name__ == "__main__":
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import cv2
|
||||
|
||||
TEST_FILEREAD = False
|
||||
TEST_WEBCAM = True
|
||||
TEST_GOPRO = False
|
||||
|
||||
display = "cv2"
|
||||
n_capture = 200
|
||||
delay_capture = 0.001
|
||||
|
||||
print("Testing the WebCam interface")
|
||||
|
||||
if TEST_FILEREAD:
|
||||
|
||||
file_name = "level1.mp4"
|
||||
|
||||
# open
|
||||
print("Opening video file")
|
||||
w = WebCam(file_name)
|
||||
if not w.open_cam():
|
||||
raise OSError(f"Opening file with name {file_name} failed!")
|
||||
|
||||
# print information
|
||||
print(
|
||||
f"Frame size {w.cap_frame_width} x {w.cap_frame_height} at {w.cap_fps} fps."
|
||||
)
|
||||
|
||||
# capture three frames and show them
|
||||
for i in range(min([n_capture, w.cap_frames_available])): # TODO: available?
|
||||
frame = w.get_frame()
|
||||
if frame == None:
|
||||
raise OSError(f"Can not get frame from file with name {file_name}!")
|
||||
print(f"frame {i} has shape {frame.shape}")
|
||||
|
||||
frame_numpy = (frame.movedim(0, -1) * 255).type(dtype=torch.uint8).numpy()
|
||||
|
||||
if display == "pyplot":
|
||||
plt.imshow(frame_numpy)
|
||||
plt.show()
|
||||
if display == "cv2":
|
||||
cv2.imshow("File", frame_numpy[:, :, (2, 1, 0)])
|
||||
cv2.waitKey(1)
|
||||
time.sleep(delay_capture)
|
||||
|
||||
# close
|
||||
print("Closing file")
|
||||
w.close_cam()
|
||||
|
||||
if TEST_WEBCAM:
|
||||
|
||||
camera_index = 0
|
||||
|
||||
# open
|
||||
print("Opening camera")
|
||||
w = WebCam(camera_index)
|
||||
if not w.open_cam():
|
||||
raise OSError(f"Opening web cam with index {camera_index} failed!")
|
||||
|
||||
# print information
|
||||
print(
|
||||
f"Frame size {w.cap_frame_width} x {w.cap_frame_height} at {w.cap_fps} fps."
|
||||
)
|
||||
|
||||
# capture three frames and show them
|
||||
for i in range(n_capture):
|
||||
frame = w.get_frame()
|
||||
if frame == None:
|
||||
raise OSError(
|
||||
f"Can not get frame from camera with index {camera_index}!"
|
||||
)
|
||||
print(f"frame {i} has shape {frame.shape}")
|
||||
|
||||
frame_numpy = (frame.movedim(0, -1) * 255).type(dtype=torch.uint8).numpy()
|
||||
if display == "pyplot":
|
||||
plt.imshow(frame_numpy)
|
||||
plt.show()
|
||||
if display == "cv2":
|
||||
cv2.imshow("WebCam", frame_numpy[:, :, (2, 1, 0)])
|
||||
cv2.waitKey(1)
|
||||
time.sleep(delay_capture)
|
||||
|
||||
# close
|
||||
print("Closing camera")
|
||||
w.close_cam()
|
||||
|
||||
if TEST_GOPRO:
|
||||
|
||||
camera_name = "GoProWireless"
|
||||
|
||||
# open
|
||||
print("Opening GoPro")
|
||||
w = WebCam(camera_name)
|
||||
if not w.open_cam():
|
||||
raise OSError(f"Opening GoPro with index {camera_index} failed!")
|
||||
|
||||
# print information
|
||||
print(
|
||||
f"Frame size {w.cap_frame_width} x {w.cap_frame_height} at {w.cap_fps} fps."
|
||||
)
|
||||
w.cap.set(cv2.CAP_PROP_BUFFERSIZE, 0)
|
||||
|
||||
# capture three frames and show them
|
||||
# print("Empty Buffer...")
|
||||
# for i in range(500):
|
||||
# print(i)
|
||||
# frame = w.get_frame()
|
||||
# print("Buffer Emptied...")
|
||||
|
||||
for i in range(n_capture):
|
||||
frame = w.get_frame()
|
||||
if frame == None:
|
||||
raise OSError(
|
||||
f"Can not get frame from camera with index {camera_index}!"
|
||||
)
|
||||
print(f"frame {i} has shape {frame.shape}")
|
||||
|
||||
frame_numpy = (frame.movedim(0, -1) * 255).type(dtype=torch.uint8).numpy()
|
||||
if display == "pyplot":
|
||||
plt.imshow(frame_numpy)
|
||||
plt.show()
|
||||
if display == "cv2":
|
||||
cv2.imshow("GoPro", frame_numpy[:, :, (2, 1, 0)])
|
||||
cv2.waitKey(1)
|
||||
time.sleep(delay_capture)
|
||||
|
||||
# close
|
||||
print("Closing Cam/File/GoPro")
|
||||
w.close_cam()
|
||||
|
||||
if display == "cv2":
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
# %%
|
402
processing_chain/Yolo5Segmentation.py
Normal file
402
processing_chain/Yolo5Segmentation.py
Normal file
|
@ -0,0 +1,402 @@
|
|||
import torch
|
||||
import os
|
||||
import torchvision as tv
|
||||
import time
|
||||
|
||||
from models.yolo import Detect, Model
|
||||
|
||||
# Warning: The line "#matplotlib.use('Agg') # for writing to files only"
|
||||
# in utils/plots.py prevents the further use of matplotlib
|
||||
|
||||
|
||||
class Yolo5Segmentation(torch.nn.Module):
|
||||
|
||||
default_dtype = torch.float32
|
||||
|
||||
conf: float = 0.25 # NMS confidence threshold
|
||||
iou: float = 0.45 # NMS IoU threshold
|
||||
agnostic: bool = False # NMS class-agnostic
|
||||
multi_label: bool = False # NMS multiple labels per box
|
||||
max_det: int = 1000 # maximum number of detections per image
|
||||
number_of_maps: int = 32
|
||||
imgsz: tuple[int, int] = (640, 640) # inference size (height, width)
|
||||
|
||||
device: torch.device = torch.device("cpu")
|
||||
weigh_path: str = ""
|
||||
|
||||
class_names: dict
|
||||
stride: int
|
||||
|
||||
found_class_id: torch.Tensor | None = None
|
||||
|
||||
def __init__(self, mode: int = 3, torch_device: str = "cpu"):
|
||||
super().__init__()
|
||||
|
||||
model_pretrained_path: str = "segment_pretrained"
|
||||
assert mode < 5
|
||||
assert mode >= 0
|
||||
if mode == 0:
|
||||
model_pretrained_weights: str = "yolov5n-seg.pt"
|
||||
elif mode == 1:
|
||||
model_pretrained_weights = "yolov5s-seg.pt"
|
||||
elif mode == 2:
|
||||
model_pretrained_weights = "yolov5m-seg.pt"
|
||||
elif mode == 3:
|
||||
model_pretrained_weights = "yolov5l-seg.pt"
|
||||
elif mode == 4:
|
||||
model_pretrained_weights = "yolov5x-seg.pt"
|
||||
|
||||
self.weigh_path = os.path.join(model_pretrained_path, model_pretrained_weights)
|
||||
|
||||
self.device = torch.device(torch_device)
|
||||
|
||||
self.network = self.attempt_load(
|
||||
self.weigh_path, device=self.device, inplace=True, fuse=True
|
||||
)
|
||||
self.stride = max(int(self.network.stride.max()), 32) # model stride
|
||||
self.network.float()
|
||||
self.class_names = dict(self.network.names) # type: ignore
|
||||
|
||||
# classes: (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
|
||||
def forward(self, input: torch.Tensor, classes=None) -> torch.Tensor:
|
||||
|
||||
assert input.ndim == 4
|
||||
assert input.shape[0] == 1
|
||||
assert input.shape[1] == 3
|
||||
|
||||
input_resized, (
|
||||
remove_left,
|
||||
remove_top,
|
||||
remove_height,
|
||||
remove_width,
|
||||
) = self.scale_and_pad(
|
||||
input,
|
||||
)
|
||||
|
||||
network_output = self.network(input_resized)
|
||||
number_of_classes = network_output[0].shape[2] - self.number_of_maps - 5
|
||||
assert len(self.class_names) == number_of_classes
|
||||
|
||||
maps = network_output[1]
|
||||
|
||||
# results matrix:
|
||||
# Fist Dimension:
|
||||
# Image Number
|
||||
# ...
|
||||
# Last Dimension:
|
||||
# center_x: 0
|
||||
# center_y: 1
|
||||
# width: 2
|
||||
# height: 3
|
||||
# obj_conf (object): 4
|
||||
# cls_conf (class): 5
|
||||
|
||||
results = non_max_suppression(
|
||||
network_output[0],
|
||||
self.conf,
|
||||
self.iou,
|
||||
classes,
|
||||
self.agnostic,
|
||||
self.multi_label,
|
||||
max_det=self.max_det,
|
||||
nm=self.number_of_maps,
|
||||
)
|
||||
|
||||
image_id = 0
|
||||
|
||||
if results[image_id].shape[0] > 0:
|
||||
masks = self.process_mask_native(
|
||||
maps[image_id],
|
||||
results[image_id][:, 6:],
|
||||
results[image_id][:, :4],
|
||||
)
|
||||
self.found_class_id = results[image_id][:, 5]
|
||||
|
||||
output = tv.transforms.functional.resize(
|
||||
tv.transforms.functional.crop(
|
||||
img=masks,
|
||||
top=int(remove_top),
|
||||
left=int(remove_left),
|
||||
height=int(remove_height),
|
||||
width=int(remove_width),
|
||||
),
|
||||
size=(input.shape[-2], input.shape[-1]),
|
||||
)
|
||||
else:
|
||||
output = None
|
||||
self.found_class_id = None
|
||||
|
||||
return output
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# -------------------------------------------------------------------------
|
||||
# -------------------------------------------------------------------------
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# code stolen and/or modified from Yolov5 ->
|
||||
def scale_and_pad(
|
||||
self,
|
||||
input,
|
||||
):
|
||||
ratio = min(self.imgsz[0] / input.shape[-2], self.imgsz[1] / input.shape[-1])
|
||||
|
||||
shape_new_x = int(input.shape[-2] * ratio)
|
||||
shape_new_y = int(input.shape[-1] * ratio)
|
||||
|
||||
dx = self.imgsz[0] - shape_new_x
|
||||
dy = self.imgsz[1] - shape_new_y
|
||||
|
||||
dx_0 = dx // 2
|
||||
dy_0 = dy // 2
|
||||
|
||||
image_resize = tv.transforms.functional.pad(
|
||||
tv.transforms.functional.resize(input, size=(shape_new_x, shape_new_y)),
|
||||
padding=[dy_0, dx_0, int(dy - dy_0), int(dx - dx_0)],
|
||||
fill=float(114.0 / 255.0),
|
||||
)
|
||||
|
||||
return image_resize, (dy_0, dx_0, shape_new_x, shape_new_y)
|
||||
|
||||
def process_mask_native(self, protos, masks_in, bboxes):
|
||||
masks = (
|
||||
(masks_in @ protos.float().view(protos.shape[0], -1))
|
||||
.sigmoid()
|
||||
.view(-1, protos.shape[1], protos.shape[2])
|
||||
)
|
||||
|
||||
masks = torch.nn.functional.interpolate(
|
||||
masks[None],
|
||||
(self.imgsz[0], self.imgsz[1]),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)[
|
||||
0
|
||||
] # CHW
|
||||
|
||||
x1, y1, x2, y2 = torch.chunk(bboxes[:, :, None], 4, 1) # x1 shape(1,1,n)
|
||||
r = torch.arange(masks.shape[2], device=masks.device, dtype=x1.dtype)[
|
||||
None, None, :
|
||||
] # rows shape(1,w,1)
|
||||
c = torch.arange(masks.shape[1], device=masks.device, dtype=x1.dtype)[
|
||||
None, :, None
|
||||
] # cols shape(h,1,1)
|
||||
|
||||
masks = masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
|
||||
|
||||
return masks.gt_(0.5)
|
||||
|
||||
def attempt_load(self, weights, device=None, inplace=True, fuse=True):
|
||||
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
||||
|
||||
model = Ensemble()
|
||||
for w in weights if isinstance(weights, list) else [weights]:
|
||||
ckpt = torch.load(w, map_location="cpu") # load
|
||||
ckpt = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model
|
||||
|
||||
# Model compatibility updates
|
||||
if not hasattr(ckpt, "stride"):
|
||||
ckpt.stride = torch.tensor([32.0])
|
||||
if hasattr(ckpt, "names") and isinstance(ckpt.names, (list, tuple)):
|
||||
ckpt.names = dict(enumerate(ckpt.names)) # convert to dict
|
||||
|
||||
model.append(
|
||||
ckpt.fuse().eval() if fuse and hasattr(ckpt, "fuse") else ckpt.eval()
|
||||
) # model in eval mode
|
||||
|
||||
# Module compatibility updates
|
||||
for m in model.modules():
|
||||
t = type(m)
|
||||
if t in (
|
||||
torch.nn.Hardswish,
|
||||
torch.nn.LeakyReLU,
|
||||
torch.nn.ReLU,
|
||||
torch.nn.ReLU6,
|
||||
torch.nn.SiLU,
|
||||
Detect,
|
||||
Model,
|
||||
):
|
||||
m.inplace = inplace # torch 1.7.0 compatibility
|
||||
if t is Detect and not isinstance(m.anchor_grid, list):
|
||||
delattr(m, "anchor_grid")
|
||||
setattr(m, "anchor_grid", [torch.zeros(1)] * m.nl)
|
||||
elif t is torch.nn.Upsample and not hasattr(m, "recompute_scale_factor"):
|
||||
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
||||
|
||||
# Return model
|
||||
if len(model) == 1:
|
||||
return model[-1]
|
||||
|
||||
# Return detection ensemble
|
||||
print(f"Ensemble created with {weights}\n")
|
||||
for k in "names", "nc", "yaml":
|
||||
setattr(model, k, getattr(model[0], k))
|
||||
model.stride = model[
|
||||
torch.argmax(torch.tensor([m.stride.max() for m in model])).int()
|
||||
].stride # max stride
|
||||
assert all(
|
||||
model[0].nc == m.nc for m in model
|
||||
), f"Models have different class counts: {[m.nc for m in model]}"
|
||||
return model
|
||||
|
||||
|
||||
class Ensemble(torch.nn.ModuleList):
|
||||
# Ensemble of models
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, augment=False, profile=False, visualize=False):
|
||||
y = [module(x, augment, profile, visualize)[0] for module in self]
|
||||
# y = torch.stack(y).max(0)[0] # max ensemble
|
||||
# y = torch.stack(y).mean(0) # mean ensemble
|
||||
y = torch.cat(y, 1) # nms ensemble
|
||||
return y, None # inference, train output
|
||||
|
||||
|
||||
def non_max_suppression(
|
||||
prediction,
|
||||
conf_thres=0.25,
|
||||
iou_thres=0.45,
|
||||
classes=None,
|
||||
agnostic=False,
|
||||
multi_label=False,
|
||||
labels=(),
|
||||
max_det=300,
|
||||
nm=0, # number of masks
|
||||
):
|
||||
"""Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
|
||||
|
||||
Returns:
|
||||
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
|
||||
"""
|
||||
|
||||
if isinstance(
|
||||
prediction, (list, tuple)
|
||||
): # YOLOv5 model in validation model, output = (inference_out, loss_out)
|
||||
prediction = prediction[0] # select only inference output
|
||||
|
||||
device = prediction.device
|
||||
mps = "mps" in device.type # Apple MPS
|
||||
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
|
||||
prediction = prediction.cpu()
|
||||
bs = prediction.shape[0] # batch size
|
||||
nc = prediction.shape[2] - nm - 5 # number of classes
|
||||
xc = prediction[..., 4] > conf_thres # candidates
|
||||
|
||||
# Checks
|
||||
assert (
|
||||
0 <= conf_thres <= 1
|
||||
), f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
|
||||
assert (
|
||||
0 <= iou_thres <= 1
|
||||
), f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
|
||||
|
||||
# Settings
|
||||
# min_wh = 2 # (pixels) minimum box width and height
|
||||
max_wh = 7680 # (pixels) maximum box width and height
|
||||
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
|
||||
time_limit = 0.5 + 0.05 * bs # seconds to quit after
|
||||
redundant = True # require redundant detections
|
||||
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
||||
merge = False # use merge-NMS
|
||||
|
||||
t = time.time()
|
||||
mi = 5 + nc # mask start index
|
||||
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
|
||||
for xi, x in enumerate(prediction): # image index, image inference
|
||||
# Apply constraints
|
||||
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
|
||||
x = x[xc[xi]] # confidence
|
||||
|
||||
# Cat apriori labels if autolabelling
|
||||
if labels and len(labels[xi]):
|
||||
lb = labels[xi]
|
||||
v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
|
||||
v[:, :4] = lb[:, 1:5] # box
|
||||
v[:, 4] = 1.0 # conf
|
||||
v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
|
||||
x = torch.cat((x, v), 0)
|
||||
|
||||
# If none remain process next image
|
||||
if not x.shape[0]:
|
||||
continue
|
||||
|
||||
# Compute conf
|
||||
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
|
||||
|
||||
# Box/Mask
|
||||
box = xywh2xyxy(
|
||||
x[:, :4]
|
||||
) # center_x, center_y, width, height) to (x1, y1, x2, y2)
|
||||
mask = x[:, mi:] # zero columns if no masks
|
||||
|
||||
# Detections matrix nx6 (xyxy, conf, cls)
|
||||
if multi_label:
|
||||
i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T
|
||||
x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1)
|
||||
else: # best class only
|
||||
conf, j = x[:, 5:mi].max(1, keepdim=True)
|
||||
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
|
||||
|
||||
# Filter by class
|
||||
if classes is not None:
|
||||
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
|
||||
|
||||
# Apply finite constraint
|
||||
# if not torch.isfinite(x).all():
|
||||
# x = x[torch.isfinite(x).all(1)]
|
||||
|
||||
# Check shape
|
||||
n = x.shape[0] # number of boxes
|
||||
if not n: # no boxes
|
||||
continue
|
||||
elif n > max_nms: # excess boxes
|
||||
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
|
||||
else:
|
||||
x = x[x[:, 4].argsort(descending=True)] # sort by confidence
|
||||
|
||||
# Batched NMS
|
||||
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
||||
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
||||
i = tv.ops.nms(boxes, scores, iou_thres) # NMS
|
||||
if i.shape[0] > max_det: # limit detections
|
||||
i = i[:max_det]
|
||||
if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean)
|
||||
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
|
||||
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
|
||||
weights = iou * scores[None] # box weights
|
||||
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(
|
||||
1, keepdim=True
|
||||
) # merged boxes
|
||||
if redundant:
|
||||
i = i[iou.sum(1) > 1] # require redundancy
|
||||
|
||||
output[xi] = x[i]
|
||||
if mps:
|
||||
output[xi] = output[xi].to(device)
|
||||
if (time.time() - t) > time_limit:
|
||||
print(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded")
|
||||
break # time limit exceeded
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def box_iou(box1, box2, eps=1e-7):
|
||||
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
|
||||
(a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
|
||||
inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
|
||||
|
||||
# IoU = inter / (area1 + area2 - inter)
|
||||
return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
|
||||
|
||||
|
||||
def xywh2xyxy(x):
|
||||
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
||||
y = x.clone()
|
||||
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
||||
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
||||
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
|
||||
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
|
||||
return y
|
||||
|
||||
|
||||
# <- code stolen and/or modified from Yolov5
|
Loading…
Reference in a new issue