Add files via upload
This commit is contained in:
commit
dab7dcb786
9 changed files with 2931 additions and 0 deletions
89
processing_chain/BuildImage.py
Normal file
89
processing_chain/BuildImage.py
Normal file
|
@ -0,0 +1,89 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def clip_coordinates(x_canvas: int, dx_canvas: int, dx_dict: int):
|
||||||
|
|
||||||
|
x_canvas = int(x_canvas)
|
||||||
|
dx_canvas = int(dx_canvas)
|
||||||
|
dx_dict = int(dx_dict)
|
||||||
|
dr_dict = int(dx_dict // 2)
|
||||||
|
|
||||||
|
x0_canvas = int(x_canvas - dr_dict)
|
||||||
|
# placement outside right boundary?
|
||||||
|
if x0_canvas >= dx_canvas:
|
||||||
|
return None
|
||||||
|
|
||||||
|
x1_canvas = int(x_canvas + dr_dict + (dx_dict % 2))
|
||||||
|
# placement outside left boundary?
|
||||||
|
if x1_canvas <= 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# clip to the left?
|
||||||
|
if x0_canvas < 0:
|
||||||
|
x0_dict = -x0_canvas
|
||||||
|
x0_canvas = 0
|
||||||
|
else:
|
||||||
|
x0_dict = 0
|
||||||
|
|
||||||
|
# clip to the right?
|
||||||
|
if x1_canvas > dx_canvas:
|
||||||
|
x1_dict = dx_dict - (x1_canvas - dx_canvas)
|
||||||
|
x1_canvas = dx_canvas
|
||||||
|
else:
|
||||||
|
x1_dict = dx_dict
|
||||||
|
|
||||||
|
# print(x0_canvas, x1_canvas, x0_dict, x1_dict)
|
||||||
|
assert (x1_canvas - x0_canvas) == (x1_dict - x0_dict)
|
||||||
|
|
||||||
|
return x0_canvas, x1_canvas, x0_dict, x1_dict
|
||||||
|
|
||||||
|
|
||||||
|
def BuildImage(
|
||||||
|
canvas_size: torch.Size,
|
||||||
|
dictionary: torch.Tensor,
|
||||||
|
position_found: torch.Tensor,
|
||||||
|
default_dtype,
|
||||||
|
torch_device,
|
||||||
|
):
|
||||||
|
|
||||||
|
assert position_found is not None
|
||||||
|
assert dictionary is not None
|
||||||
|
|
||||||
|
canvas_size_copy = torch.tensor(canvas_size)
|
||||||
|
assert canvas_size_copy.shape[0] == 4
|
||||||
|
canvas_size_copy[1] = 1
|
||||||
|
output = torch.zeros(
|
||||||
|
canvas_size_copy.tolist(),
|
||||||
|
device=torch_device,
|
||||||
|
dtype=default_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
dx_canvas = canvas_size[-2]
|
||||||
|
dy_canvas = canvas_size[-1]
|
||||||
|
dx_dict = dictionary.shape[-2]
|
||||||
|
dy_dict = dictionary.shape[-1]
|
||||||
|
|
||||||
|
for pattern_id in range(0, position_found.shape[0]):
|
||||||
|
for patch_id in range(0, position_found.shape[1]):
|
||||||
|
|
||||||
|
x_canvas = position_found[pattern_id, patch_id, 1]
|
||||||
|
y_canvas = position_found[pattern_id, patch_id, 2]
|
||||||
|
|
||||||
|
xv = clip_coordinates(x_canvas, dx_canvas, dx_dict)
|
||||||
|
if xv == None:
|
||||||
|
break
|
||||||
|
|
||||||
|
yv = clip_coordinates(y_canvas, dy_canvas, dy_dict)
|
||||||
|
if yv == None:
|
||||||
|
break
|
||||||
|
|
||||||
|
if dictionary.shape[0] > 1:
|
||||||
|
elem_idx = int(position_found[pattern_id, patch_id, 0])
|
||||||
|
else:
|
||||||
|
elem_idx = 0
|
||||||
|
|
||||||
|
output[pattern_id, 0, xv[0] : xv[1], yv[0] : yv[1]] += dictionary[
|
||||||
|
elem_idx, 0, xv[2] : xv[3], yv[2] : yv[3]
|
||||||
|
]
|
||||||
|
|
||||||
|
return output
|
288
processing_chain/ContourExtract.py
Normal file
288
processing_chain/ContourExtract.py
Normal file
|
@ -0,0 +1,288 @@
|
||||||
|
# ContourExtract.py
|
||||||
|
# ====================================
|
||||||
|
# extracts contours from gray-level images
|
||||||
|
#
|
||||||
|
# Version V1.0, pre-07.03.2023:
|
||||||
|
# no actual changes, is David's last code version...
|
||||||
|
#
|
||||||
|
# Version V1.1, 07.03.2023:
|
||||||
|
# merged David's rebuild code (GUI capable)
|
||||||
|
#
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class ContourExtract(torch.nn.Module):
|
||||||
|
image_width: int = -1
|
||||||
|
image_height: int = -1
|
||||||
|
output_threshold: torch.Tensor = torch.tensor(-1)
|
||||||
|
|
||||||
|
n_orientations: int
|
||||||
|
sigma_kernel: float = 0
|
||||||
|
lambda_kernel: float = 0
|
||||||
|
gamma_aspect_ratio: float = 0
|
||||||
|
|
||||||
|
kernel_axis_x: torch.Tensor | None = None
|
||||||
|
kernel_axis_y: torch.Tensor | None = None
|
||||||
|
target_orientations: torch.Tensor
|
||||||
|
weight_vector: torch.Tensor
|
||||||
|
|
||||||
|
image_scale: float | None
|
||||||
|
|
||||||
|
psi_phase_offset_cos: torch.Tensor
|
||||||
|
psi_phase_offset_sin: torch.Tensor
|
||||||
|
|
||||||
|
fft_gabor_cos_bank: torch.Tensor | None = None
|
||||||
|
fft_gabor_sin_bank: torch.Tensor | None = None
|
||||||
|
|
||||||
|
pi: torch.Tensor
|
||||||
|
torch_device: torch.device
|
||||||
|
default_dtype = torch.float32
|
||||||
|
|
||||||
|
rebuild_kernels: bool
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_orientations: int,
|
||||||
|
sigma_kernel: float,
|
||||||
|
lambda_kernel: float,
|
||||||
|
gamma_aspect_ratio: float = 1.0,
|
||||||
|
image_scale: float | None = 255.0,
|
||||||
|
torch_device: str = "cpu",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.torch_device = torch.device(torch_device)
|
||||||
|
|
||||||
|
self.pi = torch.tensor(
|
||||||
|
math.pi,
|
||||||
|
device=self.torch_device,
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.psi_phase_offset_cos = torch.tensor(
|
||||||
|
0.0,
|
||||||
|
device=self.torch_device,
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
)
|
||||||
|
self.psi_phase_offset_sin = -self.pi / 2
|
||||||
|
self.gamma_aspect_ratio = gamma_aspect_ratio
|
||||||
|
self.image_scale = image_scale
|
||||||
|
|
||||||
|
self.update_settings(
|
||||||
|
n_orientations=n_orientations,
|
||||||
|
sigma_kernel=sigma_kernel,
|
||||||
|
lambda_kernel=lambda_kernel,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
assert input.ndim == 4, "input must have 4 dims!"
|
||||||
|
|
||||||
|
# We can only handle one color channel
|
||||||
|
assert input.shape[1] == 1, "input.shape[1] must be 1!"
|
||||||
|
|
||||||
|
input = input.type(dtype=self.default_dtype).to(device=self.torch_device)
|
||||||
|
if self.image_scale is not None:
|
||||||
|
# scale grey level [0, 255] to range [0.0, 1.0]
|
||||||
|
input /= self.image_scale
|
||||||
|
|
||||||
|
# Do we have valid kernels?
|
||||||
|
if input.shape[-2] != self.image_width:
|
||||||
|
self.image_width = input.shape[-2]
|
||||||
|
self.rebuild_kernels = True
|
||||||
|
|
||||||
|
if input.shape[-1] != self.image_height:
|
||||||
|
self.image_height = input.shape[-1]
|
||||||
|
self.rebuild_kernels = True
|
||||||
|
|
||||||
|
assert self.image_width > 0
|
||||||
|
assert self.image_height > 0
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
# We need to rebuild the kernels
|
||||||
|
if self.rebuild_kernels is True:
|
||||||
|
|
||||||
|
self.kernel_axis_x = self.create_kernel_axis(self.image_width)
|
||||||
|
self.kernel_axis_y = self.create_kernel_axis(self.image_height)
|
||||||
|
|
||||||
|
assert self.kernel_axis_x is not None
|
||||||
|
assert self.kernel_axis_y is not None
|
||||||
|
|
||||||
|
gabor_cos_bank, gabor_sin_bank = self.create_gabor_filter_bank()
|
||||||
|
|
||||||
|
assert gabor_cos_bank is not None
|
||||||
|
assert gabor_sin_bank is not None
|
||||||
|
|
||||||
|
self.fft_gabor_cos_bank = torch.fft.rfft2(
|
||||||
|
gabor_cos_bank, s=None, dim=(-2, -1), norm=None
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
self.fft_gabor_sin_bank = torch.fft.rfft2(
|
||||||
|
gabor_sin_bank, s=None, dim=(-2, -1), norm=None
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
# compute threshold for ignoring non-zero outputs arising due
|
||||||
|
# to numerical imprecision (fft, kernel definition)
|
||||||
|
assert self.weight_vector is not None
|
||||||
|
norm_input = torch.full(
|
||||||
|
(1, 1, self.image_width, self.image_height),
|
||||||
|
1.0,
|
||||||
|
device=self.torch_device,
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
)
|
||||||
|
norm_fft_input = torch.fft.rfft2(
|
||||||
|
norm_input, s=None, dim=(-2, -1), norm=None
|
||||||
|
)
|
||||||
|
|
||||||
|
norm_output_sin = torch.fft.irfft2(
|
||||||
|
norm_fft_input * self.fft_gabor_sin_bank,
|
||||||
|
s=None,
|
||||||
|
dim=(-2, -1),
|
||||||
|
norm=None,
|
||||||
|
)
|
||||||
|
norm_output_cos = torch.fft.irfft2(
|
||||||
|
norm_fft_input * self.fft_gabor_cos_bank,
|
||||||
|
s=None,
|
||||||
|
dim=(-2, -1),
|
||||||
|
norm=None,
|
||||||
|
)
|
||||||
|
norm_value_matrix = torch.sqrt(norm_output_sin**2 + norm_output_cos**2)
|
||||||
|
|
||||||
|
# norm_output = torch.abs(
|
||||||
|
# (self.weight_vector * norm_value_matrix).sum(dim=-1)
|
||||||
|
# ).type(dtype=torch.float32)
|
||||||
|
|
||||||
|
self.output_threshold = norm_value_matrix.max()
|
||||||
|
|
||||||
|
self.rebuild_kernels = False
|
||||||
|
|
||||||
|
assert self.fft_gabor_cos_bank is not None
|
||||||
|
assert self.fft_gabor_sin_bank is not None
|
||||||
|
|
||||||
|
t1 = time.time()
|
||||||
|
|
||||||
|
fft_input = torch.fft.rfft2(input, s=None, dim=(-2, -1), norm=None)
|
||||||
|
|
||||||
|
output_sin = torch.fft.irfft2(
|
||||||
|
fft_input * self.fft_gabor_sin_bank, s=None, dim=(-2, -1), norm=None
|
||||||
|
)
|
||||||
|
|
||||||
|
output_cos = torch.fft.irfft2(
|
||||||
|
fft_input * self.fft_gabor_cos_bank, s=None, dim=(-2, -1), norm=None
|
||||||
|
)
|
||||||
|
|
||||||
|
t2 = time.time()
|
||||||
|
|
||||||
|
output = torch.sqrt(output_sin**2 + output_cos**2)
|
||||||
|
|
||||||
|
t3 = time.time()
|
||||||
|
|
||||||
|
print(
|
||||||
|
"ContourExtract {:.3f}s: prep-{:.3f}s, fft-{:.3f}s, out-{:.3f}s".format(
|
||||||
|
t3 - t0, t1 - t0, t2 - t1, t3 - t2
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def create_collapse(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
assert self.weight_vector is not None
|
||||||
|
|
||||||
|
output = torch.abs((self.weight_vector * input).sum(dim=1)).type(
|
||||||
|
dtype=self.default_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def create_kernel_axis(self, axis_size: int) -> torch.Tensor:
|
||||||
|
|
||||||
|
lower_bound_axis: int = -int(math.floor(axis_size / 2))
|
||||||
|
upper_bound_axis: int = int(math.ceil(axis_size / 2))
|
||||||
|
|
||||||
|
kernel_axis = torch.arange(
|
||||||
|
lower_bound_axis,
|
||||||
|
upper_bound_axis,
|
||||||
|
device=self.torch_device,
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
)
|
||||||
|
kernel_axis = torch.roll(kernel_axis, int(math.ceil(axis_size / 2)))
|
||||||
|
|
||||||
|
return kernel_axis
|
||||||
|
|
||||||
|
def create_gabor_filter_bank(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
assert self.kernel_axis_x is not None
|
||||||
|
assert self.kernel_axis_y is not None
|
||||||
|
assert self.target_orientations is not None
|
||||||
|
|
||||||
|
orientation_matrix = (
|
||||||
|
self.target_orientations.unsqueeze(-1).unsqueeze(-1).detach().clone()
|
||||||
|
)
|
||||||
|
x_kernel_matrix = self.kernel_axis_x.unsqueeze(0).unsqueeze(-1).detach().clone()
|
||||||
|
y_kernel_matrix = self.kernel_axis_y.unsqueeze(0).unsqueeze(0).detach().clone()
|
||||||
|
|
||||||
|
r2 = x_kernel_matrix**2 + self.gamma_aspect_ratio * y_kernel_matrix**2
|
||||||
|
|
||||||
|
kr = x_kernel_matrix * torch.cos(
|
||||||
|
orientation_matrix
|
||||||
|
) + y_kernel_matrix * torch.sin(orientation_matrix)
|
||||||
|
|
||||||
|
c0 = torch.exp(-2 * (self.pi * self.sigma_kernel / self.lambda_kernel) ** 2)
|
||||||
|
|
||||||
|
gauss: torch.Tensor = torch.exp(-r2 / 2 / self.sigma_kernel**2)
|
||||||
|
|
||||||
|
gabor_cos_bank: torch.Tensor = gauss * (
|
||||||
|
torch.cos(2 * self.pi * kr / self.lambda_kernel + self.psi_phase_offset_cos)
|
||||||
|
- c0 * torch.cos(self.psi_phase_offset_cos)
|
||||||
|
)
|
||||||
|
|
||||||
|
gabor_sin_bank: torch.Tensor = gauss * (
|
||||||
|
torch.cos(2 * self.pi * kr / self.lambda_kernel + self.psi_phase_offset_sin)
|
||||||
|
- c0 * torch.cos(self.psi_phase_offset_sin)
|
||||||
|
)
|
||||||
|
|
||||||
|
return gabor_cos_bank, gabor_sin_bank
|
||||||
|
|
||||||
|
def update_settings(
|
||||||
|
self,
|
||||||
|
n_orientations: int,
|
||||||
|
sigma_kernel: float,
|
||||||
|
lambda_kernel: float,
|
||||||
|
):
|
||||||
|
|
||||||
|
self.rebuild_kernels = True
|
||||||
|
|
||||||
|
self.n_orientations = n_orientations
|
||||||
|
self.sigma_kernel = sigma_kernel
|
||||||
|
self.lambda_kernel = lambda_kernel
|
||||||
|
|
||||||
|
# generate orientation axis and axis for complex summation
|
||||||
|
self.target_orientations: torch.Tensor = (
|
||||||
|
torch.arange(
|
||||||
|
start=0,
|
||||||
|
end=int(self.n_orientations),
|
||||||
|
device=self.torch_device,
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
)
|
||||||
|
* torch.tensor(
|
||||||
|
math.pi,
|
||||||
|
device=self.torch_device,
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
)
|
||||||
|
/ self.n_orientations
|
||||||
|
)
|
||||||
|
|
||||||
|
self.weight_vector: torch.Tensor = (
|
||||||
|
torch.exp(
|
||||||
|
2.0
|
||||||
|
* torch.complex(torch.tensor(0.0), torch.tensor(1.0))
|
||||||
|
* self.target_orientations
|
||||||
|
)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.unsqueeze(0)
|
||||||
|
)
|
162
processing_chain/DiscardElements.py
Normal file
162
processing_chain/DiscardElements.py
Normal file
|
@ -0,0 +1,162 @@
|
||||||
|
#%%
|
||||||
|
# DiscardElements.py
|
||||||
|
# ====================================
|
||||||
|
# removes elements from a sparse image representation
|
||||||
|
# such that a 'most uniform' coverage still exists
|
||||||
|
#
|
||||||
|
# Version V1.0, pre-07.03.2023:
|
||||||
|
# no actual changes, is David's last code version...
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# assume locations is array [n, 3]
|
||||||
|
# the three entries are [shape_index, pos_x, pos_y]
|
||||||
|
|
||||||
|
|
||||||
|
def discard_elements_simple(locations: np.ndarray, target_number_elements: list):
|
||||||
|
|
||||||
|
n_locations: int = locations.shape[0]
|
||||||
|
locations_remain: list = []
|
||||||
|
|
||||||
|
# Loop across all target number of elements
|
||||||
|
for target_elem in target_number_elements:
|
||||||
|
|
||||||
|
assert target_elem > 0, "Number of target elements must be larger than 0!"
|
||||||
|
assert (
|
||||||
|
target_elem <= n_locations
|
||||||
|
), "Number of target elements must be <= number of available locations!"
|
||||||
|
|
||||||
|
# Build distance matrix between positions in locations_highest_res.
|
||||||
|
# Its diagonal is defined as Inf because we don't want to consider these values in our
|
||||||
|
# search for the minimum distances.
|
||||||
|
distance_matrix = np.sqrt(
|
||||||
|
((locations[np.newaxis, :, 1:] - locations[:, np.newaxis, 1:]) ** 2).sum(
|
||||||
|
axis=-1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
distance_matrix[np.arange(n_locations), np.arange(n_locations)] = np.inf
|
||||||
|
|
||||||
|
# Find the minimal distances in upper triangle of matrix.
|
||||||
|
idcs_remove: list = []
|
||||||
|
while (n_locations - len(idcs_remove)) != target_elem:
|
||||||
|
|
||||||
|
# Get index of matrix with minimal distance
|
||||||
|
row_idcs, col_idcs = np.where(
|
||||||
|
distance_matrix == distance_matrix[distance_matrix > 0].min()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the max index.
|
||||||
|
# It correspond to the index of the element we will remove in the locations_highest_res list
|
||||||
|
sel_idx: int = max(row_idcs[0], col_idcs[0])
|
||||||
|
idcs_remove.append(sel_idx) # Save the index
|
||||||
|
|
||||||
|
# Set current distance as Inf because we don't want to consider it further in our search
|
||||||
|
distance_matrix[sel_idx, :] = np.inf
|
||||||
|
distance_matrix[:, sel_idx] = np.inf
|
||||||
|
|
||||||
|
idcs_remain: list = np.setdiff1d(np.arange(n_locations), idcs_remove)
|
||||||
|
locations_remain.append(locations[idcs_remain, :])
|
||||||
|
|
||||||
|
return locations_remain
|
||||||
|
|
||||||
|
|
||||||
|
# assume locations is array [n, 3]
|
||||||
|
# the three entries are [shape_index, pos_x, pos_y]
|
||||||
|
def discard_elements(
|
||||||
|
locations: np.ndarray, target_number_elements: list, prior: np.ndarray
|
||||||
|
):
|
||||||
|
|
||||||
|
n_locations: int = locations.shape[0]
|
||||||
|
locations_remain: list = []
|
||||||
|
disable_value: float = np.nan
|
||||||
|
|
||||||
|
# if type(prior) != np.ndarray:
|
||||||
|
# prior = np.ones((n_locations,))
|
||||||
|
assert prior.shape == (
|
||||||
|
n_locations,
|
||||||
|
), "Prior must have same number of entries as elements in locations!"
|
||||||
|
print(prior)
|
||||||
|
|
||||||
|
# Loop across all target number of elements
|
||||||
|
for target_elem in target_number_elements:
|
||||||
|
|
||||||
|
assert target_elem > 0, "Number of target elements must be larger than 0!"
|
||||||
|
assert (
|
||||||
|
target_elem <= n_locations
|
||||||
|
), "Number of target elements must be <= number of available locations!"
|
||||||
|
|
||||||
|
# Build distance matrix between positions in locations_highest_res.
|
||||||
|
# Its diagonal is defined as Inf because we don't want to consider these values in our
|
||||||
|
# search for the minimum distances.
|
||||||
|
distance_matrix = np.sqrt(
|
||||||
|
((locations[np.newaxis, :, 1:] - locations[:, np.newaxis, 1:]) ** 2).sum(
|
||||||
|
axis=-1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
prior_matrix = prior[np.newaxis, :] * prior[:, np.newaxis]
|
||||||
|
distance_matrix *= prior_matrix
|
||||||
|
distance_matrix[np.arange(n_locations), np.arange(n_locations)] = disable_value
|
||||||
|
print(distance_matrix)
|
||||||
|
|
||||||
|
# Find the minimal distances in upper triangle of matrix.
|
||||||
|
idcs_remove: list = []
|
||||||
|
while (n_locations - len(idcs_remove)) != target_elem:
|
||||||
|
|
||||||
|
# Get index of matrix with minimal distance
|
||||||
|
row_idcs, col_idcs = np.where(
|
||||||
|
# distance_matrix == distance_matrix[distance_matrix > 0].min()
|
||||||
|
distance_matrix
|
||||||
|
== np.nanmin(distance_matrix)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the max index.
|
||||||
|
# It correspond to the index of the element we will remove in the locations_highest_res list
|
||||||
|
print(row_idcs[0], col_idcs[0])
|
||||||
|
# if prior[row_idcs[0]] >= prior[col_idcs[0]]:
|
||||||
|
# sel_idx = row_idcs[0]
|
||||||
|
# else:
|
||||||
|
# sel_idx = col_idcs[0]
|
||||||
|
d_row = np.nansum(distance_matrix[row_idcs[0], :])
|
||||||
|
d_col = np.nansum(distance_matrix[:, col_idcs[0]])
|
||||||
|
if d_row > d_col:
|
||||||
|
sel_idx = col_idcs[0]
|
||||||
|
else:
|
||||||
|
sel_idx = row_idcs[0]
|
||||||
|
# sel_idx: int = max(row_idcs[0], col_idcs[0])
|
||||||
|
idcs_remove.append(sel_idx) # Save the index
|
||||||
|
|
||||||
|
# Set current distance as Inf because we don't want to consider it further in our search
|
||||||
|
distance_matrix[sel_idx, :] = disable_value
|
||||||
|
distance_matrix[:, sel_idx] = disable_value
|
||||||
|
|
||||||
|
idcs_remain: list = np.setdiff1d(np.arange(n_locations), idcs_remove)
|
||||||
|
locations_remain.append(locations[idcs_remain, :])
|
||||||
|
|
||||||
|
return locations_remain
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
# generate a circle with n locations
|
||||||
|
n_locations: int = 20
|
||||||
|
phi = np.arange(n_locations) / n_locations * 2 * np.pi
|
||||||
|
locations = np.ones((n_locations, 3))
|
||||||
|
locations[:, 1] = np.cos(phi)
|
||||||
|
locations[:, 2] = np.sin(phi)
|
||||||
|
prior = np.ones((n_locations,))
|
||||||
|
prior[:10] = 0.1
|
||||||
|
locations_remain = discard_elements(locations, [n_locations // 5], prior=prior)
|
||||||
|
|
||||||
|
plt.plot(locations[:, 1], locations[:, 2], "ko")
|
||||||
|
plt.plot(locations_remain[0][:, 1], locations_remain[0][:, 2], "rx")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
locations_remain_simple = discard_elements_simple(locations, [n_locations // 5])
|
||||||
|
|
||||||
|
plt.plot(locations[:, 1], locations[:, 2], "ko")
|
||||||
|
plt.plot(locations_remain_simple[0][:, 1], locations_remain_simple[0][:, 2], "rx")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
711
processing_chain/OnlineEncoding.py
Normal file
711
processing_chain/OnlineEncoding.py
Normal file
|
@ -0,0 +1,711 @@
|
||||||
|
# %%
|
||||||
|
#
|
||||||
|
# test_OnlineEncoding.py
|
||||||
|
# ========================================================
|
||||||
|
# encode visual scenes into sparse representations using
|
||||||
|
# different kinds of dictionaries
|
||||||
|
#
|
||||||
|
# -> derived from test_PsychophysicsEncoding.py
|
||||||
|
#
|
||||||
|
# Version 1.0, 29.04.2023:
|
||||||
|
#
|
||||||
|
# Version 1.1, 21.06.2023:
|
||||||
|
# define proper class
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
# Import Python modules
|
||||||
|
# ========================================================
|
||||||
|
# import csv
|
||||||
|
# import time
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import torch
|
||||||
|
import torchvision as tv
|
||||||
|
from PIL import Image
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# Import our modules
|
||||||
|
# ========================================================
|
||||||
|
from processing_chain.ContourExtract import ContourExtract
|
||||||
|
from processing_chain.PatchGenerator import PatchGenerator
|
||||||
|
from processing_chain.Sparsifier import Sparsifier
|
||||||
|
from processing_chain.DiscardElements import discard_elements_simple
|
||||||
|
from processing_chain.BuildImage import BuildImage
|
||||||
|
from processing_chain.WebCam import WebCam
|
||||||
|
from processing_chain.Yolo5Segmentation import Yolo5Segmentation
|
||||||
|
|
||||||
|
|
||||||
|
# TODO required?
|
||||||
|
def show_torch_frame(
|
||||||
|
frame_torch: torch.Tensor,
|
||||||
|
title: str = "",
|
||||||
|
cmap: str = "viridis",
|
||||||
|
target: str = "pyplot",
|
||||||
|
):
|
||||||
|
frame_numpy = (
|
||||||
|
(frame_torch.movedim(0, -1) * 255).type(dtype=torch.uint8).cpu().numpy()
|
||||||
|
)
|
||||||
|
if target == "pyplot":
|
||||||
|
plt.imshow(frame_numpy, cmap=cmap)
|
||||||
|
plt.title(title)
|
||||||
|
plt.show()
|
||||||
|
if target == "cv2":
|
||||||
|
if frame_numpy.ndim == 3:
|
||||||
|
if frame_numpy.shape[-1] == 1:
|
||||||
|
frame_numpy = np.tile(frame_numpy, [1, 1, 3])
|
||||||
|
frame_numpy = (frame_numpy - frame_numpy.min()) / (
|
||||||
|
frame_numpy.max() - frame_numpy.min()
|
||||||
|
)
|
||||||
|
# print(frame_numpy.shape, frame_numpy.max(), frame_numpy.min())
|
||||||
|
cv2.namedWindow(title, cv2.WINDOW_NORMAL)
|
||||||
|
cv2.imshow(title, frame_numpy[:, :, (2, 1, 0)])
|
||||||
|
cv2.waitKey(1)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
# TODO required?
|
||||||
|
def embed_image(frame_torch, out_height, out_width, init_value=0):
|
||||||
|
|
||||||
|
out_shape = torch.tensor(frame_torch.shape)
|
||||||
|
|
||||||
|
frame_width = frame_torch.shape[-1]
|
||||||
|
frame_height = frame_torch.shape[-2]
|
||||||
|
|
||||||
|
frame_width_idx0 = max([0, (frame_width - out_width) // 2])
|
||||||
|
frame_height_idx0 = max([0, (frame_height - out_height) // 2])
|
||||||
|
|
||||||
|
select_width = min([frame_width, out_width])
|
||||||
|
select_height = min([frame_height, out_height])
|
||||||
|
|
||||||
|
out_shape[-1] = out_width
|
||||||
|
out_shape[-2] = out_height
|
||||||
|
|
||||||
|
out_torch = init_value * torch.ones(tuple(out_shape))
|
||||||
|
|
||||||
|
out_width_idx0 = max([0, (out_width - frame_width) // 2])
|
||||||
|
out_height_idx0 = max([0, (out_height - frame_height) // 2])
|
||||||
|
|
||||||
|
out_torch[
|
||||||
|
...,
|
||||||
|
out_height_idx0 : (out_height_idx0 + select_height),
|
||||||
|
out_width_idx0 : (out_width_idx0 + select_width),
|
||||||
|
] = frame_torch[
|
||||||
|
...,
|
||||||
|
frame_height_idx0 : (frame_height_idx0 + select_height),
|
||||||
|
frame_width_idx0 : (frame_width_idx0 + select_width),
|
||||||
|
]
|
||||||
|
|
||||||
|
return out_torch
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineEncoding:
|
||||||
|
|
||||||
|
# TODO: also pre-populate self-ies here?
|
||||||
|
#
|
||||||
|
# DEFINED IN "__init__":
|
||||||
|
#
|
||||||
|
# display (fixed)
|
||||||
|
# gabor (changeable)
|
||||||
|
# encoding (changeable)
|
||||||
|
# dictionary (changeable)
|
||||||
|
# control (fixed)
|
||||||
|
# path (fixed)
|
||||||
|
# verbose
|
||||||
|
# torch_device, default_dtype
|
||||||
|
# display_size_max_x_PIX, display_size_max_y_PIX
|
||||||
|
# padding_fill
|
||||||
|
# cap
|
||||||
|
# yolo
|
||||||
|
# classes_detect
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# DEFINED IN "apply_parameter_changes":
|
||||||
|
#
|
||||||
|
# padding_PIX
|
||||||
|
# sigma_kernel_PIX, lambda_kernel_PIX
|
||||||
|
# out_x, out_y
|
||||||
|
# clocks, phosphene, clocks_filter
|
||||||
|
#
|
||||||
|
|
||||||
|
def __init__(self, source=0, verbose=False):
|
||||||
|
|
||||||
|
# Define parameters
|
||||||
|
# ========================================================
|
||||||
|
# Unit abbreviations:
|
||||||
|
# dva = degrees of visual angle
|
||||||
|
# pix = pixels
|
||||||
|
print("OE-Init: Defining default parameters...")
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
# display: Defines geometry of target display
|
||||||
|
# ========================================================
|
||||||
|
# The encoded image will be scaled such that it optimally uses
|
||||||
|
# the max space available. If the orignal image has a different aspect
|
||||||
|
# ratio than the display region, it will only use one spatial
|
||||||
|
# dimension (horizontal or vertical) to its full extent
|
||||||
|
#
|
||||||
|
# If one DVA corresponds to different PIX_per_DVA on the display,
|
||||||
|
# (i.e. varying distance observers from screen), it should be set
|
||||||
|
# larger than the largest PIX_per_DVA required, for avoiding
|
||||||
|
# extrapolation artefacts or blur.
|
||||||
|
#
|
||||||
|
self.display = {
|
||||||
|
"size_max_x_DVA": 10.0, # maximum x size of encoded image
|
||||||
|
"size_max_y_DVA": 10.0, # minimum y size of encoded image
|
||||||
|
"PIX_per_DVA": 40.0, # scaling factor pixels to DVA
|
||||||
|
"scale": "same_range", # "same_luminance" or "same_range"
|
||||||
|
}
|
||||||
|
|
||||||
|
# gabor: Defines paras of Gabor filters for contour extraction
|
||||||
|
# ==============================================================
|
||||||
|
self.gabor = {
|
||||||
|
"sigma_kernel_DVA": 0.06,
|
||||||
|
"lambda_kernel_DVA": 0.12,
|
||||||
|
"n_orientations": 8,
|
||||||
|
}
|
||||||
|
|
||||||
|
# encoding: Defines parameters of sparse encoding process
|
||||||
|
# ========================================================
|
||||||
|
# Roughly speaking, after contour extraction dictionary elements
|
||||||
|
# will be placed starting from the position with the highest
|
||||||
|
# overlap with the contour. Elements placed can be surrounded
|
||||||
|
# by a dead or inhibitory zone to prevent placing further elements
|
||||||
|
# too closely. The procedure will map 'n_patches_compute' elements
|
||||||
|
# and then stop. For each element one obtains an overlap with the
|
||||||
|
# contour image.
|
||||||
|
#
|
||||||
|
# After placement, the overlaps found are normalized to the max
|
||||||
|
# overlap found, and then all elements with a larger normalized overlap
|
||||||
|
# than 'overlap_threshold' will be selected. These remaining
|
||||||
|
# elements will comprise a 'full' encoding of the contour.
|
||||||
|
#
|
||||||
|
# To generate even sparser representations, the full encoding can
|
||||||
|
# be reduced to a certain percentage of elements in the full encoding
|
||||||
|
# by setting the variable 'percentages'
|
||||||
|
#
|
||||||
|
# Example: n_patches_compute = 100 reduced by overlap_threshold = 0.1
|
||||||
|
# to 80 elements. Requesting a percentage of 30% yields a representation
|
||||||
|
# with 24 elements.
|
||||||
|
#
|
||||||
|
self.encoding = {
|
||||||
|
"n_patches_compute": 100, # this amount of patches will be placed
|
||||||
|
"use_exp_deadzone": True, # parameters of Gaussian deadzone
|
||||||
|
"size_exp_deadzone_DVA": 1.20, # PREVIOUSLY 1.4283
|
||||||
|
"use_cutout_deadzone": True, # parameters of cutout deadzone
|
||||||
|
"size_cutout_deadzone_DVA": 0.65, # PREVIOUSLY 0.7575
|
||||||
|
"overlap_threshold": 0.1, # relative overlap threshold
|
||||||
|
"percentages": torch.tensor([100]),
|
||||||
|
}
|
||||||
|
self.number_of_patches = self.encoding["n_patches_compute"]
|
||||||
|
|
||||||
|
# dictionary: Defines parameters of dictionary
|
||||||
|
# ========================================================
|
||||||
|
self.dictionary = {
|
||||||
|
"size_DVA": 1.0, # PREVIOUSLY 1.25,
|
||||||
|
"clocks": None, # parameters for clocks dictionary, see below
|
||||||
|
"phosphene": None, # paramters for phosphene dictionary, see below
|
||||||
|
}
|
||||||
|
|
||||||
|
self.dictionary["phosphene"]: dict[float] = {
|
||||||
|
"sigma_width": 0.18, # DEFAULT 0.15, # half-width of Gaussian
|
||||||
|
}
|
||||||
|
|
||||||
|
self.dictionary["clocks"]: dict[int, int, float, float] = {
|
||||||
|
"n_dir": 8, # number of directions for clock pointer segments
|
||||||
|
"n_open": 4, # number of opening angles between two clock pointer segments
|
||||||
|
"pointer_width": 0.07, # PREVIOUSLY 0.05, # relative width and size of tip extension of clock pointer
|
||||||
|
"pointer_length": 0.18, # PREVIOUSLY 0.15, # relative length of clock pointer
|
||||||
|
}
|
||||||
|
|
||||||
|
# control: For controlling plotting options and flow of script
|
||||||
|
# ========================================================
|
||||||
|
self.control = {
|
||||||
|
"force_torch_use_cpu": False, # force using CPU even if GPU available
|
||||||
|
"show_capture": True, # shows captured image
|
||||||
|
"show_object": True, # shows detected object
|
||||||
|
"show_contours": True, # shows extracted contours
|
||||||
|
"show_percept": True, # shows percept
|
||||||
|
}
|
||||||
|
|
||||||
|
# specify classes to detect
|
||||||
|
class_person = 0
|
||||||
|
self.classes_detect = [class_person]
|
||||||
|
|
||||||
|
print(
|
||||||
|
"OE-Init: Defining paths, creating dirs, setting default device and datatype"
|
||||||
|
)
|
||||||
|
|
||||||
|
# path: Path infos for input and output images
|
||||||
|
# ========================================================
|
||||||
|
self.path = {"output": "test/output/level1/", "input": "test/images_test/"}
|
||||||
|
# Make output directories, if necessary: the place were we dump the new images to...
|
||||||
|
# os.makedirs(self.path["output"], mode=0o777, exist_ok=True)
|
||||||
|
|
||||||
|
# Check if GPU is available and use it, if possible
|
||||||
|
# =================================================
|
||||||
|
self.default_dtype = torch.float32
|
||||||
|
torch.set_default_dtype(self.default_dtype)
|
||||||
|
if self.control["force_torch_use_cpu"]:
|
||||||
|
torch_device: str = "cpu"
|
||||||
|
else:
|
||||||
|
torch_device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
print(f"Using {torch_device} as TORCH device...")
|
||||||
|
self.torch_device = torch_device
|
||||||
|
|
||||||
|
print("OE-Init: Compute display scaling factors and padding RGB values")
|
||||||
|
|
||||||
|
# global scaling factors for all pixel-related length scales
|
||||||
|
self.display_size_max_x_PIX: float = (
|
||||||
|
self.display["size_max_x_DVA"] * self.display["PIX_per_DVA"]
|
||||||
|
)
|
||||||
|
self.display_size_max_y_PIX: float = (
|
||||||
|
self.display["size_max_y_DVA"] * self.display["PIX_per_DVA"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# determine padding fill value
|
||||||
|
tmp = tv.transforms.Grayscale(num_output_channels=1)
|
||||||
|
tmp_value = torch.full((3, 1, 1), 0)
|
||||||
|
self.padding_fill: int = int(tmp(tmp_value).squeeze())
|
||||||
|
|
||||||
|
print(f"OE-Init: Opening camera source or video file '{source}'")
|
||||||
|
|
||||||
|
# open source
|
||||||
|
self.cap = WebCam(source)
|
||||||
|
if not self.cap.open_cam():
|
||||||
|
raise OSError(f"Opening source {source} failed!")
|
||||||
|
|
||||||
|
# get the video frame size, frame count and frame rate
|
||||||
|
frame_width = self.cap.cap_frame_width
|
||||||
|
frame_height = self.cap.cap_frame_height
|
||||||
|
fps = self.cap.cap_fps
|
||||||
|
print(
|
||||||
|
f"OE-Init: Processing frames of {frame_width} x {frame_height} @ {fps} fps."
|
||||||
|
)
|
||||||
|
|
||||||
|
# open output file if we want to save frames
|
||||||
|
# if output_file != None:
|
||||||
|
# out = cv2.VideoWriter(
|
||||||
|
# output_file,
|
||||||
|
# cv2.VideoWriter_fourcc(*"MJPG"),
|
||||||
|
# fps,
|
||||||
|
# (out_x, out_y),
|
||||||
|
# )
|
||||||
|
# if out == None:
|
||||||
|
# raise OSError(f"Can not open file {output_file} for writing!")
|
||||||
|
|
||||||
|
# get an instance of the Yolo segmentation network
|
||||||
|
print("OE-Init: initialize YOLO")
|
||||||
|
self.yolo = Yolo5Segmentation(torch_device=self.torch_device)
|
||||||
|
|
||||||
|
self.send_dictionaries = False
|
||||||
|
|
||||||
|
self.apply_parameter_changes()
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
def apply_parameter_changes(self):
|
||||||
|
|
||||||
|
# GET NEW PARAMETERS
|
||||||
|
print("OE-AppParChg: Computing sizes from new parameters")
|
||||||
|
|
||||||
|
### BLOCK: dictionary ----------------
|
||||||
|
# set patch size for both dictionaries, make sure it is odd number
|
||||||
|
dictionary_size_PIX: int = (
|
||||||
|
1
|
||||||
|
+ (int(self.dictionary["size_DVA"] * self.display["PIX_per_DVA"]) // 2) * 2
|
||||||
|
)
|
||||||
|
|
||||||
|
### BLOCK: gabor ---------------------
|
||||||
|
# convert contour-related parameters to pixel units
|
||||||
|
self.sigma_kernel_PIX: float = (
|
||||||
|
self.gabor["sigma_kernel_DVA"] * self.display["PIX_per_DVA"]
|
||||||
|
)
|
||||||
|
self.lambda_kernel_PIX: float = (
|
||||||
|
self.gabor["lambda_kernel_DVA"] * self.display["PIX_per_DVA"]
|
||||||
|
)
|
||||||
|
|
||||||
|
### BLOCK: gabor & dictionary ------------------
|
||||||
|
# Padding
|
||||||
|
# -------
|
||||||
|
self.padding_PIX: int = int(
|
||||||
|
max(3.0 * self.sigma_kernel_PIX, 1.1 * dictionary_size_PIX)
|
||||||
|
)
|
||||||
|
|
||||||
|
# define target video/representation width/height
|
||||||
|
multiple_of = 4
|
||||||
|
out_x = self.display_size_max_x_PIX + 2 * self.padding_PIX
|
||||||
|
out_y = self.display_size_max_y_PIX + 2 * self.padding_PIX
|
||||||
|
out_x += (multiple_of - (out_x % multiple_of)) % multiple_of
|
||||||
|
out_y += (multiple_of - (out_y % multiple_of)) % multiple_of
|
||||||
|
self.out_x = int(out_x)
|
||||||
|
self.out_y = int(out_y)
|
||||||
|
|
||||||
|
# generate dictionaries
|
||||||
|
# ---------------------
|
||||||
|
### BLOCK: dictionary --------------------------
|
||||||
|
print("OE-AppParChg: Generating dictionaries...")
|
||||||
|
patch_generator = PatchGenerator(torch_device=self.torch_device)
|
||||||
|
self.phosphene = patch_generator.alphabet_phosphene(
|
||||||
|
patch_size=dictionary_size_PIX,
|
||||||
|
sigma_width=self.dictionary["phosphene"]["sigma_width"]
|
||||||
|
* dictionary_size_PIX,
|
||||||
|
)
|
||||||
|
### BLOCK: dictionary & gabor --------------------------
|
||||||
|
self.clocks_filter, self.clocks, segments = patch_generator.alphabet_clocks(
|
||||||
|
patch_size=dictionary_size_PIX,
|
||||||
|
n_dir=self.dictionary["clocks"]["n_dir"],
|
||||||
|
n_filter=self.gabor["n_orientations"],
|
||||||
|
segment_width=self.dictionary["clocks"]["pointer_width"]
|
||||||
|
* dictionary_size_PIX,
|
||||||
|
segment_length=self.dictionary["clocks"]["pointer_length"]
|
||||||
|
* dictionary_size_PIX,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.send_dictionaries = True
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
# classes_detect, out_x, out_y
|
||||||
|
def update(self, data_in):
|
||||||
|
|
||||||
|
# handle parameter change
|
||||||
|
|
||||||
|
if data_in:
|
||||||
|
|
||||||
|
print("Incoming -----------> ", data_in)
|
||||||
|
|
||||||
|
self.number_of_patches = data_in["number_of_patches"]
|
||||||
|
|
||||||
|
self.classes_detect = data_in["value"]
|
||||||
|
|
||||||
|
self.gabor["sigma_kernel_DVA"] = data_in["sigma_kernel_DVA"]
|
||||||
|
self.gabor["lambda_kernel_DVA"] = data_in["sigma_kernel_DVA"] * 2
|
||||||
|
self.gabor["n_orientations"] = data_in["n_orientations"]
|
||||||
|
|
||||||
|
self.dictionary["size_DVA"] = data_in["size_DVA"]
|
||||||
|
self.dictionary["phosphene"]["sigma_width"] = data_in["sigma_width"]
|
||||||
|
self.dictionary["clocks"]["n_dir"] = data_in["n_dir"]
|
||||||
|
self.dictionary["clocks"]["n_open"] = data_in["n_dir"] // 2
|
||||||
|
self.dictionary["clocks"]["pointer_width"] = data_in["pointer_width"]
|
||||||
|
self.dictionary["clocks"]["pointer_length"] = data_in["pointer_length"]
|
||||||
|
|
||||||
|
self.encoding["use_exp_deadzone"] = data_in["use_exp_deadzone"]
|
||||||
|
self.encoding["size_exp_deadzone_DVA"] = data_in["size_exp_deadzone_DVA"]
|
||||||
|
self.encoding["use_cutout_deadzone"] = data_in["use_cutout_deadzone"]
|
||||||
|
self.encoding["size_cutout_deadzone_DVA"] = data_in[
|
||||||
|
"size_cutout_deadzone_DVA"
|
||||||
|
]
|
||||||
|
|
||||||
|
self.control["show_capture"] = data_in["enable_cam"]
|
||||||
|
self.control["show_object"] = data_in["enable_yolo"]
|
||||||
|
self.control["show_contours"] = data_in["enable_contour"]
|
||||||
|
# TODO Fenster zumachen
|
||||||
|
self.apply_parameter_changes()
|
||||||
|
|
||||||
|
# some constants for addressing specific components of output arrays
|
||||||
|
image_id_CONST: int = 0
|
||||||
|
overlap_index_CONST: int = 1
|
||||||
|
|
||||||
|
# format: color_RGB, height, width <class 'torch.tensor'> float, range=0,1
|
||||||
|
print("OE-ProcessFrame: capturing frame")
|
||||||
|
frame = self.cap.get_frame()
|
||||||
|
if frame == None:
|
||||||
|
raise OSError(f"Can not capture frame {i_frame}")
|
||||||
|
if self.verbose:
|
||||||
|
if self.control["show_capture"]:
|
||||||
|
show_torch_frame(frame, title="Captured", target=self.verbose)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
cv2.destroyWindow("Captured")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# perform segmentation
|
||||||
|
|
||||||
|
frame = frame.to(device=self.torch_device)
|
||||||
|
print("OE-ProcessFrame: frame segmentation by YOLO")
|
||||||
|
frame_segmented = self.yolo(frame.unsqueeze(0), classes=self.classes_detect)
|
||||||
|
|
||||||
|
# This extracts the frame in x to convert the mask in a video format
|
||||||
|
if self.yolo.found_class_id != None:
|
||||||
|
|
||||||
|
n_found = len(self.yolo.found_class_id)
|
||||||
|
print(
|
||||||
|
f"OE-ProcessFrame: {n_found} occurrences of desired object found in frame!"
|
||||||
|
)
|
||||||
|
|
||||||
|
mask = frame_segmented[0]
|
||||||
|
|
||||||
|
# is there something in the mask?
|
||||||
|
if not mask.sum() == 0:
|
||||||
|
|
||||||
|
# yes, cut only the part of the frame that has our object of interest
|
||||||
|
frame_masked = mask * frame
|
||||||
|
|
||||||
|
x_height = mask.sum(axis=-2)
|
||||||
|
x_indices = torch.where(x_height > 0)
|
||||||
|
x_max = x_indices[0].max() + 1
|
||||||
|
x_min = x_indices[0].min()
|
||||||
|
|
||||||
|
y_height = mask.sum(axis=-1)
|
||||||
|
y_indices = torch.where(y_height > 0)
|
||||||
|
y_max = y_indices[0].max() + 1
|
||||||
|
y_min = y_indices[0].min()
|
||||||
|
|
||||||
|
frame_cut = frame_masked[:, y_min:y_max, x_min:x_max]
|
||||||
|
else:
|
||||||
|
print(f"OE-ProcessFrame: Mask contains all zeros in current frame!")
|
||||||
|
frame_cut = None
|
||||||
|
else:
|
||||||
|
print(f"OE-ProcessFrame: No objects found in current frame!")
|
||||||
|
frame_cut = None
|
||||||
|
|
||||||
|
if frame_cut == None:
|
||||||
|
# out_torch = torch.zeros([self.out_y, self.out_x])
|
||||||
|
position_selection = torch.zeros((1, 0, 3))
|
||||||
|
contour_shape = [1, self.gabor["n_orientations"], 1, 1]
|
||||||
|
else:
|
||||||
|
if self.verbose:
|
||||||
|
if self.control["show_object"]:
|
||||||
|
show_torch_frame(
|
||||||
|
frame_cut, title="Selected Object", target=self.verbose
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
cv2.destroyWindow("Selected Object")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# UDO: from here on, we proceed as before, just handing
|
||||||
|
# UDO: over the frame_cut --> image
|
||||||
|
image = frame_cut
|
||||||
|
|
||||||
|
# Determine target size of image
|
||||||
|
# image: [RGB, Height, Width], dtype= tensor.torch.uint8
|
||||||
|
print("OE-ProcessFrame: Computing downsampling factor image -> display")
|
||||||
|
f_x: float = self.display_size_max_x_PIX / image.shape[-1]
|
||||||
|
f_y: float = self.display_size_max_y_PIX / image.shape[-2]
|
||||||
|
f_xy_min: float = min(f_x, f_y)
|
||||||
|
downsampling_x: int = int(f_xy_min * image.shape[-1])
|
||||||
|
downsampling_y: int = int(f_xy_min * image.shape[-2])
|
||||||
|
|
||||||
|
# CURRENTLY we do not crop in the end...
|
||||||
|
# Image size for removing the fft crop later
|
||||||
|
# center_crop_x: int = downsampling_x
|
||||||
|
# center_crop_y: int = downsampling_y
|
||||||
|
|
||||||
|
# define contour extraction processing chain
|
||||||
|
# ------------------------------------------
|
||||||
|
print("OE-ProcessFrame: Extracting contours")
|
||||||
|
train_processing_chain = tv.transforms.Compose(
|
||||||
|
transforms=[
|
||||||
|
tv.transforms.Grayscale(num_output_channels=1), # RGB to grayscale
|
||||||
|
tv.transforms.Resize(
|
||||||
|
size=(downsampling_y, downsampling_x)
|
||||||
|
), # downsampling
|
||||||
|
tv.transforms.Pad( # extra white padding around the picture
|
||||||
|
padding=(self.padding_PIX, self.padding_PIX),
|
||||||
|
fill=self.padding_fill,
|
||||||
|
),
|
||||||
|
ContourExtract( # contour extraction
|
||||||
|
n_orientations=self.gabor["n_orientations"],
|
||||||
|
sigma_kernel=self.sigma_kernel_PIX,
|
||||||
|
lambda_kernel=self.lambda_kernel_PIX,
|
||||||
|
torch_device=self.torch_device,
|
||||||
|
),
|
||||||
|
# CURRENTLY we do not crop in the end!
|
||||||
|
# tv.transforms.CenterCrop( # Remove the padding
|
||||||
|
# size=(center_crop_x, center_crop_y)
|
||||||
|
# ),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# ...with and without orientation channels
|
||||||
|
contour = train_processing_chain(image.unsqueeze(0))
|
||||||
|
contour_collapse = train_processing_chain.transforms[-1].create_collapse(
|
||||||
|
contour
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
if self.control["show_contours"]:
|
||||||
|
show_torch_frame(
|
||||||
|
contour_collapse,
|
||||||
|
title="Contours Extracted",
|
||||||
|
cmap="gray",
|
||||||
|
target=self.verbose,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
cv2.destroyWindow("Contours Extracted")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# generate a prior for mapping the contour to the dictionary
|
||||||
|
# CURRENTLY we use an uniform prior...
|
||||||
|
# ----------------------------------------------------------
|
||||||
|
dictionary_prior = torch.ones(
|
||||||
|
(self.clocks_filter.shape[0]),
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
device=torch.device(self.torch_device),
|
||||||
|
)
|
||||||
|
|
||||||
|
# instantiate and execute sparsifier
|
||||||
|
# ----------------------------------
|
||||||
|
print("OE-ProcessFrame: Performing sparsification")
|
||||||
|
sparsifier = Sparsifier(
|
||||||
|
dictionary_filter=self.clocks_filter,
|
||||||
|
dictionary=self.clocks,
|
||||||
|
dictionary_prior=dictionary_prior,
|
||||||
|
number_of_patches=self.encoding["n_patches_compute"],
|
||||||
|
size_exp_deadzone=self.encoding["size_exp_deadzone_DVA"]
|
||||||
|
* self.display["PIX_per_DVA"],
|
||||||
|
plot_use_map=False, # self.control["plot_deadzone"],
|
||||||
|
deadzone_exp=self.encoding["use_exp_deadzone"],
|
||||||
|
deadzone_hard_cutout=self.encoding["use_cutout_deadzone"],
|
||||||
|
deadzone_hard_cutout_size=self.encoding["size_cutout_deadzone_DVA"]
|
||||||
|
* self.display["PIX_per_DVA"],
|
||||||
|
padding_deadzone_size_x=self.padding_PIX,
|
||||||
|
padding_deadzone_size_y=self.padding_PIX,
|
||||||
|
torch_device=self.torch_device,
|
||||||
|
)
|
||||||
|
sparsifier(contour)
|
||||||
|
assert sparsifier.position_found is not None
|
||||||
|
|
||||||
|
# extract and normalize the overlap found
|
||||||
|
overlap_found = sparsifier.overlap_found[
|
||||||
|
image_id_CONST, :, overlap_index_CONST
|
||||||
|
]
|
||||||
|
overlap_found = overlap_found / overlap_found.max()
|
||||||
|
|
||||||
|
# get overlap above certain threshold, extract corresponding elements
|
||||||
|
overlap_idcs_valid = torch.where(
|
||||||
|
overlap_found >= self.encoding["overlap_threshold"]
|
||||||
|
)[0]
|
||||||
|
position_selection = sparsifier.position_found[
|
||||||
|
image_id_CONST : image_id_CONST + 1, overlap_idcs_valid, :
|
||||||
|
]
|
||||||
|
n_elements = len(overlap_idcs_valid)
|
||||||
|
print(f"OE-ProcessFrame: {n_elements} elements positioned!")
|
||||||
|
|
||||||
|
contour_shape = contour.shape
|
||||||
|
|
||||||
|
n_cut = min(position_selection.shape[-2], self.number_of_patches)
|
||||||
|
|
||||||
|
data_out = {
|
||||||
|
"position_found": position_selection[:, :n_cut, :],
|
||||||
|
"canvas_size": contour_shape,
|
||||||
|
}
|
||||||
|
if self.send_dictionaries:
|
||||||
|
data_out["features"] = self.clocks
|
||||||
|
data_out["phosphene"] = self.phosphene
|
||||||
|
self.send_dictionaries = False
|
||||||
|
|
||||||
|
return data_out
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
|
||||||
|
print("OE-Delete: exiting gracefully!")
|
||||||
|
self.cap.close_cam()
|
||||||
|
try:
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# TODO no output file
|
||||||
|
# TODO detect end of file if input is video file
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
verbose = "cv2"
|
||||||
|
source = 0 # "GoProWireless"
|
||||||
|
frame_count = 20
|
||||||
|
i_frame = 0
|
||||||
|
|
||||||
|
data_in = None
|
||||||
|
|
||||||
|
oe = OnlineEncoding(source=source, verbose=verbose)
|
||||||
|
|
||||||
|
# Loop over the frames
|
||||||
|
while i_frame < frame_count:
|
||||||
|
|
||||||
|
i_frame += 1
|
||||||
|
|
||||||
|
if i_frame == (frame_count // 3):
|
||||||
|
oe.dictionary["size_DVA"] = 0.5
|
||||||
|
oe.apply_parameter_changes()
|
||||||
|
|
||||||
|
if i_frame == (frame_count * 2 // 3):
|
||||||
|
oe.dictionary["size_DVA"] = 2.0
|
||||||
|
oe.apply_parameter_changes()
|
||||||
|
|
||||||
|
data_out = oe.update(data_in)
|
||||||
|
position_selection = data_out["position_found"]
|
||||||
|
contour_shape = data_out["canvas_size"]
|
||||||
|
|
||||||
|
# SENDE/EMPANGSLOGIK:
|
||||||
|
#
|
||||||
|
# <- PACKET empfangen
|
||||||
|
# Parameteränderungen?
|
||||||
|
# in Instanz se übertragen
|
||||||
|
# "apply_parameter_changes" aufrufen
|
||||||
|
# folgende variablen in sendepacket:
|
||||||
|
# se.clocks, se.phosphene, se.out_x, se.out_y
|
||||||
|
# "process_frame"
|
||||||
|
# folgende variablen in sendepacket:
|
||||||
|
# position_selection, contour_shape
|
||||||
|
# -> PACKET zurückgeben
|
||||||
|
|
||||||
|
# build the full image!
|
||||||
|
image_clocks = BuildImage(
|
||||||
|
canvas_size=contour_shape,
|
||||||
|
dictionary=oe.clocks,
|
||||||
|
position_found=position_selection,
|
||||||
|
default_dtype=oe.default_dtype,
|
||||||
|
torch_device=oe.torch_device,
|
||||||
|
)
|
||||||
|
# image_phosphenes = BuildImage(
|
||||||
|
# canvas_size=contour.shape,
|
||||||
|
# dictionary=dictionary_phosphene,
|
||||||
|
# position_found=position_selection,
|
||||||
|
# default_dtype=default_dtype,
|
||||||
|
# torch_device=torch_device,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# normalize to range [0...1]
|
||||||
|
m = image_clocks[0].max()
|
||||||
|
if m == 0:
|
||||||
|
m = 1
|
||||||
|
image_clocks_normalized = image_clocks[0] / m
|
||||||
|
|
||||||
|
# embed into frame of desired output size
|
||||||
|
out_torch = embed_image(
|
||||||
|
image_clocks_normalized, out_height=oe.out_y, out_width=oe.out_x
|
||||||
|
)
|
||||||
|
|
||||||
|
# show, if desired
|
||||||
|
if verbose:
|
||||||
|
if oe.control["show_percept"]:
|
||||||
|
show_torch_frame(
|
||||||
|
out_torch, title="Percept", cmap="gray", target=verbose
|
||||||
|
)
|
||||||
|
|
||||||
|
# if output_file != None:
|
||||||
|
# out_pixel = (
|
||||||
|
# (out_torch * torch.ones([3, 1, 1]) * 255)
|
||||||
|
# .type(dtype=torch.uint8)
|
||||||
|
# .movedim(0, -1)
|
||||||
|
# .numpy()
|
||||||
|
# )
|
||||||
|
# out.write(out_pixel)
|
||||||
|
|
||||||
|
del oe
|
||||||
|
|
||||||
|
# if output_file != None:
|
||||||
|
# out.release()
|
||||||
|
|
||||||
|
# %%
|
299
processing_chain/OnlinePerception.py
Normal file
299
processing_chain/OnlinePerception.py
Normal file
|
@ -0,0 +1,299 @@
|
||||||
|
#%%
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from processing_chain.BuildImage import BuildImage
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import time
|
||||||
|
import cv2
|
||||||
|
from gui.GUIEvents import GUIEvents
|
||||||
|
from gui.GUICombiData import GUICombiData
|
||||||
|
import tkinter as tk
|
||||||
|
|
||||||
|
|
||||||
|
# TODO required?
|
||||||
|
def embed_image(frame_torch, out_height, out_width, torch_device, init_value=0):
|
||||||
|
|
||||||
|
out_shape = torch.tensor(frame_torch.shape)
|
||||||
|
|
||||||
|
frame_width = frame_torch.shape[-1]
|
||||||
|
frame_height = frame_torch.shape[-2]
|
||||||
|
|
||||||
|
frame_width_idx0 = max([0, (frame_width - out_width) // 2])
|
||||||
|
frame_height_idx0 = max([0, (frame_height - out_height) // 2])
|
||||||
|
|
||||||
|
select_width = min([frame_width, out_width])
|
||||||
|
select_height = min([frame_height, out_height])
|
||||||
|
|
||||||
|
out_shape[-1] = out_width
|
||||||
|
out_shape[-2] = out_height
|
||||||
|
|
||||||
|
out_torch = init_value * torch.ones(tuple(out_shape), device=torch_device)
|
||||||
|
|
||||||
|
out_width_idx0 = max([0, (out_width - frame_width) // 2])
|
||||||
|
out_height_idx0 = max([0, (out_height - frame_height) // 2])
|
||||||
|
|
||||||
|
out_torch[
|
||||||
|
...,
|
||||||
|
out_height_idx0 : (out_height_idx0 + select_height),
|
||||||
|
out_width_idx0 : (out_width_idx0 + select_width),
|
||||||
|
] = frame_torch[
|
||||||
|
...,
|
||||||
|
frame_height_idx0 : (frame_height_idx0 + select_height),
|
||||||
|
frame_width_idx0 : (frame_width_idx0 + select_width),
|
||||||
|
]
|
||||||
|
|
||||||
|
return out_torch
|
||||||
|
|
||||||
|
|
||||||
|
class OnlinePerception:
|
||||||
|
|
||||||
|
# SELFies...
|
||||||
|
#
|
||||||
|
# torch_device, default_dtype (fixed)
|
||||||
|
# canvas_size, features, phosphene, position_found (parameters)
|
||||||
|
# percept
|
||||||
|
#
|
||||||
|
# root, events, confdata, use_gui
|
||||||
|
#
|
||||||
|
|
||||||
|
def __init__(self, target, use_gui=False):
|
||||||
|
|
||||||
|
# CPU or GPU?
|
||||||
|
self.default_dtype = torch.float32
|
||||||
|
torch.set_default_dtype(self.default_dtype)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.torch_device = torch.device("cuda")
|
||||||
|
else:
|
||||||
|
self.torch_device = torch.device("cpu")
|
||||||
|
|
||||||
|
self.use_gui = use_gui
|
||||||
|
if self.use_gui:
|
||||||
|
self.root = tk.Tk()
|
||||||
|
self.confdata: GUICombiData = GUICombiData()
|
||||||
|
self.events = GUIEvents(tk_root=self.root, confdata=self.confdata)
|
||||||
|
|
||||||
|
# default dictionary parameters
|
||||||
|
n_xy_canvas = 400
|
||||||
|
n_xy_features = 41
|
||||||
|
n_features = 32
|
||||||
|
n_positions = 1
|
||||||
|
|
||||||
|
# populate internal parameters
|
||||||
|
self.canvas_size = [1, 8, n_xy_canvas, n_xy_canvas]
|
||||||
|
self.features = torch.rand(
|
||||||
|
(n_features, 1, n_xy_features, n_xy_features), device=self.torch_device
|
||||||
|
)
|
||||||
|
self.phosphene = torch.ones(
|
||||||
|
(1, 1, n_xy_features, n_xy_features), device=self.torch_device
|
||||||
|
)
|
||||||
|
self.position_found = torch.zeros((1, n_positions, 3), device=self.torch_device)
|
||||||
|
self.position_found[0, :, 0] = torch.randint(n_features, (n_positions,))
|
||||||
|
self.position_found[0, :, 1:] = torch.randint(n_xy_canvas, (n_positions, 2))
|
||||||
|
self.percept = torch.zeros(
|
||||||
|
(1, 1, n_xy_canvas, n_xy_canvas), device=self.torch_device
|
||||||
|
)
|
||||||
|
|
||||||
|
self.p_FEATperSECperPOS = 3.0
|
||||||
|
self.tau_SEC = 0.3 # percept time constant
|
||||||
|
self.t = time.time()
|
||||||
|
|
||||||
|
self.target = target # display target
|
||||||
|
self.display = target
|
||||||
|
|
||||||
|
self.selection = 0
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
def update(self, data_in: dict):
|
||||||
|
|
||||||
|
if data_in: # not NONE?
|
||||||
|
print(f"Parameter update requested for keys {data_in.keys()}!")
|
||||||
|
if "position_found" in data_in.keys():
|
||||||
|
self.position_found = data_in["position_found"]
|
||||||
|
if "canvas_size" in data_in.keys():
|
||||||
|
self.canvas_size = data_in["canvas_size"]
|
||||||
|
self.percept = embed_image(
|
||||||
|
self.percept,
|
||||||
|
self.canvas_size[-2],
|
||||||
|
self.canvas_size[-1],
|
||||||
|
torch_device=self.torch_device,
|
||||||
|
init_value=0,
|
||||||
|
)
|
||||||
|
if "features" in data_in.keys():
|
||||||
|
print(self.features.shape, self.features.max(), self.features.min())
|
||||||
|
self.features = data_in["features"]
|
||||||
|
self.features /= self.features.max()
|
||||||
|
print(self.features.shape, self.features.max(), self.features.min())
|
||||||
|
if "phosphene" in data_in.keys():
|
||||||
|
self.phosphene = data_in["phosphene"]
|
||||||
|
self.phosphene /= self.phosphene.max()
|
||||||
|
|
||||||
|
# parameters of (optional) GUI changed?
|
||||||
|
data_out = None
|
||||||
|
if self.use_gui:
|
||||||
|
self.root.update()
|
||||||
|
if not self.confdata.gui_running:
|
||||||
|
data_out = {"exit": 42}
|
||||||
|
# graceful exit
|
||||||
|
else:
|
||||||
|
if self.confdata.check_for_change() is True:
|
||||||
|
|
||||||
|
data_out = {
|
||||||
|
"value": self.confdata.yolo_class.value,
|
||||||
|
"sigma_kernel_DVA": self.confdata.contour_extraction.sigma_kernel_DVA,
|
||||||
|
"n_orientations": self.confdata.contour_extraction.n_orientations,
|
||||||
|
"size_DVA": self.confdata.alphabet.size_DVA,
|
||||||
|
# "tau_SEC": self.confdata.alphabet.tau_SEC,
|
||||||
|
# "p_FEATperSECperPOS": self.confdata.alphabet.p_FEATperSECperPOS,
|
||||||
|
"sigma_width": self.confdata.alphabet.phosphene_sigma_width,
|
||||||
|
"n_dir": self.confdata.alphabet.clocks_n_dir,
|
||||||
|
"pointer_width": self.confdata.alphabet.clocks_pointer_width,
|
||||||
|
"pointer_length": self.confdata.alphabet.clocks_pointer_length,
|
||||||
|
"number_of_patches": self.confdata.sparsifier.number_of_patches,
|
||||||
|
"use_exp_deadzone": self.confdata.sparsifier.use_exp_deadzone,
|
||||||
|
"use_cutout_deadzone": self.confdata.sparsifier.use_cutout_deadzone,
|
||||||
|
"size_exp_deadzone_DVA": self.confdata.sparsifier.size_exp_deadzone_DVA,
|
||||||
|
"size_cutout_deadzone_DVA": self.confdata.sparsifier.size_cutout_deadzone_DVA,
|
||||||
|
"enable_cam": self.confdata.output_mode.enable_cam,
|
||||||
|
"enable_yolo": self.confdata.output_mode.enable_yolo,
|
||||||
|
"enable_contour": self.confdata.output_mode.enable_contour,
|
||||||
|
}
|
||||||
|
print(data_out)
|
||||||
|
|
||||||
|
self.p_FEATperSECperPOS = self.confdata.alphabet.p_FEATperSECperPOS
|
||||||
|
self.tau_SEC = self.confdata.alphabet.tau_SEC
|
||||||
|
|
||||||
|
# print(self.confdata.alphabet.selection)
|
||||||
|
self.selection = self.confdata.alphabet.selection
|
||||||
|
# print(f"Selektion gemacht {self.selection}")
|
||||||
|
|
||||||
|
# print(self.confdata.output_mode.enable_percept)
|
||||||
|
if self.confdata.output_mode.enable_percept:
|
||||||
|
self.display = self.target
|
||||||
|
else:
|
||||||
|
self.display = None
|
||||||
|
if self.target == "cv2":
|
||||||
|
cv2.destroyWindow("Percept")
|
||||||
|
|
||||||
|
self.confdata.reset_change_detector()
|
||||||
|
|
||||||
|
# keep track of time, yields dt
|
||||||
|
t_new = time.time()
|
||||||
|
dt = t_new - self.t
|
||||||
|
self.t = t_new
|
||||||
|
|
||||||
|
# exponential decay
|
||||||
|
self.percept *= torch.exp(-torch.tensor(dt / self.tau_SEC))
|
||||||
|
|
||||||
|
# new stimulation
|
||||||
|
p_dt = self.p_FEATperSECperPOS * dt
|
||||||
|
n_positions = self.position_found.shape[-2]
|
||||||
|
position_select = torch.rand((n_positions,), device=self.torch_device) < p_dt
|
||||||
|
n_select = position_select.sum()
|
||||||
|
if n_select > 0:
|
||||||
|
# print(f"Selektion ausgewertet {self.selection}")
|
||||||
|
if self.selection:
|
||||||
|
dictionary = self.features
|
||||||
|
else:
|
||||||
|
dictionary = self.phosphene
|
||||||
|
percept_addon = BuildImage(
|
||||||
|
canvas_size=self.canvas_size,
|
||||||
|
dictionary=dictionary,
|
||||||
|
position_found=self.position_found[:, position_select, :],
|
||||||
|
default_dtype=self.default_dtype,
|
||||||
|
torch_device=self.torch_device,
|
||||||
|
)
|
||||||
|
self.percept += percept_addon
|
||||||
|
|
||||||
|
# prepare for display
|
||||||
|
display = self.percept[0, 0].cpu().numpy()
|
||||||
|
if self.display == "cv2":
|
||||||
|
|
||||||
|
# display, and update
|
||||||
|
cv2.namedWindow("Percept", cv2.WINDOW_NORMAL)
|
||||||
|
cv2.imshow("Percept", display)
|
||||||
|
q = cv2.waitKey(1)
|
||||||
|
|
||||||
|
if self.display == "pyplot":
|
||||||
|
|
||||||
|
# display, RUNS SLOWLY, just for testing
|
||||||
|
plt.imshow(display, cmap="gray", vmin=0, vmax=1)
|
||||||
|
plt.show()
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
return data_out
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
print("Now I'm deleting me!")
|
||||||
|
try:
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
if self.use_gui:
|
||||||
|
print("...and the GUI!")
|
||||||
|
try:
|
||||||
|
if self.confdata.gui_running is True:
|
||||||
|
self.root.destroy()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
use_gui = True
|
||||||
|
|
||||||
|
data_in = None
|
||||||
|
op = OnlinePerception("cv2", use_gui=use_gui)
|
||||||
|
|
||||||
|
t_max = 40.0
|
||||||
|
dt_update = 10.0
|
||||||
|
t_update = dt_update
|
||||||
|
|
||||||
|
t = time.time()
|
||||||
|
t0 = t
|
||||||
|
while t - t0 < t_max:
|
||||||
|
|
||||||
|
data_out = op.update(data_in)
|
||||||
|
data_in = None
|
||||||
|
if data_out:
|
||||||
|
print("Output given!")
|
||||||
|
if "exit" in data_out.keys():
|
||||||
|
break
|
||||||
|
|
||||||
|
t = time.time()
|
||||||
|
if t - t0 > t_update:
|
||||||
|
|
||||||
|
# new canvas size
|
||||||
|
n_xy_canvas = int(400 + torch.randint(200, (1,)))
|
||||||
|
canvas_size = [1, 8, n_xy_canvas, n_xy_canvas]
|
||||||
|
|
||||||
|
# new features/phosphenes
|
||||||
|
n_features = int(16 + torch.randint(16, (1,)))
|
||||||
|
n_xy_features = int(31 + 2 * torch.randint(10, (1,)))
|
||||||
|
features = torch.rand(
|
||||||
|
(n_features, 1, n_xy_features, n_xy_features), device=op.torch_device
|
||||||
|
)
|
||||||
|
phosphene = torch.ones(
|
||||||
|
(1, 1, n_xy_features, n_xy_features), device=op.torch_device
|
||||||
|
)
|
||||||
|
|
||||||
|
# new positions
|
||||||
|
n_positions = int(1 + torch.randint(3, (1,)))
|
||||||
|
position_found = torch.zeros((1, n_positions, 3), device=op.torch_device)
|
||||||
|
position_found[0, :, 0] = torch.randint(n_features, (n_positions,))
|
||||||
|
position_found[0, :, 1] = torch.randint(n_xy_canvas, (n_positions,))
|
||||||
|
position_found[0, :, 2] = torch.randint(n_xy_canvas, (n_positions,))
|
||||||
|
t_update += dt_update
|
||||||
|
data_in = {
|
||||||
|
"position_found": position_found,
|
||||||
|
"canvas_size": canvas_size,
|
||||||
|
"features": features,
|
||||||
|
"phosphene": phosphene,
|
||||||
|
}
|
||||||
|
|
||||||
|
print("done!")
|
||||||
|
del op
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
237
processing_chain/PatchGenerator.py
Normal file
237
processing_chain/PatchGenerator.py
Normal file
|
@ -0,0 +1,237 @@
|
||||||
|
# %%
|
||||||
|
# PatchGenerator.py
|
||||||
|
# ====================================
|
||||||
|
# generates dictionaries (currently: phosphenes or clocks)
|
||||||
|
#
|
||||||
|
# Version V1.0, pre-07.03.2023:
|
||||||
|
# no actual changes, is David's last code version...
|
||||||
|
#
|
||||||
|
# Version V1.1, 07.03.2023:
|
||||||
|
# merged David's rebuild code (GUI capable)
|
||||||
|
# (there was not really anything to merge :-))
|
||||||
|
#
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class PatchGenerator:
|
||||||
|
|
||||||
|
pi: torch.Tensor
|
||||||
|
torch_device: torch.device
|
||||||
|
default_dtype = torch.float32
|
||||||
|
|
||||||
|
def __init__(self, torch_device: str = "cpu"):
|
||||||
|
self.torch_device = torch.device(torch_device)
|
||||||
|
|
||||||
|
self.pi = torch.tensor(
|
||||||
|
math.pi,
|
||||||
|
device=self.torch_device,
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def alphabet_phosphene(
|
||||||
|
self,
|
||||||
|
sigma_width: float = 2.5,
|
||||||
|
patch_size: int = 41,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
n: int = int(patch_size // 2)
|
||||||
|
temp_grid: torch.Tensor = torch.arange(
|
||||||
|
start=-n,
|
||||||
|
end=n + 1,
|
||||||
|
device=self.torch_device,
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
x, y = torch.meshgrid(temp_grid, temp_grid, indexing="ij")
|
||||||
|
|
||||||
|
phosphene: torch.Tensor = torch.exp(-(x**2 + y**2) / (2 * sigma_width**2))
|
||||||
|
phosphene /= phosphene.sum()
|
||||||
|
|
||||||
|
return phosphene.unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
|
def alphabet_clocks(
|
||||||
|
self,
|
||||||
|
n_dir: int = 8,
|
||||||
|
n_open: int = 4,
|
||||||
|
n_filter: int = 4,
|
||||||
|
patch_size: int = 41,
|
||||||
|
segment_width: float = 2.5,
|
||||||
|
segment_length: float = 15.0,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
# for n_dir directions, there are n_open_max opening angles possible...
|
||||||
|
assert n_dir % 2 == 0, "n_dir must be multiple of 2"
|
||||||
|
n_open_max: int = n_dir // 2
|
||||||
|
|
||||||
|
# ...but this number can be reduced by integer multiples:
|
||||||
|
assert (
|
||||||
|
n_open_max % n_open == 0
|
||||||
|
), "n_open_max = n_dir // 2 must be multiple of n_open"
|
||||||
|
mul_open: int = n_open_max // n_open
|
||||||
|
|
||||||
|
# filter planes must be multiple of number of orientations implied by n_dir
|
||||||
|
assert n_filter % n_open_max == 0, "n_filter must be multiple of (n_dir // 2)"
|
||||||
|
mul_filter: int = n_filter // n_open_max
|
||||||
|
# compute single segments
|
||||||
|
segments: torch.Tensor = torch.zeros(
|
||||||
|
(n_dir, patch_size, patch_size),
|
||||||
|
device=self.torch_device,
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
)
|
||||||
|
dirs: torch.Tensor = (
|
||||||
|
2
|
||||||
|
* self.pi
|
||||||
|
* torch.arange(
|
||||||
|
start=0, end=n_dir, device=self.torch_device, dtype=self.default_dtype
|
||||||
|
)
|
||||||
|
/ n_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
for i_dir in range(n_dir):
|
||||||
|
segments[i_dir] = self.draw_segment(
|
||||||
|
patch_size=patch_size,
|
||||||
|
phi=float(dirs[i_dir]),
|
||||||
|
segment_length=segment_length,
|
||||||
|
segment_width=segment_width,
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute patches from segments
|
||||||
|
clocks = torch.zeros(
|
||||||
|
(n_open, n_dir, patch_size, patch_size),
|
||||||
|
device=self.torch_device,
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
)
|
||||||
|
clocks_filter = torch.zeros(
|
||||||
|
(n_open, n_dir, n_filter, patch_size, patch_size),
|
||||||
|
device=self.torch_device,
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i_dir in range(n_dir):
|
||||||
|
for i_open in range(n_open):
|
||||||
|
|
||||||
|
seg1 = segments[i_dir]
|
||||||
|
seg2 = segments[(i_dir + (i_open + 1) * mul_open) % n_dir]
|
||||||
|
clocks[i_open, i_dir] = torch.where(seg1 > seg2, seg1, seg2)
|
||||||
|
|
||||||
|
i_filter_seg1 = (i_dir * mul_filter) % n_filter
|
||||||
|
i_filter_seg2 = (
|
||||||
|
(i_dir + (i_open + 1) * mul_open) * mul_filter
|
||||||
|
) % n_filter
|
||||||
|
|
||||||
|
if i_filter_seg1 == i_filter_seg2:
|
||||||
|
clock_merged = torch.where(seg1 > seg2, seg1, seg2)
|
||||||
|
clocks_filter[i_open, i_dir, i_filter_seg1] = clock_merged
|
||||||
|
else:
|
||||||
|
clocks_filter[i_open, i_dir, i_filter_seg1] = seg1
|
||||||
|
clocks_filter[i_open, i_dir, i_filter_seg2] = seg2
|
||||||
|
|
||||||
|
clocks_filter = clocks_filter.reshape(
|
||||||
|
(n_open * n_dir, n_filter, patch_size, patch_size)
|
||||||
|
)
|
||||||
|
clocks_filter = clocks_filter / clocks_filter.sum(
|
||||||
|
axis=(-3, -2, -1), keepdims=True
|
||||||
|
)
|
||||||
|
clocks = clocks.reshape((n_open * n_dir, 1, patch_size, patch_size))
|
||||||
|
clocks = clocks / clocks.sum(axis=(-2, -1), keepdims=True)
|
||||||
|
|
||||||
|
return clocks_filter, clocks, segments
|
||||||
|
|
||||||
|
def draw_segment(
|
||||||
|
self,
|
||||||
|
patch_size: float,
|
||||||
|
phi: float,
|
||||||
|
segment_length: float,
|
||||||
|
segment_width: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# extension of patch beyond origin
|
||||||
|
n: int = int(patch_size // 2)
|
||||||
|
|
||||||
|
# grid for the patch
|
||||||
|
temp_grid: torch.Tensor = torch.arange(
|
||||||
|
start=-n,
|
||||||
|
end=n + 1,
|
||||||
|
device=self.torch_device,
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
x, y = torch.meshgrid(temp_grid, temp_grid, indexing="ij")
|
||||||
|
|
||||||
|
r: torch.Tensor = torch.cat((x.unsqueeze(-1), y.unsqueeze(-1)), dim=2)
|
||||||
|
|
||||||
|
# target orientation of basis segment
|
||||||
|
phi90: torch.Tensor = phi + self.pi / 2
|
||||||
|
|
||||||
|
# vector pointing to the ending point of segment (direction)
|
||||||
|
#
|
||||||
|
# when result is displayed with plt.imshow(segment),
|
||||||
|
# phi=0 points to the right, and increasing phi rotates
|
||||||
|
# the segment counterclockwise
|
||||||
|
#
|
||||||
|
e: torch.Tensor = torch.tensor(
|
||||||
|
[torch.cos(phi90), torch.sin(phi90)],
|
||||||
|
device=self.torch_device,
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# tangential vectors
|
||||||
|
e_tang: torch.Tensor = e.flip(dims=[0]) * torch.tensor(
|
||||||
|
[-1, 1], device=self.torch_device, dtype=self.default_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute distances to segment: parallel/tangential
|
||||||
|
d = torch.maximum(
|
||||||
|
torch.zeros(
|
||||||
|
(r.shape[0], r.shape[1]),
|
||||||
|
device=self.torch_device,
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
),
|
||||||
|
torch.abs(
|
||||||
|
(r * e.unsqueeze(0).unsqueeze(0)).sum(dim=-1) - segment_length / 2
|
||||||
|
)
|
||||||
|
- segment_length / 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
d_tang = (r * e_tang.unsqueeze(0).unsqueeze(0)).sum(dim=-1)
|
||||||
|
|
||||||
|
# compute minimum distance to any of the two pointers
|
||||||
|
dr = torch.sqrt(d**2 + d_tang**2)
|
||||||
|
|
||||||
|
segment = torch.exp(-(dr**2) / 2 / segment_width**2)
|
||||||
|
segment = segment / segment.sum()
|
||||||
|
|
||||||
|
return segment
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
pg = PatchGenerator()
|
||||||
|
|
||||||
|
patch1 = pg.draw_segment(
|
||||||
|
patch_size=81, phi=math.pi / 4, segment_length=30, segment_width=5
|
||||||
|
)
|
||||||
|
patch2 = pg.draw_segment(patch_size=81, phi=0, segment_length=30, segment_width=5)
|
||||||
|
plt.imshow(torch.where(patch1 > patch2, patch1, patch2).cpu())
|
||||||
|
|
||||||
|
phos = pg.alphabet_phosphene()
|
||||||
|
plt.imshow(phos[0, 0].cpu())
|
||||||
|
|
||||||
|
n_filter = 8
|
||||||
|
n_dir = 8
|
||||||
|
clocks_filter, clocks, segments = pg.alphabet_clocks(n_dir=n_dir, n_filter=n_filter)
|
||||||
|
|
||||||
|
n_features = clocks_filter.shape[0]
|
||||||
|
print(n_features, "clock features generated!")
|
||||||
|
for i_feature in range(n_features):
|
||||||
|
for i_filter in range(n_filter):
|
||||||
|
plt.subplot(1, n_filter, i_filter + 1)
|
||||||
|
plt.imshow(clocks_filter[i_feature, i_filter].cpu())
|
||||||
|
plt.title("Feature #{}, Dir #{}".format(i_feature, i_filter))
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
418
processing_chain/Sparsifier.py
Normal file
418
processing_chain/Sparsifier.py
Normal file
|
@ -0,0 +1,418 @@
|
||||||
|
# Sparsifier.py
|
||||||
|
# ====================================
|
||||||
|
# matches dictionary patches to contour images
|
||||||
|
#
|
||||||
|
# Version V1.0, 07.03.2023:
|
||||||
|
# slight parameter scaling changes to David's last code version...
|
||||||
|
#
|
||||||
|
# Version V1.1, 07.03.2023:
|
||||||
|
# merged David's rebuild code (GUI capable)
|
||||||
|
#
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class Sparsifier(torch.nn.Module):
|
||||||
|
|
||||||
|
dictionary_filter_fft: torch.Tensor | None = None
|
||||||
|
dictionary_filter: torch.Tensor
|
||||||
|
dictionary: torch.Tensor
|
||||||
|
|
||||||
|
parameter_ready: bool
|
||||||
|
dictionary_ready: bool
|
||||||
|
|
||||||
|
contour_convolved_sum: torch.Tensor | None = None
|
||||||
|
use_map: torch.Tensor | None = None
|
||||||
|
position_found: torch.Tensor | None = None
|
||||||
|
|
||||||
|
size_exp_deadzone: float
|
||||||
|
|
||||||
|
number_of_patches: int
|
||||||
|
padding_deadzone_size_x: int
|
||||||
|
padding_deadzone_size_y: int
|
||||||
|
|
||||||
|
plot_use_map: bool
|
||||||
|
|
||||||
|
deadzone_exp: bool
|
||||||
|
deadzone_hard_cutout: int # 0 = not, 1 = round, 2 = box
|
||||||
|
deadzone_hard_cutout_size: float
|
||||||
|
|
||||||
|
dictionary_prior: torch.Tensor | None
|
||||||
|
|
||||||
|
pi: torch.Tensor
|
||||||
|
torch_device: torch.device
|
||||||
|
default_dtype = torch.float32
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dictionary_filter: torch.Tensor,
|
||||||
|
dictionary: torch.Tensor,
|
||||||
|
dictionary_prior: torch.Tensor | None = None,
|
||||||
|
number_of_patches: int = 10, #
|
||||||
|
size_exp_deadzone: float = 1.0, #
|
||||||
|
padding_deadzone_size_x: int = 0, #
|
||||||
|
padding_deadzone_size_y: int = 0, #
|
||||||
|
plot_use_map: bool = False,
|
||||||
|
deadzone_exp: bool = True, #
|
||||||
|
deadzone_hard_cutout: int = 1, # 0 = not, 1 = round
|
||||||
|
deadzone_hard_cutout_size: float = 1.0, #
|
||||||
|
torch_device: str = "cpu",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dictionary_ready = False
|
||||||
|
self.parameter_ready = False
|
||||||
|
|
||||||
|
self.torch_device = torch.device(torch_device)
|
||||||
|
|
||||||
|
self.pi = torch.tensor(
|
||||||
|
math.pi,
|
||||||
|
device=self.torch_device,
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.plot_use_map = plot_use_map
|
||||||
|
|
||||||
|
self.update_parameter(
|
||||||
|
number_of_patches,
|
||||||
|
size_exp_deadzone,
|
||||||
|
padding_deadzone_size_x,
|
||||||
|
padding_deadzone_size_y,
|
||||||
|
deadzone_exp,
|
||||||
|
deadzone_hard_cutout,
|
||||||
|
deadzone_hard_cutout_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.update_dictionary(dictionary_filter, dictionary, dictionary_prior)
|
||||||
|
|
||||||
|
def update_parameter(
|
||||||
|
self,
|
||||||
|
number_of_patches: int,
|
||||||
|
size_exp_deadzone: float,
|
||||||
|
padding_deadzone_size_x: int,
|
||||||
|
padding_deadzone_size_y: int,
|
||||||
|
deadzone_exp: bool,
|
||||||
|
deadzone_hard_cutout: int,
|
||||||
|
deadzone_hard_cutout_size: float,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
self.parameter_ready = False
|
||||||
|
|
||||||
|
assert size_exp_deadzone > 0.0
|
||||||
|
assert number_of_patches > 0
|
||||||
|
assert padding_deadzone_size_x >= 0
|
||||||
|
assert padding_deadzone_size_y >= 0
|
||||||
|
|
||||||
|
self.number_of_patches = number_of_patches
|
||||||
|
self.size_exp_deadzone = size_exp_deadzone
|
||||||
|
self.padding_deadzone_size_x = padding_deadzone_size_x
|
||||||
|
self.padding_deadzone_size_y = padding_deadzone_size_y
|
||||||
|
|
||||||
|
self.deadzone_exp = deadzone_exp
|
||||||
|
self.deadzone_hard_cutout = deadzone_hard_cutout
|
||||||
|
self.deadzone_hard_cutout_size = deadzone_hard_cutout_size
|
||||||
|
|
||||||
|
self.parameter_ready = True
|
||||||
|
|
||||||
|
def update_dictionary(
|
||||||
|
self,
|
||||||
|
dictionary_filter: torch.Tensor,
|
||||||
|
dictionary: torch.Tensor,
|
||||||
|
dictionary_prior: torch.Tensor | None = None,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
self.dictionary_ready = False
|
||||||
|
|
||||||
|
assert dictionary_filter.ndim == 4
|
||||||
|
assert dictionary.ndim == 4
|
||||||
|
|
||||||
|
# Odd number of pixels. Please!
|
||||||
|
assert (dictionary_filter.shape[-2] % 2) == 1
|
||||||
|
assert (dictionary_filter.shape[-1] % 2) == 1
|
||||||
|
|
||||||
|
self.dictionary_filter = dictionary_filter.detach().clone()
|
||||||
|
self.dictionary = dictionary
|
||||||
|
|
||||||
|
if dictionary_prior is not None:
|
||||||
|
assert dictionary_prior.ndim == 1
|
||||||
|
assert dictionary_prior.shape[0] == dictionary_filter.shape[0]
|
||||||
|
|
||||||
|
self.dictionary_prior = dictionary_prior
|
||||||
|
|
||||||
|
self.dictionary_filter_fft = None
|
||||||
|
|
||||||
|
self.dictionary_ready = True
|
||||||
|
|
||||||
|
def fourier_convolution(self, contour: torch.Tensor):
|
||||||
|
# Pattern, X, Y
|
||||||
|
assert contour.dim() == 4
|
||||||
|
assert self.dictionary_filter is not None
|
||||||
|
assert contour.shape[-2] >= self.dictionary_filter.shape[-2]
|
||||||
|
assert contour.shape[-1] >= self.dictionary_filter.shape[-1]
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
contour_fft = torch.fft.rfft2(contour, dim=(-2, -1))
|
||||||
|
|
||||||
|
t1 = time.time()
|
||||||
|
|
||||||
|
if (
|
||||||
|
(self.dictionary_filter_fft is None)
|
||||||
|
or (self.dictionary_filter_fft.dim() != 4)
|
||||||
|
or (self.dictionary_filter_fft.shape[-2] != contour.shape[-2])
|
||||||
|
or (self.dictionary_filter_fft.shape[-1] != contour.shape[-1])
|
||||||
|
):
|
||||||
|
dictionary_padded: torch.Tensor = torch.zeros(
|
||||||
|
(
|
||||||
|
self.dictionary_filter.shape[0],
|
||||||
|
self.dictionary_filter.shape[1],
|
||||||
|
contour.shape[-2],
|
||||||
|
contour.shape[-1],
|
||||||
|
),
|
||||||
|
device=self.torch_device,
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
)
|
||||||
|
dictionary_padded[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
: self.dictionary_filter.shape[-2],
|
||||||
|
: self.dictionary_filter.shape[-1],
|
||||||
|
] = self.dictionary_filter.flip(dims=[-2, -1])
|
||||||
|
|
||||||
|
dictionary_padded = dictionary_padded.roll(
|
||||||
|
(
|
||||||
|
-(self.dictionary_filter.shape[-2] // 2),
|
||||||
|
-(self.dictionary_filter.shape[-1] // 2),
|
||||||
|
),
|
||||||
|
(-2, -1),
|
||||||
|
)
|
||||||
|
self.dictionary_filter_fft = torch.fft.rfft2(
|
||||||
|
dictionary_padded, dim=(-2, -1)
|
||||||
|
)
|
||||||
|
|
||||||
|
t2 = time.time()
|
||||||
|
|
||||||
|
assert self.dictionary_filter_fft is not None
|
||||||
|
|
||||||
|
# dimension order for multiplication: [pat, feat, ori, x, y]
|
||||||
|
self.contour_convolved_sum = torch.fft.irfft2(
|
||||||
|
contour_fft.unsqueeze(1) * self.dictionary_filter_fft.unsqueeze(0),
|
||||||
|
dim=(-2, -1),
|
||||||
|
).sum(dim=2)
|
||||||
|
# --> [pat, feat, x, y]
|
||||||
|
t3 = time.time()
|
||||||
|
|
||||||
|
self.use_map: torch.Tensor = torch.ones(
|
||||||
|
(contour.shape[0], 1, contour.shape[-2], contour.shape[-1]),
|
||||||
|
device=self.torch_device,
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.padding_deadzone_size_x > 0:
|
||||||
|
self.use_map[:, 0, : self.padding_deadzone_size_x, :] = 0.0
|
||||||
|
self.use_map[:, 0, -self.padding_deadzone_size_x :, :] = 0.0
|
||||||
|
|
||||||
|
if self.padding_deadzone_size_y > 0:
|
||||||
|
self.use_map[:, 0, :, : self.padding_deadzone_size_y] = 0.0
|
||||||
|
self.use_map[:, 0, :, -self.padding_deadzone_size_y :] = 0.0
|
||||||
|
|
||||||
|
t4 = time.time()
|
||||||
|
print(
|
||||||
|
"Sparsifier-convol {:.3f}s: fft-img-{:.3f}s, fft-fil-{:.3f}s, convol-{:.3f}s, pad-{:.3f}s".format(
|
||||||
|
t4 - t0, t1 - t0, t2 - t1, t3 - t2, t4 - t3
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
def find_next_element(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
assert self.use_map is not None
|
||||||
|
assert self.contour_convolved_sum is not None
|
||||||
|
assert self.dictionary_filter is not None
|
||||||
|
|
||||||
|
# store feature index, x pos and y pos (3 entries)
|
||||||
|
position_found = torch.zeros(
|
||||||
|
(self.contour_convolved_sum.shape[0], 3),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=self.torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
overlap_found = torch.zeros(
|
||||||
|
(self.contour_convolved_sum.shape[0], 2),
|
||||||
|
device=self.torch_device,
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
for pattern_id in range(0, self.contour_convolved_sum.shape[0]):
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
# search_tensor: torch.Tensor = (
|
||||||
|
# self.contour_convolved[pattern_id] * self.use_map[pattern_id]
|
||||||
|
# ).sum(dim=1)
|
||||||
|
search_tensor: torch.Tensor = (
|
||||||
|
self.contour_convolved_sum[pattern_id] * self.use_map[pattern_id]
|
||||||
|
)
|
||||||
|
|
||||||
|
t1 = time.time()
|
||||||
|
|
||||||
|
if self.dictionary_prior is not None:
|
||||||
|
search_tensor *= self.dictionary_prior.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
|
||||||
|
t2 = time.time()
|
||||||
|
|
||||||
|
temp, max_0 = search_tensor.max(dim=0)
|
||||||
|
temp, max_1 = temp.max(dim=0)
|
||||||
|
temp_overlap, max_2 = temp.max(dim=0)
|
||||||
|
|
||||||
|
position_base_3: int = int(max_2)
|
||||||
|
position_base_2: int = int(max_1[position_base_3])
|
||||||
|
position_base_1: int = int(max_0[position_base_2, position_base_3])
|
||||||
|
position_base_0: int = int(pattern_id)
|
||||||
|
|
||||||
|
position_found[position_base_0, 0] = position_base_1
|
||||||
|
position_found[position_base_0, 1] = position_base_2
|
||||||
|
position_found[position_base_0, 2] = position_base_3
|
||||||
|
|
||||||
|
overlap_found[pattern_id, 0] = temp_overlap
|
||||||
|
overlap_found[pattern_id, 1] = self.contour_convolved_sum[
|
||||||
|
position_base_0, position_base_1, position_base_2, position_base_3
|
||||||
|
]
|
||||||
|
|
||||||
|
t3 = time.time()
|
||||||
|
|
||||||
|
x_max: int = int(self.contour_convolved_sum.shape[-2])
|
||||||
|
y_max: int = int(self.contour_convolved_sum.shape[-1])
|
||||||
|
|
||||||
|
# Center arround the max position
|
||||||
|
x_0 = int(-position_base_2)
|
||||||
|
x_1 = int(x_max - position_base_2)
|
||||||
|
y_0 = int(-position_base_3)
|
||||||
|
y_1 = int(y_max - position_base_3)
|
||||||
|
|
||||||
|
temp_grid_x: torch.Tensor = torch.arange(
|
||||||
|
start=x_0,
|
||||||
|
end=x_1,
|
||||||
|
device=self.torch_device,
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
temp_grid_y: torch.Tensor = torch.arange(
|
||||||
|
start=y_0,
|
||||||
|
end=y_1,
|
||||||
|
device=self.torch_device,
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
x, y = torch.meshgrid(temp_grid_x, temp_grid_y, indexing="ij")
|
||||||
|
|
||||||
|
# discourage the neigbourhood around for the future
|
||||||
|
if self.deadzone_exp is True:
|
||||||
|
self.temp_map: torch.Tensor = 1.0 - torch.exp(
|
||||||
|
-(x**2 + y**2) / (2 * self.size_exp_deadzone**2)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.temp_map = torch.ones(
|
||||||
|
(x.shape[0], x.shape[1]),
|
||||||
|
device=self.torch_device,
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert self.deadzone_hard_cutout >= 0
|
||||||
|
assert self.deadzone_hard_cutout <= 1
|
||||||
|
assert self.deadzone_hard_cutout_size >= 0
|
||||||
|
|
||||||
|
if self.deadzone_hard_cutout == 1:
|
||||||
|
temp = x**2 + y**2
|
||||||
|
self.temp_map *= torch.where(
|
||||||
|
temp <= self.deadzone_hard_cutout_size**2,
|
||||||
|
0.0,
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.use_map[position_base_0, 0, :, :] *= self.temp_map
|
||||||
|
|
||||||
|
t4 = time.time()
|
||||||
|
|
||||||
|
# Only for keeping it in the float32 range:
|
||||||
|
self.use_map[position_base_0, 0, :, :] /= self.use_map[
|
||||||
|
position_base_0, 0, :, :
|
||||||
|
].max()
|
||||||
|
|
||||||
|
t5 = time.time()
|
||||||
|
|
||||||
|
# print(
|
||||||
|
# "{}, {}, {}, {}, {}".format(t1 - t0, t2 - t1, t3 - t2, t4 - t3, t5 - t4)
|
||||||
|
# )
|
||||||
|
|
||||||
|
return (
|
||||||
|
position_found,
|
||||||
|
overlap_found,
|
||||||
|
torch.tensor((t1 - t0, t2 - t1, t3 - t2, t4 - t3, t5 - t4)),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> None:
|
||||||
|
assert self.number_of_patches > 0
|
||||||
|
|
||||||
|
assert self.dictionary_ready is True
|
||||||
|
assert self.parameter_ready is True
|
||||||
|
|
||||||
|
self.position_found = torch.zeros(
|
||||||
|
(input.shape[0], self.number_of_patches, 3),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=self.torch_device,
|
||||||
|
)
|
||||||
|
self.overlap_found = torch.zeros(
|
||||||
|
(input.shape[0], self.number_of_patches, 2),
|
||||||
|
device=self.torch_device,
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Folding the images and the dictionary
|
||||||
|
self.fourier_convolution(input)
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
dt = torch.tensor([0, 0, 0, 0, 0])
|
||||||
|
for patch_id in range(0, self.number_of_patches):
|
||||||
|
(
|
||||||
|
self.position_found[:, patch_id, :],
|
||||||
|
self.overlap_found[:, patch_id, :],
|
||||||
|
dt_tmp,
|
||||||
|
) = self.find_next_element()
|
||||||
|
|
||||||
|
dt = dt + dt_tmp
|
||||||
|
|
||||||
|
if self.plot_use_map is True:
|
||||||
|
|
||||||
|
assert self.position_found.shape[0] == 1
|
||||||
|
assert self.use_map is not None
|
||||||
|
|
||||||
|
print("Position Saliency:")
|
||||||
|
print(self.overlap_found[0, :, 0])
|
||||||
|
print("Overlap with Contour Image:")
|
||||||
|
print(self.overlap_found[0, :, 1])
|
||||||
|
plt.subplot(1, 2, 1)
|
||||||
|
plt.imshow(self.use_map[0, 0].cpu(), cmap="gray")
|
||||||
|
plt.title(f"patch-number: {patch_id}")
|
||||||
|
plt.axis("off")
|
||||||
|
plt.colorbar(shrink=0.5)
|
||||||
|
plt.subplot(1, 2, 2)
|
||||||
|
plt.imshow(
|
||||||
|
(self.use_map[0, 0] * input[0].sum(dim=0)).cpu(), cmap="gray"
|
||||||
|
)
|
||||||
|
plt.title(f"patch-number: {patch_id}")
|
||||||
|
plt.axis("off")
|
||||||
|
plt.colorbar(shrink=0.5)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# self.overlap_found /= self.overlap_found.max(dim=1, keepdim=True)[0]
|
||||||
|
t1 = time.time()
|
||||||
|
print(
|
||||||
|
"Sparsifier-forward {:.3f}s: usemap-{:.3f}s, prior-{:.3f}s, findmax-{:.3f}s, notch-{:.3f}s, norm-{:.3f}s: (sum-{:.3f})".format(
|
||||||
|
t1 - t0, dt[0], dt[1], dt[2], dt[3], dt[4], dt.sum()
|
||||||
|
)
|
||||||
|
)
|
325
processing_chain/WebCam.py
Normal file
325
processing_chain/WebCam.py
Normal file
|
@ -0,0 +1,325 @@
|
||||||
|
#%%
|
||||||
|
#
|
||||||
|
# WebCam.py
|
||||||
|
# ========================================================
|
||||||
|
# interface to cv2 for using a webcam or for reading from
|
||||||
|
# a video file
|
||||||
|
#
|
||||||
|
# Version 1.0, before 30.03.2023:
|
||||||
|
# written by David...
|
||||||
|
#
|
||||||
|
# Version 1.1, 30.03.2023:
|
||||||
|
# thrown out test image
|
||||||
|
# added test code
|
||||||
|
# added code to "capture" from video file
|
||||||
|
#
|
||||||
|
# Version 1.2, 20.06.2023:
|
||||||
|
# added code to capture wirelessly from "GoPro" camera
|
||||||
|
# added test code for "GoPro"
|
||||||
|
#
|
||||||
|
# Version 1.3, 23.06.2023
|
||||||
|
# test display in pyplot or cv2
|
||||||
|
#
|
||||||
|
# Version 1.4, 28.06.2023
|
||||||
|
# solved Windows DirectShow problem
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
import torchvision as tv
|
||||||
|
|
||||||
|
# for GoPro
|
||||||
|
import time
|
||||||
|
import socket
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("Trying to import GoPro modules...")
|
||||||
|
from goprocam import GoProCamera, constants
|
||||||
|
|
||||||
|
gopro_exists = True
|
||||||
|
except:
|
||||||
|
print("...not found, continuing!")
|
||||||
|
gopro_exists = False
|
||||||
|
import platform
|
||||||
|
|
||||||
|
|
||||||
|
class WebCam:
|
||||||
|
|
||||||
|
# test_pattern: torch.Tensor
|
||||||
|
# test_pattern_gray: torch.Tensor
|
||||||
|
|
||||||
|
source: int
|
||||||
|
framesize: tuple[int, int]
|
||||||
|
fps: float
|
||||||
|
cap_frame_width: int
|
||||||
|
cap_frame_height: int
|
||||||
|
cap_fps: float
|
||||||
|
webcam_is_ready: bool
|
||||||
|
cap_frames_available: int
|
||||||
|
|
||||||
|
default_dtype = torch.float32
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
source: str | int = 1,
|
||||||
|
framesize: tuple[int, int] = (720, 1280), # (1920, 1080), # (640, 480),
|
||||||
|
fps: float = 30.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert fps > 0
|
||||||
|
|
||||||
|
self.source = source
|
||||||
|
self.framesize = framesize
|
||||||
|
self.cap = None
|
||||||
|
self.fps = fps
|
||||||
|
self.webcam_is_ready = False
|
||||||
|
|
||||||
|
def open_cam(self) -> bool:
|
||||||
|
if self.cap is not None:
|
||||||
|
self.cap.release()
|
||||||
|
self.cap = None
|
||||||
|
self.webcam_is_ready = False
|
||||||
|
|
||||||
|
# handle GoPro...
|
||||||
|
if self.source == "GoProWireless":
|
||||||
|
|
||||||
|
if not gopro_exists:
|
||||||
|
print("No GoPro driver/support!")
|
||||||
|
self.webcam_is_ready = False
|
||||||
|
return False
|
||||||
|
|
||||||
|
print("GoPro: Starting access")
|
||||||
|
gpCam = GoProCamera.GoPro()
|
||||||
|
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
|
print("GoPro: Socket created!!!")
|
||||||
|
|
||||||
|
self.t = time.time()
|
||||||
|
gpCam.livestream("start")
|
||||||
|
gpCam.video_settings(res="1080p", fps="30")
|
||||||
|
gpCam.gpControlSet(
|
||||||
|
constants.Stream.WINDOW_SIZE, constants.Stream.WindowSize.R720
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cap = cv2.VideoCapture("udp://10.5.5.9:8554", cv2.CAP_FFMPEG)
|
||||||
|
print("GoPro: Video capture started!!!")
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.sock = None
|
||||||
|
self.t = -1
|
||||||
|
if platform.system().lower() == "windows":
|
||||||
|
self.cap = cv2.VideoCapture(self.source, cv2.CAP_DSHOW)
|
||||||
|
else:
|
||||||
|
self.cap = cv2.VideoCapture(self.source)
|
||||||
|
print("Normal capture started!!!")
|
||||||
|
|
||||||
|
assert self.cap is not None
|
||||||
|
|
||||||
|
if self.cap.isOpened() is not True:
|
||||||
|
self.webcam_is_ready = False
|
||||||
|
return False
|
||||||
|
|
||||||
|
if type(self.source) != str:
|
||||||
|
self.cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*"MJPG"))
|
||||||
|
self.cap.set(cv2.CAP_PROP_FPS, self.fps)
|
||||||
|
self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, self.framesize[0])
|
||||||
|
self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.framesize[1])
|
||||||
|
self.cap_frames_available = None
|
||||||
|
else:
|
||||||
|
# ...for files and GoPro...
|
||||||
|
self.cap_frames_available = self.cap.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||||
|
|
||||||
|
self.cap_frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
|
self.cap_frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
|
self.cap_fps = float(self.cap.get(cv2.CAP_PROP_FPS))
|
||||||
|
|
||||||
|
print(
|
||||||
|
(
|
||||||
|
f"Capturing or reading with: {self.cap_frame_width:.0f} x "
|
||||||
|
f"{self.cap_frame_height:.0f} @ "
|
||||||
|
f"{self.cap_fps:.1f}."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.webcam_is_ready = True
|
||||||
|
return True
|
||||||
|
|
||||||
|
def close_cam(self) -> None:
|
||||||
|
if self.cap is not None:
|
||||||
|
self.cap.release()
|
||||||
|
|
||||||
|
def get_frame(self) -> torch.Tensor | None:
|
||||||
|
if self.cap is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
if self.sock:
|
||||||
|
dt_min = 0.015
|
||||||
|
success, frame = self.cap.read()
|
||||||
|
t_next = time.time()
|
||||||
|
t_prev = t_next
|
||||||
|
while t_next - t_prev < dt_min:
|
||||||
|
t_prev = t_next
|
||||||
|
success, frame = self.cap.read()
|
||||||
|
t_next = time.time()
|
||||||
|
|
||||||
|
if self.t >= 0:
|
||||||
|
if time.time() - self.t > 2.5:
|
||||||
|
print("GoPro-Stream must be kept awake!...")
|
||||||
|
self.sock.sendto(
|
||||||
|
"_GPHD_:0:0:2:0.000000\n".encode(), ("10.5.5.9", 8554)
|
||||||
|
)
|
||||||
|
self.t = time.time()
|
||||||
|
else:
|
||||||
|
success, frame = self.cap.read()
|
||||||
|
|
||||||
|
if success is False:
|
||||||
|
self.webcam_is_ready = False
|
||||||
|
return None
|
||||||
|
|
||||||
|
output = (
|
||||||
|
torch.tensor(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||||
|
.movedim(-1, 0)
|
||||||
|
.type(dtype=self.default_dtype)
|
||||||
|
/ 255.0
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
# for testing the code if module is executed from command line
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
TEST_FILEREAD = False
|
||||||
|
TEST_WEBCAM = True
|
||||||
|
TEST_GOPRO = False
|
||||||
|
|
||||||
|
display = "cv2"
|
||||||
|
n_capture = 200
|
||||||
|
delay_capture = 0.001
|
||||||
|
|
||||||
|
print("Testing the WebCam interface")
|
||||||
|
|
||||||
|
if TEST_FILEREAD:
|
||||||
|
|
||||||
|
file_name = "level1.mp4"
|
||||||
|
|
||||||
|
# open
|
||||||
|
print("Opening video file")
|
||||||
|
w = WebCam(file_name)
|
||||||
|
if not w.open_cam():
|
||||||
|
raise OSError(f"Opening file with name {file_name} failed!")
|
||||||
|
|
||||||
|
# print information
|
||||||
|
print(
|
||||||
|
f"Frame size {w.cap_frame_width} x {w.cap_frame_height} at {w.cap_fps} fps."
|
||||||
|
)
|
||||||
|
|
||||||
|
# capture three frames and show them
|
||||||
|
for i in range(min([n_capture, w.cap_frames_available])): # TODO: available?
|
||||||
|
frame = w.get_frame()
|
||||||
|
if frame == None:
|
||||||
|
raise OSError(f"Can not get frame from file with name {file_name}!")
|
||||||
|
print(f"frame {i} has shape {frame.shape}")
|
||||||
|
|
||||||
|
frame_numpy = (frame.movedim(0, -1) * 255).type(dtype=torch.uint8).numpy()
|
||||||
|
|
||||||
|
if display == "pyplot":
|
||||||
|
plt.imshow(frame_numpy)
|
||||||
|
plt.show()
|
||||||
|
if display == "cv2":
|
||||||
|
cv2.imshow("File", frame_numpy[:, :, (2, 1, 0)])
|
||||||
|
cv2.waitKey(1)
|
||||||
|
time.sleep(delay_capture)
|
||||||
|
|
||||||
|
# close
|
||||||
|
print("Closing file")
|
||||||
|
w.close_cam()
|
||||||
|
|
||||||
|
if TEST_WEBCAM:
|
||||||
|
|
||||||
|
camera_index = 0
|
||||||
|
|
||||||
|
# open
|
||||||
|
print("Opening camera")
|
||||||
|
w = WebCam(camera_index)
|
||||||
|
if not w.open_cam():
|
||||||
|
raise OSError(f"Opening web cam with index {camera_index} failed!")
|
||||||
|
|
||||||
|
# print information
|
||||||
|
print(
|
||||||
|
f"Frame size {w.cap_frame_width} x {w.cap_frame_height} at {w.cap_fps} fps."
|
||||||
|
)
|
||||||
|
|
||||||
|
# capture three frames and show them
|
||||||
|
for i in range(n_capture):
|
||||||
|
frame = w.get_frame()
|
||||||
|
if frame == None:
|
||||||
|
raise OSError(
|
||||||
|
f"Can not get frame from camera with index {camera_index}!"
|
||||||
|
)
|
||||||
|
print(f"frame {i} has shape {frame.shape}")
|
||||||
|
|
||||||
|
frame_numpy = (frame.movedim(0, -1) * 255).type(dtype=torch.uint8).numpy()
|
||||||
|
if display == "pyplot":
|
||||||
|
plt.imshow(frame_numpy)
|
||||||
|
plt.show()
|
||||||
|
if display == "cv2":
|
||||||
|
cv2.imshow("WebCam", frame_numpy[:, :, (2, 1, 0)])
|
||||||
|
cv2.waitKey(1)
|
||||||
|
time.sleep(delay_capture)
|
||||||
|
|
||||||
|
# close
|
||||||
|
print("Closing camera")
|
||||||
|
w.close_cam()
|
||||||
|
|
||||||
|
if TEST_GOPRO:
|
||||||
|
|
||||||
|
camera_name = "GoProWireless"
|
||||||
|
|
||||||
|
# open
|
||||||
|
print("Opening GoPro")
|
||||||
|
w = WebCam(camera_name)
|
||||||
|
if not w.open_cam():
|
||||||
|
raise OSError(f"Opening GoPro with index {camera_index} failed!")
|
||||||
|
|
||||||
|
# print information
|
||||||
|
print(
|
||||||
|
f"Frame size {w.cap_frame_width} x {w.cap_frame_height} at {w.cap_fps} fps."
|
||||||
|
)
|
||||||
|
w.cap.set(cv2.CAP_PROP_BUFFERSIZE, 0)
|
||||||
|
|
||||||
|
# capture three frames and show them
|
||||||
|
# print("Empty Buffer...")
|
||||||
|
# for i in range(500):
|
||||||
|
# print(i)
|
||||||
|
# frame = w.get_frame()
|
||||||
|
# print("Buffer Emptied...")
|
||||||
|
|
||||||
|
for i in range(n_capture):
|
||||||
|
frame = w.get_frame()
|
||||||
|
if frame == None:
|
||||||
|
raise OSError(
|
||||||
|
f"Can not get frame from camera with index {camera_index}!"
|
||||||
|
)
|
||||||
|
print(f"frame {i} has shape {frame.shape}")
|
||||||
|
|
||||||
|
frame_numpy = (frame.movedim(0, -1) * 255).type(dtype=torch.uint8).numpy()
|
||||||
|
if display == "pyplot":
|
||||||
|
plt.imshow(frame_numpy)
|
||||||
|
plt.show()
|
||||||
|
if display == "cv2":
|
||||||
|
cv2.imshow("GoPro", frame_numpy[:, :, (2, 1, 0)])
|
||||||
|
cv2.waitKey(1)
|
||||||
|
time.sleep(delay_capture)
|
||||||
|
|
||||||
|
# close
|
||||||
|
print("Closing Cam/File/GoPro")
|
||||||
|
w.close_cam()
|
||||||
|
|
||||||
|
if display == "cv2":
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
# %%
|
402
processing_chain/Yolo5Segmentation.py
Normal file
402
processing_chain/Yolo5Segmentation.py
Normal file
|
@ -0,0 +1,402 @@
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
import torchvision as tv
|
||||||
|
import time
|
||||||
|
|
||||||
|
from models.yolo import Detect, Model
|
||||||
|
|
||||||
|
# Warning: The line "#matplotlib.use('Agg') # for writing to files only"
|
||||||
|
# in utils/plots.py prevents the further use of matplotlib
|
||||||
|
|
||||||
|
|
||||||
|
class Yolo5Segmentation(torch.nn.Module):
|
||||||
|
|
||||||
|
default_dtype = torch.float32
|
||||||
|
|
||||||
|
conf: float = 0.25 # NMS confidence threshold
|
||||||
|
iou: float = 0.45 # NMS IoU threshold
|
||||||
|
agnostic: bool = False # NMS class-agnostic
|
||||||
|
multi_label: bool = False # NMS multiple labels per box
|
||||||
|
max_det: int = 1000 # maximum number of detections per image
|
||||||
|
number_of_maps: int = 32
|
||||||
|
imgsz: tuple[int, int] = (640, 640) # inference size (height, width)
|
||||||
|
|
||||||
|
device: torch.device = torch.device("cpu")
|
||||||
|
weigh_path: str = ""
|
||||||
|
|
||||||
|
class_names: dict
|
||||||
|
stride: int
|
||||||
|
|
||||||
|
found_class_id: torch.Tensor | None = None
|
||||||
|
|
||||||
|
def __init__(self, mode: int = 3, torch_device: str = "cpu"):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
model_pretrained_path: str = "segment_pretrained"
|
||||||
|
assert mode < 5
|
||||||
|
assert mode >= 0
|
||||||
|
if mode == 0:
|
||||||
|
model_pretrained_weights: str = "yolov5n-seg.pt"
|
||||||
|
elif mode == 1:
|
||||||
|
model_pretrained_weights = "yolov5s-seg.pt"
|
||||||
|
elif mode == 2:
|
||||||
|
model_pretrained_weights = "yolov5m-seg.pt"
|
||||||
|
elif mode == 3:
|
||||||
|
model_pretrained_weights = "yolov5l-seg.pt"
|
||||||
|
elif mode == 4:
|
||||||
|
model_pretrained_weights = "yolov5x-seg.pt"
|
||||||
|
|
||||||
|
self.weigh_path = os.path.join(model_pretrained_path, model_pretrained_weights)
|
||||||
|
|
||||||
|
self.device = torch.device(torch_device)
|
||||||
|
|
||||||
|
self.network = self.attempt_load(
|
||||||
|
self.weigh_path, device=self.device, inplace=True, fuse=True
|
||||||
|
)
|
||||||
|
self.stride = max(int(self.network.stride.max()), 32) # model stride
|
||||||
|
self.network.float()
|
||||||
|
self.class_names = dict(self.network.names) # type: ignore
|
||||||
|
|
||||||
|
# classes: (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
|
||||||
|
def forward(self, input: torch.Tensor, classes=None) -> torch.Tensor:
|
||||||
|
|
||||||
|
assert input.ndim == 4
|
||||||
|
assert input.shape[0] == 1
|
||||||
|
assert input.shape[1] == 3
|
||||||
|
|
||||||
|
input_resized, (
|
||||||
|
remove_left,
|
||||||
|
remove_top,
|
||||||
|
remove_height,
|
||||||
|
remove_width,
|
||||||
|
) = self.scale_and_pad(
|
||||||
|
input,
|
||||||
|
)
|
||||||
|
|
||||||
|
network_output = self.network(input_resized)
|
||||||
|
number_of_classes = network_output[0].shape[2] - self.number_of_maps - 5
|
||||||
|
assert len(self.class_names) == number_of_classes
|
||||||
|
|
||||||
|
maps = network_output[1]
|
||||||
|
|
||||||
|
# results matrix:
|
||||||
|
# Fist Dimension:
|
||||||
|
# Image Number
|
||||||
|
# ...
|
||||||
|
# Last Dimension:
|
||||||
|
# center_x: 0
|
||||||
|
# center_y: 1
|
||||||
|
# width: 2
|
||||||
|
# height: 3
|
||||||
|
# obj_conf (object): 4
|
||||||
|
# cls_conf (class): 5
|
||||||
|
|
||||||
|
results = non_max_suppression(
|
||||||
|
network_output[0],
|
||||||
|
self.conf,
|
||||||
|
self.iou,
|
||||||
|
classes,
|
||||||
|
self.agnostic,
|
||||||
|
self.multi_label,
|
||||||
|
max_det=self.max_det,
|
||||||
|
nm=self.number_of_maps,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_id = 0
|
||||||
|
|
||||||
|
if results[image_id].shape[0] > 0:
|
||||||
|
masks = self.process_mask_native(
|
||||||
|
maps[image_id],
|
||||||
|
results[image_id][:, 6:],
|
||||||
|
results[image_id][:, :4],
|
||||||
|
)
|
||||||
|
self.found_class_id = results[image_id][:, 5]
|
||||||
|
|
||||||
|
output = tv.transforms.functional.resize(
|
||||||
|
tv.transforms.functional.crop(
|
||||||
|
img=masks,
|
||||||
|
top=int(remove_top),
|
||||||
|
left=int(remove_left),
|
||||||
|
height=int(remove_height),
|
||||||
|
width=int(remove_width),
|
||||||
|
),
|
||||||
|
size=(input.shape[-2], input.shape[-1]),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output = None
|
||||||
|
self.found_class_id = None
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# code stolen and/or modified from Yolov5 ->
|
||||||
|
def scale_and_pad(
|
||||||
|
self,
|
||||||
|
input,
|
||||||
|
):
|
||||||
|
ratio = min(self.imgsz[0] / input.shape[-2], self.imgsz[1] / input.shape[-1])
|
||||||
|
|
||||||
|
shape_new_x = int(input.shape[-2] * ratio)
|
||||||
|
shape_new_y = int(input.shape[-1] * ratio)
|
||||||
|
|
||||||
|
dx = self.imgsz[0] - shape_new_x
|
||||||
|
dy = self.imgsz[1] - shape_new_y
|
||||||
|
|
||||||
|
dx_0 = dx // 2
|
||||||
|
dy_0 = dy // 2
|
||||||
|
|
||||||
|
image_resize = tv.transforms.functional.pad(
|
||||||
|
tv.transforms.functional.resize(input, size=(shape_new_x, shape_new_y)),
|
||||||
|
padding=[dy_0, dx_0, int(dy - dy_0), int(dx - dx_0)],
|
||||||
|
fill=float(114.0 / 255.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
return image_resize, (dy_0, dx_0, shape_new_x, shape_new_y)
|
||||||
|
|
||||||
|
def process_mask_native(self, protos, masks_in, bboxes):
|
||||||
|
masks = (
|
||||||
|
(masks_in @ protos.float().view(protos.shape[0], -1))
|
||||||
|
.sigmoid()
|
||||||
|
.view(-1, protos.shape[1], protos.shape[2])
|
||||||
|
)
|
||||||
|
|
||||||
|
masks = torch.nn.functional.interpolate(
|
||||||
|
masks[None],
|
||||||
|
(self.imgsz[0], self.imgsz[1]),
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False,
|
||||||
|
)[
|
||||||
|
0
|
||||||
|
] # CHW
|
||||||
|
|
||||||
|
x1, y1, x2, y2 = torch.chunk(bboxes[:, :, None], 4, 1) # x1 shape(1,1,n)
|
||||||
|
r = torch.arange(masks.shape[2], device=masks.device, dtype=x1.dtype)[
|
||||||
|
None, None, :
|
||||||
|
] # rows shape(1,w,1)
|
||||||
|
c = torch.arange(masks.shape[1], device=masks.device, dtype=x1.dtype)[
|
||||||
|
None, :, None
|
||||||
|
] # cols shape(h,1,1)
|
||||||
|
|
||||||
|
masks = masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
|
||||||
|
|
||||||
|
return masks.gt_(0.5)
|
||||||
|
|
||||||
|
def attempt_load(self, weights, device=None, inplace=True, fuse=True):
|
||||||
|
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
||||||
|
|
||||||
|
model = Ensemble()
|
||||||
|
for w in weights if isinstance(weights, list) else [weights]:
|
||||||
|
ckpt = torch.load(w, map_location="cpu") # load
|
||||||
|
ckpt = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model
|
||||||
|
|
||||||
|
# Model compatibility updates
|
||||||
|
if not hasattr(ckpt, "stride"):
|
||||||
|
ckpt.stride = torch.tensor([32.0])
|
||||||
|
if hasattr(ckpt, "names") and isinstance(ckpt.names, (list, tuple)):
|
||||||
|
ckpt.names = dict(enumerate(ckpt.names)) # convert to dict
|
||||||
|
|
||||||
|
model.append(
|
||||||
|
ckpt.fuse().eval() if fuse and hasattr(ckpt, "fuse") else ckpt.eval()
|
||||||
|
) # model in eval mode
|
||||||
|
|
||||||
|
# Module compatibility updates
|
||||||
|
for m in model.modules():
|
||||||
|
t = type(m)
|
||||||
|
if t in (
|
||||||
|
torch.nn.Hardswish,
|
||||||
|
torch.nn.LeakyReLU,
|
||||||
|
torch.nn.ReLU,
|
||||||
|
torch.nn.ReLU6,
|
||||||
|
torch.nn.SiLU,
|
||||||
|
Detect,
|
||||||
|
Model,
|
||||||
|
):
|
||||||
|
m.inplace = inplace # torch 1.7.0 compatibility
|
||||||
|
if t is Detect and not isinstance(m.anchor_grid, list):
|
||||||
|
delattr(m, "anchor_grid")
|
||||||
|
setattr(m, "anchor_grid", [torch.zeros(1)] * m.nl)
|
||||||
|
elif t is torch.nn.Upsample and not hasattr(m, "recompute_scale_factor"):
|
||||||
|
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
||||||
|
|
||||||
|
# Return model
|
||||||
|
if len(model) == 1:
|
||||||
|
return model[-1]
|
||||||
|
|
||||||
|
# Return detection ensemble
|
||||||
|
print(f"Ensemble created with {weights}\n")
|
||||||
|
for k in "names", "nc", "yaml":
|
||||||
|
setattr(model, k, getattr(model[0], k))
|
||||||
|
model.stride = model[
|
||||||
|
torch.argmax(torch.tensor([m.stride.max() for m in model])).int()
|
||||||
|
].stride # max stride
|
||||||
|
assert all(
|
||||||
|
model[0].nc == m.nc for m in model
|
||||||
|
), f"Models have different class counts: {[m.nc for m in model]}"
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class Ensemble(torch.nn.ModuleList):
|
||||||
|
# Ensemble of models
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x, augment=False, profile=False, visualize=False):
|
||||||
|
y = [module(x, augment, profile, visualize)[0] for module in self]
|
||||||
|
# y = torch.stack(y).max(0)[0] # max ensemble
|
||||||
|
# y = torch.stack(y).mean(0) # mean ensemble
|
||||||
|
y = torch.cat(y, 1) # nms ensemble
|
||||||
|
return y, None # inference, train output
|
||||||
|
|
||||||
|
|
||||||
|
def non_max_suppression(
|
||||||
|
prediction,
|
||||||
|
conf_thres=0.25,
|
||||||
|
iou_thres=0.45,
|
||||||
|
classes=None,
|
||||||
|
agnostic=False,
|
||||||
|
multi_label=False,
|
||||||
|
labels=(),
|
||||||
|
max_det=300,
|
||||||
|
nm=0, # number of masks
|
||||||
|
):
|
||||||
|
"""Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(
|
||||||
|
prediction, (list, tuple)
|
||||||
|
): # YOLOv5 model in validation model, output = (inference_out, loss_out)
|
||||||
|
prediction = prediction[0] # select only inference output
|
||||||
|
|
||||||
|
device = prediction.device
|
||||||
|
mps = "mps" in device.type # Apple MPS
|
||||||
|
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
|
||||||
|
prediction = prediction.cpu()
|
||||||
|
bs = prediction.shape[0] # batch size
|
||||||
|
nc = prediction.shape[2] - nm - 5 # number of classes
|
||||||
|
xc = prediction[..., 4] > conf_thres # candidates
|
||||||
|
|
||||||
|
# Checks
|
||||||
|
assert (
|
||||||
|
0 <= conf_thres <= 1
|
||||||
|
), f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
|
||||||
|
assert (
|
||||||
|
0 <= iou_thres <= 1
|
||||||
|
), f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
|
||||||
|
|
||||||
|
# Settings
|
||||||
|
# min_wh = 2 # (pixels) minimum box width and height
|
||||||
|
max_wh = 7680 # (pixels) maximum box width and height
|
||||||
|
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
|
||||||
|
time_limit = 0.5 + 0.05 * bs # seconds to quit after
|
||||||
|
redundant = True # require redundant detections
|
||||||
|
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
||||||
|
merge = False # use merge-NMS
|
||||||
|
|
||||||
|
t = time.time()
|
||||||
|
mi = 5 + nc # mask start index
|
||||||
|
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
|
||||||
|
for xi, x in enumerate(prediction): # image index, image inference
|
||||||
|
# Apply constraints
|
||||||
|
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
|
||||||
|
x = x[xc[xi]] # confidence
|
||||||
|
|
||||||
|
# Cat apriori labels if autolabelling
|
||||||
|
if labels and len(labels[xi]):
|
||||||
|
lb = labels[xi]
|
||||||
|
v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
|
||||||
|
v[:, :4] = lb[:, 1:5] # box
|
||||||
|
v[:, 4] = 1.0 # conf
|
||||||
|
v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
|
||||||
|
x = torch.cat((x, v), 0)
|
||||||
|
|
||||||
|
# If none remain process next image
|
||||||
|
if not x.shape[0]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Compute conf
|
||||||
|
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
|
||||||
|
|
||||||
|
# Box/Mask
|
||||||
|
box = xywh2xyxy(
|
||||||
|
x[:, :4]
|
||||||
|
) # center_x, center_y, width, height) to (x1, y1, x2, y2)
|
||||||
|
mask = x[:, mi:] # zero columns if no masks
|
||||||
|
|
||||||
|
# Detections matrix nx6 (xyxy, conf, cls)
|
||||||
|
if multi_label:
|
||||||
|
i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T
|
||||||
|
x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1)
|
||||||
|
else: # best class only
|
||||||
|
conf, j = x[:, 5:mi].max(1, keepdim=True)
|
||||||
|
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
|
||||||
|
|
||||||
|
# Filter by class
|
||||||
|
if classes is not None:
|
||||||
|
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
|
||||||
|
|
||||||
|
# Apply finite constraint
|
||||||
|
# if not torch.isfinite(x).all():
|
||||||
|
# x = x[torch.isfinite(x).all(1)]
|
||||||
|
|
||||||
|
# Check shape
|
||||||
|
n = x.shape[0] # number of boxes
|
||||||
|
if not n: # no boxes
|
||||||
|
continue
|
||||||
|
elif n > max_nms: # excess boxes
|
||||||
|
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
|
||||||
|
else:
|
||||||
|
x = x[x[:, 4].argsort(descending=True)] # sort by confidence
|
||||||
|
|
||||||
|
# Batched NMS
|
||||||
|
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
||||||
|
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
||||||
|
i = tv.ops.nms(boxes, scores, iou_thres) # NMS
|
||||||
|
if i.shape[0] > max_det: # limit detections
|
||||||
|
i = i[:max_det]
|
||||||
|
if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean)
|
||||||
|
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
|
||||||
|
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
|
||||||
|
weights = iou * scores[None] # box weights
|
||||||
|
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(
|
||||||
|
1, keepdim=True
|
||||||
|
) # merged boxes
|
||||||
|
if redundant:
|
||||||
|
i = i[iou.sum(1) > 1] # require redundancy
|
||||||
|
|
||||||
|
output[xi] = x[i]
|
||||||
|
if mps:
|
||||||
|
output[xi] = output[xi].to(device)
|
||||||
|
if (time.time() - t) > time_limit:
|
||||||
|
print(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded")
|
||||||
|
break # time limit exceeded
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def box_iou(box1, box2, eps=1e-7):
|
||||||
|
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
|
||||||
|
(a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
|
||||||
|
inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
|
||||||
|
|
||||||
|
# IoU = inter / (area1 + area2 - inter)
|
||||||
|
return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
|
||||||
|
|
||||||
|
|
||||||
|
def xywh2xyxy(x):
|
||||||
|
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
||||||
|
y = x.clone()
|
||||||
|
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
||||||
|
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
||||||
|
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
|
||||||
|
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
# <- code stolen and/or modified from Yolov5
|
Loading…
Reference in a new issue