4D epsilon xy tensor

This commit is contained in:
David Rotermund 2022-05-03 11:31:29 +02:00 committed by GitHub
parent c9f79f2e13
commit e596c761d1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 112 additions and 35 deletions

View file

@ -14,5 +14,5 @@ __all__ = [
class HDynamicCNNManyIP(): class HDynamicCNNManyIP():
def __init__(self) -> None: ... def __init__(self) -> None: ...
def update_with_init_vector_multi_pattern(self, arg0: int, arg1: int, arg2: int, arg3: int, arg4: int, arg5: int, arg6: int, arg7: int, arg8: int, arg9: int, arg10: int, arg11: int, arg12: int, arg13: int, arg14: int, arg15: int, arg16: int, arg17: int, arg18: int, arg19: int, arg20: int) -> bool: ... def update_with_init_vector_multi_pattern(self, arg0: int, arg1: int, arg2: int, arg3: int, arg4: int, arg5: int, arg6: int, arg7: int, arg8: int, arg9: int, arg10: int, arg11: int, arg12: int, arg13: int, arg14: int, arg15: int, arg16: int, arg17: int, arg18: int, arg19: int, arg20: int, arg21: int) -> bool: ...
pass pass

145
SbS.py
View file

@ -154,7 +154,7 @@ class SbS(torch.nn.Module):
def epsilon_xy(self, value: torch.Tensor): def epsilon_xy(self, value: torch.Tensor):
assert value is not None assert value is not None
assert torch.is_tensor(value) is True assert torch.is_tensor(value) is True
assert value.dim() == 2 assert value.dim() == 4
assert value.dtype == torch.float64 assert value.dtype == torch.float64
if self._epsilon_xy_exists is False: if self._epsilon_xy_exists is False:
self._epsilon_xy = torch.nn.parameter.Parameter( self._epsilon_xy = torch.nn.parameter.Parameter(
@ -617,10 +617,16 @@ class SbS(torch.nn.Module):
"""Creates initial epsilon xy matrices""" """Creates initial epsilon xy matrices"""
assert self._output_size is not None assert self._output_size is not None
assert self._kernel_size is not None
assert eps_xy_intitial > 0 assert eps_xy_intitial > 0
eps_xy_temp: torch.Tensor = torch.full( eps_xy_temp: torch.Tensor = torch.full(
(int(self._output_size[0]), int(self._output_size[1])), (
int(self._output_size[0]),
int(self._output_size[1]),
int(self._kernel_size[0]),
int(self._kernel_size[1]),
),
eps_xy_intitial, eps_xy_intitial,
dtype=torch.float64, dtype=torch.float64,
) )
@ -759,6 +765,36 @@ class FunctionalSbS(torch.autograd.Function):
stride=(1, 1), stride=(1, 1),
).requires_grad_(True) ).requires_grad_(True)
epsilon_xy_convolved: torch.Tensor = (
(
torch.nn.functional.unfold(
epsilon_xy.reshape(
(
int(epsilon_xy.shape[0]) * int(epsilon_xy.shape[1]),
int(epsilon_xy.shape[2]),
int(epsilon_xy.shape[3]),
)
)
.unsqueeze(1)
.tile((1, input.shape[1], 1, 1)),
kernel_size=tuple(kernel_size.tolist()),
dilation=1,
padding=0,
stride=1,
)
.squeeze(-1)
.reshape(
(
int(epsilon_xy.shape[0]),
int(epsilon_xy.shape[1]),
int(input_convolved.shape[1]),
)
)
)
.moveaxis(-1, 0)
.contiguous(memory_format=torch.contiguous_format)
)
############################################################ ############################################################
# Spike generation # # Spike generation #
############################################################ ############################################################
@ -864,7 +900,12 @@ class FunctionalSbS(torch.autograd.Function):
) )
epsilon_scale: torch.Tensor = torch.ones( epsilon_scale: torch.Tensor = torch.ones(
size=[1, int(epsilon_xy.shape[0]), int(epsilon_xy.shape[1]), 1], size=[
int(spikes.shape[0]),
int(spikes.shape[2]),
int(spikes.shape[3]),
1,
],
dtype=torch.float32, dtype=torch.float32,
) )
@ -875,9 +916,28 @@ class FunctionalSbS(torch.autograd.Function):
epsilon_scale = torch.ones_like(epsilon_scale) epsilon_scale = torch.ones_like(epsilon_scale)
h_temp: torch.Tensor = weights[spikes[:, t, :, :], :] * h h_temp: torch.Tensor = weights[spikes[:, t, :, :], :] * h
wx = 0
wy = 0
if t == 0:
epsilon_temp: torch.Tensor = torch.empty(
(
int(spikes.shape[0]),
int(spikes.shape[2]),
int(spikes.shape[3]),
),
dtype=torch.float32,
)
for wx in range(0, int(spikes.shape[2])):
for wy in range(0, int(spikes.shape[3])):
epsilon_temp[:, wx, wy] = epsilon_xy_convolved[
spikes[:, t, wx, wy], wx, wy
]
epsilon_subsegment: torch.Tensor = ( epsilon_subsegment: torch.Tensor = (
epsilon_xy.unsqueeze(0).unsqueeze(-1) * epsilon_t[t] * epsilon_0 epsilon_temp.unsqueeze(-1) * epsilon_t[t] * epsilon_0
) )
h_temp_sum: torch.Tensor = ( h_temp_sum: torch.Tensor = (
epsilon_scale * epsilon_subsegment / h_temp.sum(dim=3, keepdim=True) epsilon_scale * epsilon_subsegment / h_temp.sum(dim=3, keepdim=True)
) )
@ -891,6 +951,7 @@ class FunctionalSbS(torch.autograd.Function):
h /= epsilon_scale h /= epsilon_scale
output = h.movedim(3, 1) output = h.movedim(3, 1)
else: else:
epsilon_t_0: torch.Tensor = epsilon_t * epsilon_0 epsilon_t_0: torch.Tensor = epsilon_t * epsilon_0
@ -909,10 +970,10 @@ class FunctionalSbS(torch.autograd.Function):
assert np_h.flags["C_CONTIGUOUS"] is True assert np_h.flags["C_CONTIGUOUS"] is True
assert np_h.ndim == 4 assert np_h.ndim == 4
np_epsilon_xy: np.ndarray = epsilon_xy.detach().numpy() np_epsilon_xy: np.ndarray = epsilon_xy_convolved.detach().numpy()
assert epsilon_xy.dtype == torch.float32 assert epsilon_xy.dtype == torch.float32
assert np_epsilon_xy.flags["C_CONTIGUOUS"] is True assert np_epsilon_xy.flags["C_CONTIGUOUS"] is True
assert np_epsilon_xy.ndim == 2 assert np_epsilon_xy.ndim == 3
np_epsilon_t: np.ndarray = epsilon_t_0.detach().numpy() np_epsilon_t: np.ndarray = epsilon_t_0.detach().numpy()
assert epsilon_t_0.dtype == torch.float32 assert epsilon_t_0.dtype == torch.float32
@ -948,6 +1009,7 @@ class FunctionalSbS(torch.autograd.Function):
np_epsilon_xy.__array_interface__["data"][0], np_epsilon_xy.__array_interface__["data"][0],
int(np_epsilon_xy.shape[0]), int(np_epsilon_xy.shape[0]),
int(np_epsilon_xy.shape[1]), int(np_epsilon_xy.shape[1]),
int(np_epsilon_xy.shape[2]),
np_epsilon_t.__array_interface__["data"][0], np_epsilon_t.__array_interface__["data"][0],
int(np_epsilon_t.shape[0]), int(np_epsilon_t.shape[0]),
np_weights.__array_interface__["data"][0], np_weights.__array_interface__["data"][0],
@ -1056,7 +1118,7 @@ class FunctionalSbS(torch.autograd.Function):
ctx.save_for_backward( ctx.save_for_backward(
input_convolved, input_convolved,
epsilon_xy_float64, epsilon_xy_convolved,
epsilon_0_float64, epsilon_0_float64,
weights_float64, weights_float64,
output, output,
@ -1075,7 +1137,7 @@ class FunctionalSbS(torch.autograd.Function):
# Get the variables back # Get the variables back
( (
input_float32, input_float32,
epsilon_xy, epsilon_xy_float32,
epsilon_0, epsilon_0,
weights, weights,
output, output,
@ -1088,6 +1150,7 @@ class FunctionalSbS(torch.autograd.Function):
input = input_float32.type(dtype=torch.float64) input = input_float32.type(dtype=torch.float64)
input /= input.sum(dim=1, keepdim=True, dtype=torch.float64) input /= input.sum(dim=1, keepdim=True, dtype=torch.float64)
epsilon_xy = epsilon_xy_float32.type(dtype=torch.float64)
# For debugging: # For debugging:
# print( # print(
@ -1125,10 +1188,6 @@ class FunctionalSbS(torch.autograd.Function):
) )
torch.clip(backprop_z, out=backprop_z, min=-1e300, max=1e300) torch.clip(backprop_z, out=backprop_z, min=-1e300, max=1e300)
backprop_y: torch.Tensor = (
torch.einsum("bijxy,bixy->bjxy", backprop_z, input) - output
)
result_omega: torch.Tensor = backprop_bigr.unsqueeze(2) * grad_output.unsqueeze( result_omega: torch.Tensor = backprop_bigr.unsqueeze(2) * grad_output.unsqueeze(
1 1
) )
@ -1136,29 +1195,25 @@ class FunctionalSbS(torch.autograd.Function):
"bijxy,bjxy->bixy", backprop_r, grad_output "bijxy,bjxy->bixy", backprop_r, grad_output
).unsqueeze(2) ).unsqueeze(2)
result_omega *= backprop_f result_omega *= backprop_f
torch.nan_to_num(
result_omega, out=result_omega, nan=1e300, posinf=1e300, neginf=-1e300
)
torch.clip(result_omega, out=result_omega, min=-1e300, max=1e300)
result_eps_xy: torch.Tensor = ( result_eps_xy: torch.Tensor = (
torch.einsum("bixy,bixy->bxy", backprop_y, grad_output) * eps_b (
) (backprop_z * input.unsqueeze(2) - output.unsqueeze(1))
torch.nan_to_num( * grad_output.unsqueeze(1)
result_eps_xy, out=result_eps_xy, nan=1e300, posinf=1e300, neginf=-1e300 )
) .sum(dim=2)
torch.clip(result_eps_xy, out=result_eps_xy, min=-1e300, max=1e300) .sum(dim=0)
) * eps_b
result_phi: torch.Tensor = torch.einsum( result_phi: torch.Tensor = torch.einsum(
"bijxy,bjxy->bixy", backprop_z, grad_output "bijxy,bjxy->bixy", backprop_z, grad_output
) * eps_a.unsqueeze(0).unsqueeze(0) ) * eps_a.unsqueeze(0)
torch.nan_to_num(
result_phi, out=result_phi, nan=1e300, posinf=1e300, neginf=-1e300
)
torch.clip(result_phi, out=result_phi, min=-1e300, max=1e300)
grad_weights = result_omega.sum(0).sum(-1).sum(-1) grad_weights = result_omega.sum(0).sum(-1).sum(-1)
grad_eps_xy = result_eps_xy.sum(0) torch.nan_to_num(
grad_weights, out=grad_weights, nan=1e300, posinf=1e300, neginf=-1e300
)
torch.clip(grad_weights, out=grad_weights, min=-1e300, max=1e300)
grad_input = torch.nn.functional.fold( grad_input = torch.nn.functional.fold(
torch.nn.functional.unfold( torch.nn.functional.unfold(
@ -1174,22 +1229,41 @@ class FunctionalSbS(torch.autograd.Function):
padding=padding, padding=padding,
stride=stride, stride=stride,
) )
torch.nan_to_num( torch.nan_to_num(
grad_input, out=grad_input, nan=1e300, posinf=1e300, neginf=-1e300 grad_input, out=grad_input, nan=1e300, posinf=1e300, neginf=-1e300
) )
torch.clip(grad_input, out=grad_input, min=-1e300, max=1e300) torch.clip(grad_input, out=grad_input, min=-1e300, max=1e300)
grad_eps_xy_temp = torch.nn.functional.fold(
result_eps_xy.moveaxis(0, -1)
.reshape(
(
int(result_eps_xy.shape[1]) * int(result_eps_xy.shape[2]),
int(result_eps_xy.shape[0]),
)
)
.unsqueeze(-1),
output_size=kernel_size,
kernel_size=kernel_size,
)
grad_eps_xy = (
grad_eps_xy_temp.sum(dim=1)
.reshape(
(
int(result_eps_xy.shape[1]),
int(result_eps_xy.shape[2]),
int(grad_eps_xy_temp.shape[-2]),
int(grad_eps_xy_temp.shape[-1]),
)
)
.contiguous(memory_format=torch.contiguous_format)
)
torch.nan_to_num( torch.nan_to_num(
grad_eps_xy, out=grad_eps_xy, nan=1e300, posinf=1e300, neginf=-1e300 grad_eps_xy, out=grad_eps_xy, nan=1e300, posinf=1e300, neginf=-1e300
) )
torch.clip(grad_eps_xy, out=grad_eps_xy, min=-1e300, max=1e300) torch.clip(grad_eps_xy, out=grad_eps_xy, min=-1e300, max=1e300)
torch.nan_to_num(
grad_weights, out=grad_weights, nan=1e300, posinf=1e300, neginf=-1e300
)
torch.clip(grad_weights, out=grad_weights, min=-1e300, max=1e300)
grad_epsilon_0 = None grad_epsilon_0 = None
grad_epsilon_t = None grad_epsilon_t = None
grad_kernel_size = None grad_kernel_size = None
@ -1218,3 +1292,6 @@ class FunctionalSbS(torch.autograd.Function):
grad_h_initial, grad_h_initial,
grad_alpha_number_of_iterations, grad_alpha_number_of_iterations,
) )
# %%