diff --git a/PyHDynamicCNNManyIP.pyi b/PyHDynamicCNNManyIP.pyi index d86ff54..d5d1f6f 100644 --- a/PyHDynamicCNNManyIP.pyi +++ b/PyHDynamicCNNManyIP.pyi @@ -14,5 +14,5 @@ __all__ = [ class HDynamicCNNManyIP(): def __init__(self) -> None: ... - def update_with_init_vector_multi_pattern(self, arg0: int, arg1: int, arg2: int, arg3: int, arg4: int, arg5: int, arg6: int, arg7: int, arg8: int, arg9: int, arg10: int, arg11: int, arg12: int, arg13: int, arg14: int, arg15: int, arg16: int, arg17: int, arg18: int, arg19: int, arg20: int) -> bool: ... + def update_with_init_vector_multi_pattern(self, arg0: int, arg1: int, arg2: int, arg3: int, arg4: int, arg5: int, arg6: int, arg7: int, arg8: int, arg9: int, arg10: int, arg11: int, arg12: int, arg13: int, arg14: int, arg15: int, arg16: int, arg17: int, arg18: int, arg19: int, arg20: int, arg21: int) -> bool: ... pass diff --git a/SbS.py b/SbS.py index 7b56a08..70c5a1f 100644 --- a/SbS.py +++ b/SbS.py @@ -154,7 +154,7 @@ class SbS(torch.nn.Module): def epsilon_xy(self, value: torch.Tensor): assert value is not None assert torch.is_tensor(value) is True - assert value.dim() == 2 + assert value.dim() == 4 assert value.dtype == torch.float64 if self._epsilon_xy_exists is False: self._epsilon_xy = torch.nn.parameter.Parameter( @@ -617,10 +617,16 @@ class SbS(torch.nn.Module): """Creates initial epsilon xy matrices""" assert self._output_size is not None + assert self._kernel_size is not None assert eps_xy_intitial > 0 eps_xy_temp: torch.Tensor = torch.full( - (int(self._output_size[0]), int(self._output_size[1])), + ( + int(self._output_size[0]), + int(self._output_size[1]), + int(self._kernel_size[0]), + int(self._kernel_size[1]), + ), eps_xy_intitial, dtype=torch.float64, ) @@ -759,6 +765,36 @@ class FunctionalSbS(torch.autograd.Function): stride=(1, 1), ).requires_grad_(True) + epsilon_xy_convolved: torch.Tensor = ( + ( + torch.nn.functional.unfold( + epsilon_xy.reshape( + ( + int(epsilon_xy.shape[0]) * int(epsilon_xy.shape[1]), + int(epsilon_xy.shape[2]), + int(epsilon_xy.shape[3]), + ) + ) + .unsqueeze(1) + .tile((1, input.shape[1], 1, 1)), + kernel_size=tuple(kernel_size.tolist()), + dilation=1, + padding=0, + stride=1, + ) + .squeeze(-1) + .reshape( + ( + int(epsilon_xy.shape[0]), + int(epsilon_xy.shape[1]), + int(input_convolved.shape[1]), + ) + ) + ) + .moveaxis(-1, 0) + .contiguous(memory_format=torch.contiguous_format) + ) + ############################################################ # Spike generation # ############################################################ @@ -864,7 +900,12 @@ class FunctionalSbS(torch.autograd.Function): ) epsilon_scale: torch.Tensor = torch.ones( - size=[1, int(epsilon_xy.shape[0]), int(epsilon_xy.shape[1]), 1], + size=[ + int(spikes.shape[0]), + int(spikes.shape[2]), + int(spikes.shape[3]), + 1, + ], dtype=torch.float32, ) @@ -875,9 +916,28 @@ class FunctionalSbS(torch.autograd.Function): epsilon_scale = torch.ones_like(epsilon_scale) h_temp: torch.Tensor = weights[spikes[:, t, :, :], :] * h + wx = 0 + wy = 0 + + if t == 0: + epsilon_temp: torch.Tensor = torch.empty( + ( + int(spikes.shape[0]), + int(spikes.shape[2]), + int(spikes.shape[3]), + ), + dtype=torch.float32, + ) + for wx in range(0, int(spikes.shape[2])): + for wy in range(0, int(spikes.shape[3])): + epsilon_temp[:, wx, wy] = epsilon_xy_convolved[ + spikes[:, t, wx, wy], wx, wy + ] + epsilon_subsegment: torch.Tensor = ( - epsilon_xy.unsqueeze(0).unsqueeze(-1) * epsilon_t[t] * epsilon_0 + epsilon_temp.unsqueeze(-1) * epsilon_t[t] * epsilon_0 ) + h_temp_sum: torch.Tensor = ( epsilon_scale * epsilon_subsegment / h_temp.sum(dim=3, keepdim=True) ) @@ -891,6 +951,7 @@ class FunctionalSbS(torch.autograd.Function): h /= epsilon_scale output = h.movedim(3, 1) + else: epsilon_t_0: torch.Tensor = epsilon_t * epsilon_0 @@ -909,10 +970,10 @@ class FunctionalSbS(torch.autograd.Function): assert np_h.flags["C_CONTIGUOUS"] is True assert np_h.ndim == 4 - np_epsilon_xy: np.ndarray = epsilon_xy.detach().numpy() + np_epsilon_xy: np.ndarray = epsilon_xy_convolved.detach().numpy() assert epsilon_xy.dtype == torch.float32 assert np_epsilon_xy.flags["C_CONTIGUOUS"] is True - assert np_epsilon_xy.ndim == 2 + assert np_epsilon_xy.ndim == 3 np_epsilon_t: np.ndarray = epsilon_t_0.detach().numpy() assert epsilon_t_0.dtype == torch.float32 @@ -948,6 +1009,7 @@ class FunctionalSbS(torch.autograd.Function): np_epsilon_xy.__array_interface__["data"][0], int(np_epsilon_xy.shape[0]), int(np_epsilon_xy.shape[1]), + int(np_epsilon_xy.shape[2]), np_epsilon_t.__array_interface__["data"][0], int(np_epsilon_t.shape[0]), np_weights.__array_interface__["data"][0], @@ -1056,7 +1118,7 @@ class FunctionalSbS(torch.autograd.Function): ctx.save_for_backward( input_convolved, - epsilon_xy_float64, + epsilon_xy_convolved, epsilon_0_float64, weights_float64, output, @@ -1075,7 +1137,7 @@ class FunctionalSbS(torch.autograd.Function): # Get the variables back ( input_float32, - epsilon_xy, + epsilon_xy_float32, epsilon_0, weights, output, @@ -1088,6 +1150,7 @@ class FunctionalSbS(torch.autograd.Function): input = input_float32.type(dtype=torch.float64) input /= input.sum(dim=1, keepdim=True, dtype=torch.float64) + epsilon_xy = epsilon_xy_float32.type(dtype=torch.float64) # For debugging: # print( @@ -1125,10 +1188,6 @@ class FunctionalSbS(torch.autograd.Function): ) torch.clip(backprop_z, out=backprop_z, min=-1e300, max=1e300) - backprop_y: torch.Tensor = ( - torch.einsum("bijxy,bixy->bjxy", backprop_z, input) - output - ) - result_omega: torch.Tensor = backprop_bigr.unsqueeze(2) * grad_output.unsqueeze( 1 ) @@ -1136,29 +1195,25 @@ class FunctionalSbS(torch.autograd.Function): "bijxy,bjxy->bixy", backprop_r, grad_output ).unsqueeze(2) result_omega *= backprop_f - torch.nan_to_num( - result_omega, out=result_omega, nan=1e300, posinf=1e300, neginf=-1e300 - ) - torch.clip(result_omega, out=result_omega, min=-1e300, max=1e300) result_eps_xy: torch.Tensor = ( - torch.einsum("bixy,bixy->bxy", backprop_y, grad_output) * eps_b - ) - torch.nan_to_num( - result_eps_xy, out=result_eps_xy, nan=1e300, posinf=1e300, neginf=-1e300 - ) - torch.clip(result_eps_xy, out=result_eps_xy, min=-1e300, max=1e300) + ( + (backprop_z * input.unsqueeze(2) - output.unsqueeze(1)) + * grad_output.unsqueeze(1) + ) + .sum(dim=2) + .sum(dim=0) + ) * eps_b result_phi: torch.Tensor = torch.einsum( "bijxy,bjxy->bixy", backprop_z, grad_output - ) * eps_a.unsqueeze(0).unsqueeze(0) - torch.nan_to_num( - result_phi, out=result_phi, nan=1e300, posinf=1e300, neginf=-1e300 - ) - torch.clip(result_phi, out=result_phi, min=-1e300, max=1e300) + ) * eps_a.unsqueeze(0) grad_weights = result_omega.sum(0).sum(-1).sum(-1) - grad_eps_xy = result_eps_xy.sum(0) + torch.nan_to_num( + grad_weights, out=grad_weights, nan=1e300, posinf=1e300, neginf=-1e300 + ) + torch.clip(grad_weights, out=grad_weights, min=-1e300, max=1e300) grad_input = torch.nn.functional.fold( torch.nn.functional.unfold( @@ -1174,22 +1229,41 @@ class FunctionalSbS(torch.autograd.Function): padding=padding, stride=stride, ) - torch.nan_to_num( grad_input, out=grad_input, nan=1e300, posinf=1e300, neginf=-1e300 ) torch.clip(grad_input, out=grad_input, min=-1e300, max=1e300) + grad_eps_xy_temp = torch.nn.functional.fold( + result_eps_xy.moveaxis(0, -1) + .reshape( + ( + int(result_eps_xy.shape[1]) * int(result_eps_xy.shape[2]), + int(result_eps_xy.shape[0]), + ) + ) + .unsqueeze(-1), + output_size=kernel_size, + kernel_size=kernel_size, + ) + + grad_eps_xy = ( + grad_eps_xy_temp.sum(dim=1) + .reshape( + ( + int(result_eps_xy.shape[1]), + int(result_eps_xy.shape[2]), + int(grad_eps_xy_temp.shape[-2]), + int(grad_eps_xy_temp.shape[-1]), + ) + ) + .contiguous(memory_format=torch.contiguous_format) + ) torch.nan_to_num( grad_eps_xy, out=grad_eps_xy, nan=1e300, posinf=1e300, neginf=-1e300 ) torch.clip(grad_eps_xy, out=grad_eps_xy, min=-1e300, max=1e300) - torch.nan_to_num( - grad_weights, out=grad_weights, nan=1e300, posinf=1e300, neginf=-1e300 - ) - torch.clip(grad_weights, out=grad_weights, min=-1e300, max=1e300) - grad_epsilon_0 = None grad_epsilon_t = None grad_kernel_size = None @@ -1218,3 +1292,6 @@ class FunctionalSbS(torch.autograd.Function): grad_h_initial, grad_alpha_number_of_iterations, ) + + +# %%