237 lines
7.4 KiB
Python
237 lines
7.4 KiB
Python
# %%
|
|
# PatchGenerator.py
|
|
# ====================================
|
|
# generates dictionaries (currently: phosphenes or clocks)
|
|
#
|
|
# Version V1.0, pre-07.03.2023:
|
|
# no actual changes, is David's last code version...
|
|
#
|
|
# Version V1.1, 07.03.2023:
|
|
# merged David's rebuild code (GUI capable)
|
|
# (there was not really anything to merge :-))
|
|
#
|
|
|
|
import torch
|
|
import math
|
|
|
|
|
|
class PatchGenerator:
|
|
|
|
pi: torch.Tensor
|
|
torch_device: torch.device
|
|
default_dtype = torch.float32
|
|
|
|
def __init__(self, torch_device: str = "cpu"):
|
|
self.torch_device = torch.device(torch_device)
|
|
|
|
self.pi = torch.tensor(
|
|
math.pi,
|
|
device=self.torch_device,
|
|
dtype=self.default_dtype,
|
|
)
|
|
|
|
def alphabet_phosphene(
|
|
self,
|
|
sigma_width: float = 2.5,
|
|
patch_size: int = 41,
|
|
) -> torch.Tensor:
|
|
|
|
n: int = int(patch_size // 2)
|
|
temp_grid: torch.Tensor = torch.arange(
|
|
start=-n,
|
|
end=n + 1,
|
|
device=self.torch_device,
|
|
dtype=self.default_dtype,
|
|
)
|
|
|
|
x, y = torch.meshgrid(temp_grid, temp_grid, indexing="ij")
|
|
|
|
phosphene: torch.Tensor = torch.exp(-(x**2 + y**2) / (2 * sigma_width**2))
|
|
phosphene /= phosphene.sum()
|
|
|
|
return phosphene.unsqueeze(0).unsqueeze(0)
|
|
|
|
def alphabet_clocks(
|
|
self,
|
|
n_dir: int = 8,
|
|
n_open: int = 4,
|
|
n_filter: int = 4,
|
|
patch_size: int = 41,
|
|
segment_width: float = 2.5,
|
|
segment_length: float = 15.0,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
|
# for n_dir directions, there are n_open_max opening angles possible...
|
|
assert n_dir % 2 == 0, "n_dir must be multiple of 2"
|
|
n_open_max: int = n_dir // 2
|
|
|
|
# ...but this number can be reduced by integer multiples:
|
|
assert (
|
|
n_open_max % n_open == 0
|
|
), "n_open_max = n_dir // 2 must be multiple of n_open"
|
|
mul_open: int = n_open_max // n_open
|
|
|
|
# filter planes must be multiple of number of orientations implied by n_dir
|
|
assert n_filter % n_open_max == 0, "n_filter must be multiple of (n_dir // 2)"
|
|
mul_filter: int = n_filter // n_open_max
|
|
# compute single segments
|
|
segments: torch.Tensor = torch.zeros(
|
|
(n_dir, patch_size, patch_size),
|
|
device=self.torch_device,
|
|
dtype=self.default_dtype,
|
|
)
|
|
dirs: torch.Tensor = (
|
|
2
|
|
* self.pi
|
|
* torch.arange(
|
|
start=0, end=n_dir, device=self.torch_device, dtype=self.default_dtype
|
|
)
|
|
/ n_dir
|
|
)
|
|
|
|
for i_dir in range(n_dir):
|
|
segments[i_dir] = self.draw_segment(
|
|
patch_size=patch_size,
|
|
phi=float(dirs[i_dir]),
|
|
segment_length=segment_length,
|
|
segment_width=segment_width,
|
|
)
|
|
|
|
# compute patches from segments
|
|
clocks = torch.zeros(
|
|
(n_open, n_dir, patch_size, patch_size),
|
|
device=self.torch_device,
|
|
dtype=self.default_dtype,
|
|
)
|
|
clocks_filter = torch.zeros(
|
|
(n_open, n_dir, n_filter, patch_size, patch_size),
|
|
device=self.torch_device,
|
|
dtype=self.default_dtype,
|
|
)
|
|
|
|
for i_dir in range(n_dir):
|
|
for i_open in range(n_open):
|
|
|
|
seg1 = segments[i_dir]
|
|
seg2 = segments[(i_dir + (i_open + 1) * mul_open) % n_dir]
|
|
clocks[i_open, i_dir] = torch.where(seg1 > seg2, seg1, seg2)
|
|
|
|
i_filter_seg1 = (i_dir * mul_filter) % n_filter
|
|
i_filter_seg2 = (
|
|
(i_dir + (i_open + 1) * mul_open) * mul_filter
|
|
) % n_filter
|
|
|
|
if i_filter_seg1 == i_filter_seg2:
|
|
clock_merged = torch.where(seg1 > seg2, seg1, seg2)
|
|
clocks_filter[i_open, i_dir, i_filter_seg1] = clock_merged
|
|
else:
|
|
clocks_filter[i_open, i_dir, i_filter_seg1] = seg1
|
|
clocks_filter[i_open, i_dir, i_filter_seg2] = seg2
|
|
|
|
clocks_filter = clocks_filter.reshape(
|
|
(n_open * n_dir, n_filter, patch_size, patch_size)
|
|
)
|
|
clocks_filter = clocks_filter / clocks_filter.sum(
|
|
axis=(-3, -2, -1), keepdims=True
|
|
)
|
|
clocks = clocks.reshape((n_open * n_dir, 1, patch_size, patch_size))
|
|
clocks = clocks / clocks.sum(axis=(-2, -1), keepdims=True)
|
|
|
|
return clocks_filter, clocks, segments
|
|
|
|
def draw_segment(
|
|
self,
|
|
patch_size: float,
|
|
phi: float,
|
|
segment_length: float,
|
|
segment_width: float,
|
|
) -> torch.Tensor:
|
|
|
|
# extension of patch beyond origin
|
|
n: int = int(patch_size // 2)
|
|
|
|
# grid for the patch
|
|
temp_grid: torch.Tensor = torch.arange(
|
|
start=-n,
|
|
end=n + 1,
|
|
device=self.torch_device,
|
|
dtype=self.default_dtype,
|
|
)
|
|
|
|
x, y = torch.meshgrid(temp_grid, temp_grid, indexing="ij")
|
|
|
|
r: torch.Tensor = torch.cat((x.unsqueeze(-1), y.unsqueeze(-1)), dim=2)
|
|
|
|
# target orientation of basis segment
|
|
phi90: torch.Tensor = phi + self.pi / 2
|
|
|
|
# vector pointing to the ending point of segment (direction)
|
|
#
|
|
# when result is displayed with plt.imshow(segment),
|
|
# phi=0 points to the right, and increasing phi rotates
|
|
# the segment counterclockwise
|
|
#
|
|
e: torch.Tensor = torch.tensor(
|
|
[torch.cos(phi90), torch.sin(phi90)],
|
|
device=self.torch_device,
|
|
dtype=self.default_dtype,
|
|
)
|
|
|
|
# tangential vectors
|
|
e_tang: torch.Tensor = e.flip(dims=[0]) * torch.tensor(
|
|
[-1, 1], device=self.torch_device, dtype=self.default_dtype
|
|
)
|
|
|
|
# compute distances to segment: parallel/tangential
|
|
d = torch.maximum(
|
|
torch.zeros(
|
|
(r.shape[0], r.shape[1]),
|
|
device=self.torch_device,
|
|
dtype=self.default_dtype,
|
|
),
|
|
torch.abs(
|
|
(r * e.unsqueeze(0).unsqueeze(0)).sum(dim=-1) - segment_length / 2
|
|
)
|
|
- segment_length / 2,
|
|
)
|
|
|
|
d_tang = (r * e_tang.unsqueeze(0).unsqueeze(0)).sum(dim=-1)
|
|
|
|
# compute minimum distance to any of the two pointers
|
|
dr = torch.sqrt(d**2 + d_tang**2)
|
|
|
|
segment = torch.exp(-(dr**2) / 2 / segment_width**2)
|
|
segment = segment / segment.sum()
|
|
|
|
return segment
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import matplotlib.pyplot as plt
|
|
|
|
pg = PatchGenerator()
|
|
|
|
patch1 = pg.draw_segment(
|
|
patch_size=81, phi=math.pi / 4, segment_length=30, segment_width=5
|
|
)
|
|
patch2 = pg.draw_segment(patch_size=81, phi=0, segment_length=30, segment_width=5)
|
|
plt.imshow(torch.where(patch1 > patch2, patch1, patch2).cpu())
|
|
|
|
phos = pg.alphabet_phosphene()
|
|
plt.imshow(phos[0, 0].cpu())
|
|
|
|
n_filter = 8
|
|
n_dir = 8
|
|
clocks_filter, clocks, segments = pg.alphabet_clocks(n_dir=n_dir, n_filter=n_filter)
|
|
|
|
n_features = clocks_filter.shape[0]
|
|
print(n_features, "clock features generated!")
|
|
for i_feature in range(n_features):
|
|
for i_filter in range(n_filter):
|
|
plt.subplot(1, n_filter, i_filter + 1)
|
|
plt.imshow(clocks_filter[i_feature, i_filter].cpu())
|
|
plt.title("Feature #{}, Dir #{}".format(i_feature, i_filter))
|
|
plt.show()
|
|
|
|
|
|
# %%
|