4D epsilon xy tensor
This commit is contained in:
parent
c9f79f2e13
commit
e596c761d1
2 changed files with 112 additions and 35 deletions
|
@ -14,5 +14,5 @@ __all__ = [
|
|||
|
||||
class HDynamicCNNManyIP():
|
||||
def __init__(self) -> None: ...
|
||||
def update_with_init_vector_multi_pattern(self, arg0: int, arg1: int, arg2: int, arg3: int, arg4: int, arg5: int, arg6: int, arg7: int, arg8: int, arg9: int, arg10: int, arg11: int, arg12: int, arg13: int, arg14: int, arg15: int, arg16: int, arg17: int, arg18: int, arg19: int, arg20: int) -> bool: ...
|
||||
def update_with_init_vector_multi_pattern(self, arg0: int, arg1: int, arg2: int, arg3: int, arg4: int, arg5: int, arg6: int, arg7: int, arg8: int, arg9: int, arg10: int, arg11: int, arg12: int, arg13: int, arg14: int, arg15: int, arg16: int, arg17: int, arg18: int, arg19: int, arg20: int, arg21: int) -> bool: ...
|
||||
pass
|
||||
|
|
145
SbS.py
145
SbS.py
|
@ -154,7 +154,7 @@ class SbS(torch.nn.Module):
|
|||
def epsilon_xy(self, value: torch.Tensor):
|
||||
assert value is not None
|
||||
assert torch.is_tensor(value) is True
|
||||
assert value.dim() == 2
|
||||
assert value.dim() == 4
|
||||
assert value.dtype == torch.float64
|
||||
if self._epsilon_xy_exists is False:
|
||||
self._epsilon_xy = torch.nn.parameter.Parameter(
|
||||
|
@ -617,10 +617,16 @@ class SbS(torch.nn.Module):
|
|||
"""Creates initial epsilon xy matrices"""
|
||||
|
||||
assert self._output_size is not None
|
||||
assert self._kernel_size is not None
|
||||
assert eps_xy_intitial > 0
|
||||
|
||||
eps_xy_temp: torch.Tensor = torch.full(
|
||||
(int(self._output_size[0]), int(self._output_size[1])),
|
||||
(
|
||||
int(self._output_size[0]),
|
||||
int(self._output_size[1]),
|
||||
int(self._kernel_size[0]),
|
||||
int(self._kernel_size[1]),
|
||||
),
|
||||
eps_xy_intitial,
|
||||
dtype=torch.float64,
|
||||
)
|
||||
|
@ -759,6 +765,36 @@ class FunctionalSbS(torch.autograd.Function):
|
|||
stride=(1, 1),
|
||||
).requires_grad_(True)
|
||||
|
||||
epsilon_xy_convolved: torch.Tensor = (
|
||||
(
|
||||
torch.nn.functional.unfold(
|
||||
epsilon_xy.reshape(
|
||||
(
|
||||
int(epsilon_xy.shape[0]) * int(epsilon_xy.shape[1]),
|
||||
int(epsilon_xy.shape[2]),
|
||||
int(epsilon_xy.shape[3]),
|
||||
)
|
||||
)
|
||||
.unsqueeze(1)
|
||||
.tile((1, input.shape[1], 1, 1)),
|
||||
kernel_size=tuple(kernel_size.tolist()),
|
||||
dilation=1,
|
||||
padding=0,
|
||||
stride=1,
|
||||
)
|
||||
.squeeze(-1)
|
||||
.reshape(
|
||||
(
|
||||
int(epsilon_xy.shape[0]),
|
||||
int(epsilon_xy.shape[1]),
|
||||
int(input_convolved.shape[1]),
|
||||
)
|
||||
)
|
||||
)
|
||||
.moveaxis(-1, 0)
|
||||
.contiguous(memory_format=torch.contiguous_format)
|
||||
)
|
||||
|
||||
############################################################
|
||||
# Spike generation #
|
||||
############################################################
|
||||
|
@ -864,7 +900,12 @@ class FunctionalSbS(torch.autograd.Function):
|
|||
)
|
||||
|
||||
epsilon_scale: torch.Tensor = torch.ones(
|
||||
size=[1, int(epsilon_xy.shape[0]), int(epsilon_xy.shape[1]), 1],
|
||||
size=[
|
||||
int(spikes.shape[0]),
|
||||
int(spikes.shape[2]),
|
||||
int(spikes.shape[3]),
|
||||
1,
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
|
@ -875,9 +916,28 @@ class FunctionalSbS(torch.autograd.Function):
|
|||
epsilon_scale = torch.ones_like(epsilon_scale)
|
||||
|
||||
h_temp: torch.Tensor = weights[spikes[:, t, :, :], :] * h
|
||||
epsilon_subsegment: torch.Tensor = (
|
||||
epsilon_xy.unsqueeze(0).unsqueeze(-1) * epsilon_t[t] * epsilon_0
|
||||
wx = 0
|
||||
wy = 0
|
||||
|
||||
if t == 0:
|
||||
epsilon_temp: torch.Tensor = torch.empty(
|
||||
(
|
||||
int(spikes.shape[0]),
|
||||
int(spikes.shape[2]),
|
||||
int(spikes.shape[3]),
|
||||
),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
for wx in range(0, int(spikes.shape[2])):
|
||||
for wy in range(0, int(spikes.shape[3])):
|
||||
epsilon_temp[:, wx, wy] = epsilon_xy_convolved[
|
||||
spikes[:, t, wx, wy], wx, wy
|
||||
]
|
||||
|
||||
epsilon_subsegment: torch.Tensor = (
|
||||
epsilon_temp.unsqueeze(-1) * epsilon_t[t] * epsilon_0
|
||||
)
|
||||
|
||||
h_temp_sum: torch.Tensor = (
|
||||
epsilon_scale * epsilon_subsegment / h_temp.sum(dim=3, keepdim=True)
|
||||
)
|
||||
|
@ -891,6 +951,7 @@ class FunctionalSbS(torch.autograd.Function):
|
|||
|
||||
h /= epsilon_scale
|
||||
output = h.movedim(3, 1)
|
||||
|
||||
else:
|
||||
epsilon_t_0: torch.Tensor = epsilon_t * epsilon_0
|
||||
|
||||
|
@ -909,10 +970,10 @@ class FunctionalSbS(torch.autograd.Function):
|
|||
assert np_h.flags["C_CONTIGUOUS"] is True
|
||||
assert np_h.ndim == 4
|
||||
|
||||
np_epsilon_xy: np.ndarray = epsilon_xy.detach().numpy()
|
||||
np_epsilon_xy: np.ndarray = epsilon_xy_convolved.detach().numpy()
|
||||
assert epsilon_xy.dtype == torch.float32
|
||||
assert np_epsilon_xy.flags["C_CONTIGUOUS"] is True
|
||||
assert np_epsilon_xy.ndim == 2
|
||||
assert np_epsilon_xy.ndim == 3
|
||||
|
||||
np_epsilon_t: np.ndarray = epsilon_t_0.detach().numpy()
|
||||
assert epsilon_t_0.dtype == torch.float32
|
||||
|
@ -948,6 +1009,7 @@ class FunctionalSbS(torch.autograd.Function):
|
|||
np_epsilon_xy.__array_interface__["data"][0],
|
||||
int(np_epsilon_xy.shape[0]),
|
||||
int(np_epsilon_xy.shape[1]),
|
||||
int(np_epsilon_xy.shape[2]),
|
||||
np_epsilon_t.__array_interface__["data"][0],
|
||||
int(np_epsilon_t.shape[0]),
|
||||
np_weights.__array_interface__["data"][0],
|
||||
|
@ -1056,7 +1118,7 @@ class FunctionalSbS(torch.autograd.Function):
|
|||
|
||||
ctx.save_for_backward(
|
||||
input_convolved,
|
||||
epsilon_xy_float64,
|
||||
epsilon_xy_convolved,
|
||||
epsilon_0_float64,
|
||||
weights_float64,
|
||||
output,
|
||||
|
@ -1075,7 +1137,7 @@ class FunctionalSbS(torch.autograd.Function):
|
|||
# Get the variables back
|
||||
(
|
||||
input_float32,
|
||||
epsilon_xy,
|
||||
epsilon_xy_float32,
|
||||
epsilon_0,
|
||||
weights,
|
||||
output,
|
||||
|
@ -1088,6 +1150,7 @@ class FunctionalSbS(torch.autograd.Function):
|
|||
|
||||
input = input_float32.type(dtype=torch.float64)
|
||||
input /= input.sum(dim=1, keepdim=True, dtype=torch.float64)
|
||||
epsilon_xy = epsilon_xy_float32.type(dtype=torch.float64)
|
||||
|
||||
# For debugging:
|
||||
# print(
|
||||
|
@ -1125,10 +1188,6 @@ class FunctionalSbS(torch.autograd.Function):
|
|||
)
|
||||
torch.clip(backprop_z, out=backprop_z, min=-1e300, max=1e300)
|
||||
|
||||
backprop_y: torch.Tensor = (
|
||||
torch.einsum("bijxy,bixy->bjxy", backprop_z, input) - output
|
||||
)
|
||||
|
||||
result_omega: torch.Tensor = backprop_bigr.unsqueeze(2) * grad_output.unsqueeze(
|
||||
1
|
||||
)
|
||||
|
@ -1136,29 +1195,25 @@ class FunctionalSbS(torch.autograd.Function):
|
|||
"bijxy,bjxy->bixy", backprop_r, grad_output
|
||||
).unsqueeze(2)
|
||||
result_omega *= backprop_f
|
||||
torch.nan_to_num(
|
||||
result_omega, out=result_omega, nan=1e300, posinf=1e300, neginf=-1e300
|
||||
)
|
||||
torch.clip(result_omega, out=result_omega, min=-1e300, max=1e300)
|
||||
|
||||
result_eps_xy: torch.Tensor = (
|
||||
torch.einsum("bixy,bixy->bxy", backprop_y, grad_output) * eps_b
|
||||
(
|
||||
(backprop_z * input.unsqueeze(2) - output.unsqueeze(1))
|
||||
* grad_output.unsqueeze(1)
|
||||
)
|
||||
torch.nan_to_num(
|
||||
result_eps_xy, out=result_eps_xy, nan=1e300, posinf=1e300, neginf=-1e300
|
||||
)
|
||||
torch.clip(result_eps_xy, out=result_eps_xy, min=-1e300, max=1e300)
|
||||
.sum(dim=2)
|
||||
.sum(dim=0)
|
||||
) * eps_b
|
||||
|
||||
result_phi: torch.Tensor = torch.einsum(
|
||||
"bijxy,bjxy->bixy", backprop_z, grad_output
|
||||
) * eps_a.unsqueeze(0).unsqueeze(0)
|
||||
torch.nan_to_num(
|
||||
result_phi, out=result_phi, nan=1e300, posinf=1e300, neginf=-1e300
|
||||
)
|
||||
torch.clip(result_phi, out=result_phi, min=-1e300, max=1e300)
|
||||
) * eps_a.unsqueeze(0)
|
||||
|
||||
grad_weights = result_omega.sum(0).sum(-1).sum(-1)
|
||||
grad_eps_xy = result_eps_xy.sum(0)
|
||||
torch.nan_to_num(
|
||||
grad_weights, out=grad_weights, nan=1e300, posinf=1e300, neginf=-1e300
|
||||
)
|
||||
torch.clip(grad_weights, out=grad_weights, min=-1e300, max=1e300)
|
||||
|
||||
grad_input = torch.nn.functional.fold(
|
||||
torch.nn.functional.unfold(
|
||||
|
@ -1174,22 +1229,41 @@ class FunctionalSbS(torch.autograd.Function):
|
|||
padding=padding,
|
||||
stride=stride,
|
||||
)
|
||||
|
||||
torch.nan_to_num(
|
||||
grad_input, out=grad_input, nan=1e300, posinf=1e300, neginf=-1e300
|
||||
)
|
||||
torch.clip(grad_input, out=grad_input, min=-1e300, max=1e300)
|
||||
|
||||
grad_eps_xy_temp = torch.nn.functional.fold(
|
||||
result_eps_xy.moveaxis(0, -1)
|
||||
.reshape(
|
||||
(
|
||||
int(result_eps_xy.shape[1]) * int(result_eps_xy.shape[2]),
|
||||
int(result_eps_xy.shape[0]),
|
||||
)
|
||||
)
|
||||
.unsqueeze(-1),
|
||||
output_size=kernel_size,
|
||||
kernel_size=kernel_size,
|
||||
)
|
||||
|
||||
grad_eps_xy = (
|
||||
grad_eps_xy_temp.sum(dim=1)
|
||||
.reshape(
|
||||
(
|
||||
int(result_eps_xy.shape[1]),
|
||||
int(result_eps_xy.shape[2]),
|
||||
int(grad_eps_xy_temp.shape[-2]),
|
||||
int(grad_eps_xy_temp.shape[-1]),
|
||||
)
|
||||
)
|
||||
.contiguous(memory_format=torch.contiguous_format)
|
||||
)
|
||||
torch.nan_to_num(
|
||||
grad_eps_xy, out=grad_eps_xy, nan=1e300, posinf=1e300, neginf=-1e300
|
||||
)
|
||||
torch.clip(grad_eps_xy, out=grad_eps_xy, min=-1e300, max=1e300)
|
||||
|
||||
torch.nan_to_num(
|
||||
grad_weights, out=grad_weights, nan=1e300, posinf=1e300, neginf=-1e300
|
||||
)
|
||||
torch.clip(grad_weights, out=grad_weights, min=-1e300, max=1e300)
|
||||
|
||||
grad_epsilon_0 = None
|
||||
grad_epsilon_t = None
|
||||
grad_kernel_size = None
|
||||
|
@ -1218,3 +1292,6 @@ class FunctionalSbS(torch.autograd.Function):
|
|||
grad_h_initial,
|
||||
grad_alpha_number_of_iterations,
|
||||
)
|
||||
|
||||
|
||||
# %%
|
||||
|
|
Loading…
Reference in a new issue