4D epsilon xy tensor
This commit is contained in:
parent
c9f79f2e13
commit
e596c761d1
2 changed files with 112 additions and 35 deletions
|
@ -14,5 +14,5 @@ __all__ = [
|
||||||
|
|
||||||
class HDynamicCNNManyIP():
|
class HDynamicCNNManyIP():
|
||||||
def __init__(self) -> None: ...
|
def __init__(self) -> None: ...
|
||||||
def update_with_init_vector_multi_pattern(self, arg0: int, arg1: int, arg2: int, arg3: int, arg4: int, arg5: int, arg6: int, arg7: int, arg8: int, arg9: int, arg10: int, arg11: int, arg12: int, arg13: int, arg14: int, arg15: int, arg16: int, arg17: int, arg18: int, arg19: int, arg20: int) -> bool: ...
|
def update_with_init_vector_multi_pattern(self, arg0: int, arg1: int, arg2: int, arg3: int, arg4: int, arg5: int, arg6: int, arg7: int, arg8: int, arg9: int, arg10: int, arg11: int, arg12: int, arg13: int, arg14: int, arg15: int, arg16: int, arg17: int, arg18: int, arg19: int, arg20: int, arg21: int) -> bool: ...
|
||||||
pass
|
pass
|
||||||
|
|
145
SbS.py
145
SbS.py
|
@ -154,7 +154,7 @@ class SbS(torch.nn.Module):
|
||||||
def epsilon_xy(self, value: torch.Tensor):
|
def epsilon_xy(self, value: torch.Tensor):
|
||||||
assert value is not None
|
assert value is not None
|
||||||
assert torch.is_tensor(value) is True
|
assert torch.is_tensor(value) is True
|
||||||
assert value.dim() == 2
|
assert value.dim() == 4
|
||||||
assert value.dtype == torch.float64
|
assert value.dtype == torch.float64
|
||||||
if self._epsilon_xy_exists is False:
|
if self._epsilon_xy_exists is False:
|
||||||
self._epsilon_xy = torch.nn.parameter.Parameter(
|
self._epsilon_xy = torch.nn.parameter.Parameter(
|
||||||
|
@ -617,10 +617,16 @@ class SbS(torch.nn.Module):
|
||||||
"""Creates initial epsilon xy matrices"""
|
"""Creates initial epsilon xy matrices"""
|
||||||
|
|
||||||
assert self._output_size is not None
|
assert self._output_size is not None
|
||||||
|
assert self._kernel_size is not None
|
||||||
assert eps_xy_intitial > 0
|
assert eps_xy_intitial > 0
|
||||||
|
|
||||||
eps_xy_temp: torch.Tensor = torch.full(
|
eps_xy_temp: torch.Tensor = torch.full(
|
||||||
(int(self._output_size[0]), int(self._output_size[1])),
|
(
|
||||||
|
int(self._output_size[0]),
|
||||||
|
int(self._output_size[1]),
|
||||||
|
int(self._kernel_size[0]),
|
||||||
|
int(self._kernel_size[1]),
|
||||||
|
),
|
||||||
eps_xy_intitial,
|
eps_xy_intitial,
|
||||||
dtype=torch.float64,
|
dtype=torch.float64,
|
||||||
)
|
)
|
||||||
|
@ -759,6 +765,36 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
stride=(1, 1),
|
stride=(1, 1),
|
||||||
).requires_grad_(True)
|
).requires_grad_(True)
|
||||||
|
|
||||||
|
epsilon_xy_convolved: torch.Tensor = (
|
||||||
|
(
|
||||||
|
torch.nn.functional.unfold(
|
||||||
|
epsilon_xy.reshape(
|
||||||
|
(
|
||||||
|
int(epsilon_xy.shape[0]) * int(epsilon_xy.shape[1]),
|
||||||
|
int(epsilon_xy.shape[2]),
|
||||||
|
int(epsilon_xy.shape[3]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.tile((1, input.shape[1], 1, 1)),
|
||||||
|
kernel_size=tuple(kernel_size.tolist()),
|
||||||
|
dilation=1,
|
||||||
|
padding=0,
|
||||||
|
stride=1,
|
||||||
|
)
|
||||||
|
.squeeze(-1)
|
||||||
|
.reshape(
|
||||||
|
(
|
||||||
|
int(epsilon_xy.shape[0]),
|
||||||
|
int(epsilon_xy.shape[1]),
|
||||||
|
int(input_convolved.shape[1]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.moveaxis(-1, 0)
|
||||||
|
.contiguous(memory_format=torch.contiguous_format)
|
||||||
|
)
|
||||||
|
|
||||||
############################################################
|
############################################################
|
||||||
# Spike generation #
|
# Spike generation #
|
||||||
############################################################
|
############################################################
|
||||||
|
@ -864,7 +900,12 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
)
|
)
|
||||||
|
|
||||||
epsilon_scale: torch.Tensor = torch.ones(
|
epsilon_scale: torch.Tensor = torch.ones(
|
||||||
size=[1, int(epsilon_xy.shape[0]), int(epsilon_xy.shape[1]), 1],
|
size=[
|
||||||
|
int(spikes.shape[0]),
|
||||||
|
int(spikes.shape[2]),
|
||||||
|
int(spikes.shape[3]),
|
||||||
|
1,
|
||||||
|
],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -875,9 +916,28 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
epsilon_scale = torch.ones_like(epsilon_scale)
|
epsilon_scale = torch.ones_like(epsilon_scale)
|
||||||
|
|
||||||
h_temp: torch.Tensor = weights[spikes[:, t, :, :], :] * h
|
h_temp: torch.Tensor = weights[spikes[:, t, :, :], :] * h
|
||||||
epsilon_subsegment: torch.Tensor = (
|
wx = 0
|
||||||
epsilon_xy.unsqueeze(0).unsqueeze(-1) * epsilon_t[t] * epsilon_0
|
wy = 0
|
||||||
|
|
||||||
|
if t == 0:
|
||||||
|
epsilon_temp: torch.Tensor = torch.empty(
|
||||||
|
(
|
||||||
|
int(spikes.shape[0]),
|
||||||
|
int(spikes.shape[2]),
|
||||||
|
int(spikes.shape[3]),
|
||||||
|
),
|
||||||
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
|
for wx in range(0, int(spikes.shape[2])):
|
||||||
|
for wy in range(0, int(spikes.shape[3])):
|
||||||
|
epsilon_temp[:, wx, wy] = epsilon_xy_convolved[
|
||||||
|
spikes[:, t, wx, wy], wx, wy
|
||||||
|
]
|
||||||
|
|
||||||
|
epsilon_subsegment: torch.Tensor = (
|
||||||
|
epsilon_temp.unsqueeze(-1) * epsilon_t[t] * epsilon_0
|
||||||
|
)
|
||||||
|
|
||||||
h_temp_sum: torch.Tensor = (
|
h_temp_sum: torch.Tensor = (
|
||||||
epsilon_scale * epsilon_subsegment / h_temp.sum(dim=3, keepdim=True)
|
epsilon_scale * epsilon_subsegment / h_temp.sum(dim=3, keepdim=True)
|
||||||
)
|
)
|
||||||
|
@ -891,6 +951,7 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
|
|
||||||
h /= epsilon_scale
|
h /= epsilon_scale
|
||||||
output = h.movedim(3, 1)
|
output = h.movedim(3, 1)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
epsilon_t_0: torch.Tensor = epsilon_t * epsilon_0
|
epsilon_t_0: torch.Tensor = epsilon_t * epsilon_0
|
||||||
|
|
||||||
|
@ -909,10 +970,10 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
assert np_h.flags["C_CONTIGUOUS"] is True
|
assert np_h.flags["C_CONTIGUOUS"] is True
|
||||||
assert np_h.ndim == 4
|
assert np_h.ndim == 4
|
||||||
|
|
||||||
np_epsilon_xy: np.ndarray = epsilon_xy.detach().numpy()
|
np_epsilon_xy: np.ndarray = epsilon_xy_convolved.detach().numpy()
|
||||||
assert epsilon_xy.dtype == torch.float32
|
assert epsilon_xy.dtype == torch.float32
|
||||||
assert np_epsilon_xy.flags["C_CONTIGUOUS"] is True
|
assert np_epsilon_xy.flags["C_CONTIGUOUS"] is True
|
||||||
assert np_epsilon_xy.ndim == 2
|
assert np_epsilon_xy.ndim == 3
|
||||||
|
|
||||||
np_epsilon_t: np.ndarray = epsilon_t_0.detach().numpy()
|
np_epsilon_t: np.ndarray = epsilon_t_0.detach().numpy()
|
||||||
assert epsilon_t_0.dtype == torch.float32
|
assert epsilon_t_0.dtype == torch.float32
|
||||||
|
@ -948,6 +1009,7 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
np_epsilon_xy.__array_interface__["data"][0],
|
np_epsilon_xy.__array_interface__["data"][0],
|
||||||
int(np_epsilon_xy.shape[0]),
|
int(np_epsilon_xy.shape[0]),
|
||||||
int(np_epsilon_xy.shape[1]),
|
int(np_epsilon_xy.shape[1]),
|
||||||
|
int(np_epsilon_xy.shape[2]),
|
||||||
np_epsilon_t.__array_interface__["data"][0],
|
np_epsilon_t.__array_interface__["data"][0],
|
||||||
int(np_epsilon_t.shape[0]),
|
int(np_epsilon_t.shape[0]),
|
||||||
np_weights.__array_interface__["data"][0],
|
np_weights.__array_interface__["data"][0],
|
||||||
|
@ -1056,7 +1118,7 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
|
|
||||||
ctx.save_for_backward(
|
ctx.save_for_backward(
|
||||||
input_convolved,
|
input_convolved,
|
||||||
epsilon_xy_float64,
|
epsilon_xy_convolved,
|
||||||
epsilon_0_float64,
|
epsilon_0_float64,
|
||||||
weights_float64,
|
weights_float64,
|
||||||
output,
|
output,
|
||||||
|
@ -1075,7 +1137,7 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
# Get the variables back
|
# Get the variables back
|
||||||
(
|
(
|
||||||
input_float32,
|
input_float32,
|
||||||
epsilon_xy,
|
epsilon_xy_float32,
|
||||||
epsilon_0,
|
epsilon_0,
|
||||||
weights,
|
weights,
|
||||||
output,
|
output,
|
||||||
|
@ -1088,6 +1150,7 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
|
|
||||||
input = input_float32.type(dtype=torch.float64)
|
input = input_float32.type(dtype=torch.float64)
|
||||||
input /= input.sum(dim=1, keepdim=True, dtype=torch.float64)
|
input /= input.sum(dim=1, keepdim=True, dtype=torch.float64)
|
||||||
|
epsilon_xy = epsilon_xy_float32.type(dtype=torch.float64)
|
||||||
|
|
||||||
# For debugging:
|
# For debugging:
|
||||||
# print(
|
# print(
|
||||||
|
@ -1125,10 +1188,6 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
)
|
)
|
||||||
torch.clip(backprop_z, out=backprop_z, min=-1e300, max=1e300)
|
torch.clip(backprop_z, out=backprop_z, min=-1e300, max=1e300)
|
||||||
|
|
||||||
backprop_y: torch.Tensor = (
|
|
||||||
torch.einsum("bijxy,bixy->bjxy", backprop_z, input) - output
|
|
||||||
)
|
|
||||||
|
|
||||||
result_omega: torch.Tensor = backprop_bigr.unsqueeze(2) * grad_output.unsqueeze(
|
result_omega: torch.Tensor = backprop_bigr.unsqueeze(2) * grad_output.unsqueeze(
|
||||||
1
|
1
|
||||||
)
|
)
|
||||||
|
@ -1136,29 +1195,25 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
"bijxy,bjxy->bixy", backprop_r, grad_output
|
"bijxy,bjxy->bixy", backprop_r, grad_output
|
||||||
).unsqueeze(2)
|
).unsqueeze(2)
|
||||||
result_omega *= backprop_f
|
result_omega *= backprop_f
|
||||||
torch.nan_to_num(
|
|
||||||
result_omega, out=result_omega, nan=1e300, posinf=1e300, neginf=-1e300
|
|
||||||
)
|
|
||||||
torch.clip(result_omega, out=result_omega, min=-1e300, max=1e300)
|
|
||||||
|
|
||||||
result_eps_xy: torch.Tensor = (
|
result_eps_xy: torch.Tensor = (
|
||||||
torch.einsum("bixy,bixy->bxy", backprop_y, grad_output) * eps_b
|
(
|
||||||
|
(backprop_z * input.unsqueeze(2) - output.unsqueeze(1))
|
||||||
|
* grad_output.unsqueeze(1)
|
||||||
)
|
)
|
||||||
torch.nan_to_num(
|
.sum(dim=2)
|
||||||
result_eps_xy, out=result_eps_xy, nan=1e300, posinf=1e300, neginf=-1e300
|
.sum(dim=0)
|
||||||
)
|
) * eps_b
|
||||||
torch.clip(result_eps_xy, out=result_eps_xy, min=-1e300, max=1e300)
|
|
||||||
|
|
||||||
result_phi: torch.Tensor = torch.einsum(
|
result_phi: torch.Tensor = torch.einsum(
|
||||||
"bijxy,bjxy->bixy", backprop_z, grad_output
|
"bijxy,bjxy->bixy", backprop_z, grad_output
|
||||||
) * eps_a.unsqueeze(0).unsqueeze(0)
|
) * eps_a.unsqueeze(0)
|
||||||
torch.nan_to_num(
|
|
||||||
result_phi, out=result_phi, nan=1e300, posinf=1e300, neginf=-1e300
|
|
||||||
)
|
|
||||||
torch.clip(result_phi, out=result_phi, min=-1e300, max=1e300)
|
|
||||||
|
|
||||||
grad_weights = result_omega.sum(0).sum(-1).sum(-1)
|
grad_weights = result_omega.sum(0).sum(-1).sum(-1)
|
||||||
grad_eps_xy = result_eps_xy.sum(0)
|
torch.nan_to_num(
|
||||||
|
grad_weights, out=grad_weights, nan=1e300, posinf=1e300, neginf=-1e300
|
||||||
|
)
|
||||||
|
torch.clip(grad_weights, out=grad_weights, min=-1e300, max=1e300)
|
||||||
|
|
||||||
grad_input = torch.nn.functional.fold(
|
grad_input = torch.nn.functional.fold(
|
||||||
torch.nn.functional.unfold(
|
torch.nn.functional.unfold(
|
||||||
|
@ -1174,22 +1229,41 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
padding=padding,
|
padding=padding,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.nan_to_num(
|
torch.nan_to_num(
|
||||||
grad_input, out=grad_input, nan=1e300, posinf=1e300, neginf=-1e300
|
grad_input, out=grad_input, nan=1e300, posinf=1e300, neginf=-1e300
|
||||||
)
|
)
|
||||||
torch.clip(grad_input, out=grad_input, min=-1e300, max=1e300)
|
torch.clip(grad_input, out=grad_input, min=-1e300, max=1e300)
|
||||||
|
|
||||||
|
grad_eps_xy_temp = torch.nn.functional.fold(
|
||||||
|
result_eps_xy.moveaxis(0, -1)
|
||||||
|
.reshape(
|
||||||
|
(
|
||||||
|
int(result_eps_xy.shape[1]) * int(result_eps_xy.shape[2]),
|
||||||
|
int(result_eps_xy.shape[0]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.unsqueeze(-1),
|
||||||
|
output_size=kernel_size,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
grad_eps_xy = (
|
||||||
|
grad_eps_xy_temp.sum(dim=1)
|
||||||
|
.reshape(
|
||||||
|
(
|
||||||
|
int(result_eps_xy.shape[1]),
|
||||||
|
int(result_eps_xy.shape[2]),
|
||||||
|
int(grad_eps_xy_temp.shape[-2]),
|
||||||
|
int(grad_eps_xy_temp.shape[-1]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.contiguous(memory_format=torch.contiguous_format)
|
||||||
|
)
|
||||||
torch.nan_to_num(
|
torch.nan_to_num(
|
||||||
grad_eps_xy, out=grad_eps_xy, nan=1e300, posinf=1e300, neginf=-1e300
|
grad_eps_xy, out=grad_eps_xy, nan=1e300, posinf=1e300, neginf=-1e300
|
||||||
)
|
)
|
||||||
torch.clip(grad_eps_xy, out=grad_eps_xy, min=-1e300, max=1e300)
|
torch.clip(grad_eps_xy, out=grad_eps_xy, min=-1e300, max=1e300)
|
||||||
|
|
||||||
torch.nan_to_num(
|
|
||||||
grad_weights, out=grad_weights, nan=1e300, posinf=1e300, neginf=-1e300
|
|
||||||
)
|
|
||||||
torch.clip(grad_weights, out=grad_weights, min=-1e300, max=1e300)
|
|
||||||
|
|
||||||
grad_epsilon_0 = None
|
grad_epsilon_0 = None
|
||||||
grad_epsilon_t = None
|
grad_epsilon_t = None
|
||||||
grad_kernel_size = None
|
grad_kernel_size = None
|
||||||
|
@ -1218,3 +1292,6 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
grad_h_initial,
|
grad_h_initial,
|
||||||
grad_alpha_number_of_iterations,
|
grad_alpha_number_of_iterations,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
Loading…
Reference in a new issue