Add files via upload

This commit is contained in:
David Rotermund 2023-01-13 21:31:39 +01:00 committed by GitHub
parent ff64a9da14
commit f156fbf8bb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 558 additions and 101 deletions

View file

@ -15,6 +15,7 @@ class Adam(torch.optim.Optimizer):
self, self,
params, params,
sbs_setting: list[bool], sbs_setting: list[bool],
logging,
lr: float = 1e-3, lr: float = 1e-3,
beta1: float = 0.9, beta1: float = 0.9,
beta2: float = 0.999, beta2: float = 0.999,
@ -41,6 +42,7 @@ class Adam(torch.optim.Optimizer):
self.beta2 = beta2 self.beta2 = beta2
self.eps = eps self.eps = eps
self.maximize = maximize self.maximize = maximize
self._logging = logging
defaults = dict( defaults = dict(
lr=lr, lr=lr,
@ -149,8 +151,12 @@ class Adam(torch.optim.Optimizer):
if sbs_setting[i] is False: if sbs_setting[i] is False:
param -= step_size * (exp_avg / denom) param -= step_size * (exp_avg / denom)
else: else:
delta = torch.exp(-step_size * (exp_avg / denom)) # delta = torch.exp(-step_size * (exp_avg / denom))
print( delta = torch.tanh(-step_size * (exp_avg / denom))
f"{float(delta.min()) - 1.0:.4e} {float(delta.max()) - 1.0:.4e} {lr:.4e}" delta += 1.0
delta *= 0.5
delta += 0.5
self._logging.info(
f"ADAM: Layer {i} -> dw_min:{float(delta.min()):.4e} dw_max:{float(delta.max()):.4e} lr:{lr:.4e}"
) )
param *= delta param *= delta

View file

@ -3,6 +3,10 @@ import math
from network.CPP.PyMultiApp import MultiApp from network.CPP.PyMultiApp import MultiApp
global_multiapp_gpu_setting: list[torch.Tensor] = []
global_multiapp_size: list[torch.Tensor] = []
global_multiapp_cpp: list[MultiApp] = []
class Conv2dApproximation(torch.nn.Module): class Conv2dApproximation(torch.nn.Module):
@ -26,6 +30,9 @@ class Conv2dApproximation(torch.nn.Module):
device: torch.device device: torch.device
dtype: torch.dtype dtype: torch.dtype
multiapp_gpu_setting_position: int = -1
multiapp_cpp_position: int = -1
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
@ -68,6 +75,12 @@ class Conv2dApproximation(torch.nn.Module):
self.number_of_trunc_bits = number_of_trunc_bits self.number_of_trunc_bits = number_of_trunc_bits
self.number_of_frac = number_of_frac self.number_of_frac = number_of_frac
global_multiapp_gpu_setting.append(torch.tensor([0]))
global_multiapp_size.append(torch.tensor([0, 0, 0, 0]))
global_multiapp_cpp.append(MultiApp())
self.multiapp_gpu_setting_position = len(global_multiapp_gpu_setting) - 1
self.multiapp_cpp_position = len(global_multiapp_cpp) - 1
if self.use_bias is True: if self.use_bias is True:
self.bias: torch.nn.parameter.Parameter | None = ( self.bias: torch.nn.parameter.Parameter | None = (
torch.nn.parameter.Parameter( torch.nn.parameter.Parameter(
@ -190,6 +203,8 @@ class Conv2dApproximation(torch.nn.Module):
assert input.dim() == 4 assert input.dim() == 4
assert self.kernel_size is not None assert self.kernel_size is not None
assert self.multiapp_gpu_setting_position != -1
assert self.multiapp_cpp_position != -1
input_size = torch.Tensor([int(input.shape[-2]), int(input.shape[-1])]).type( input_size = torch.Tensor([int(input.shape[-2]), int(input.shape[-1])]).type(
dtype=torch.int64 dtype=torch.int64
@ -232,6 +247,8 @@ class Conv2dApproximation(torch.nn.Module):
int(self.number_of_trunc_bits), # 1 int(self.number_of_trunc_bits), # 1
int(self.number_of_frac), # 2 int(self.number_of_frac), # 2
int(number_of_cpu_processes), # 3 int(number_of_cpu_processes), # 3
int(self.multiapp_gpu_setting_position), # 4
int(self.multiapp_cpp_position), # 5
], ],
dtype=torch.int64, dtype=torch.int64,
) )
@ -267,6 +284,8 @@ class FunctionalMultiConv2d(torch.autograd.Function):
number_of_trunc_bits = int(parameter_list[1]) number_of_trunc_bits = int(parameter_list[1])
number_of_frac = int(parameter_list[2]) number_of_frac = int(parameter_list[2])
number_of_processes = int(parameter_list[3]) number_of_processes = int(parameter_list[3])
multiapp_gpu_setting_position = int(parameter_list[4])
multiapp_cpp_position = int(parameter_list[5])
assert input.device == weights.device assert input.device == weights.device
@ -278,9 +297,54 @@ class FunctionalMultiConv2d(torch.autograd.Function):
) )
assert output.is_contiguous() is True assert output.is_contiguous() is True
multiplier: MultiApp = MultiApp() multiapp_profile = global_multiapp_gpu_setting[
multiapp_gpu_setting_position
].clone()
multiplier.update_with_init_vector_multi_pattern( multiapp_size = global_multiapp_size[multiapp_gpu_setting_position].clone()
if input.device != torch.device("cpu"):
if (
(multiapp_profile.numel() == 1)
or (multiapp_size[0] != int(output.shape[0]))
or (multiapp_size[1] != int(output.shape[1]))
or (multiapp_size[2] != int(output.shape[2]))
or (multiapp_size[3] != int(output.shape[3]))
):
multiapp_profile = torch.zeros(
(1, 7), dtype=torch.int64, device=torch.device("cpu")
)
global_multiapp_cpp[multiapp_cpp_position].gpu_occupancy_export(
int(output.shape[2]),
int(output.shape[3]),
int(output.shape[0]),
int(output.shape[1]),
multiapp_profile.data_ptr(),
int(multiapp_profile.shape[0]),
int(multiapp_profile.shape[1]),
)
global_multiapp_gpu_setting[
multiapp_gpu_setting_position
] = multiapp_profile.clone()
multiapp_size[0] = int(output.shape[0])
multiapp_size[1] = int(output.shape[1])
multiapp_size[2] = int(output.shape[2])
multiapp_size[3] = int(output.shape[3])
global_multiapp_size[
multiapp_gpu_setting_position
] = multiapp_size.clone()
else:
global_multiapp_cpp[multiapp_cpp_position].gpu_occupancy_import(
multiapp_profile.data_ptr(),
int(multiapp_profile.shape[0]),
int(multiapp_profile.shape[1]),
)
global_multiapp_cpp[multiapp_cpp_position].update_entrypoint(
input.data_ptr(), input.data_ptr(),
weights.data_ptr(), weights.data_ptr(),
output.data_ptr(), output.data_ptr(),

View file

@ -47,8 +47,8 @@ class LearningParameters:
weight_noise_range: list[float] = field(default_factory=list) weight_noise_range: list[float] = field(default_factory=list)
eps_xy_intitial: float = field(default=0.1) eps_xy_intitial: float = field(default=0.1)
# disable_scale_grade: bool = field(default=False) disable_scale_grade: bool = field(default=False)
# kepp_last_grad_scale: bool = field(default=True) kepp_last_grad_scale: bool = field(default=True)
sbs_skip_gradient_calculation: list[bool] = field(default_factory=list) sbs_skip_gradient_calculation: list[bool] = field(default_factory=list)

View file

@ -4,6 +4,13 @@ from network.CPP.PySpikeGeneration2DManyIP import SpikeGeneration2DManyIP
from network.CPP.PyHDynamicCNNManyIP import HDynamicCNNManyIP from network.CPP.PyHDynamicCNNManyIP import HDynamicCNNManyIP
from network.calculate_output_size import calculate_output_size from network.calculate_output_size import calculate_output_size
global_sbs_gpu_setting: list[torch.Tensor] = []
global_sbs_size: list[torch.Tensor] = []
global_sbs_hdynamic_cpp: list[HDynamicCNNManyIP] = []
global_spike_generation_gpu_setting: list[torch.Tensor] = []
global_spike_size: list[torch.Tensor] = []
global_spike_generation_cpp: list[SpikeGeneration2DManyIP] = []
class SbS(torch.nn.Module): class SbS(torch.nn.Module):
@ -24,9 +31,9 @@ class SbS(torch.nn.Module):
_epsilon_xy_intitial: float _epsilon_xy_intitial: float
_h_initial: torch.Tensor | None = None _h_initial: torch.Tensor | None = None
_w_trainable: bool _w_trainable: bool
# _last_grad_scale: torch.nn.parameter.Parameter _last_grad_scale: torch.nn.parameter.Parameter
# _keep_last_grad_scale: bool _keep_last_grad_scale: bool
# _disable_scale_grade: bool _disable_scale_grade: bool
_forgetting_offset: torch.Tensor | None = None _forgetting_offset: torch.Tensor | None = None
_weight_noise_range: list[float] _weight_noise_range: list[float]
_skip_gradient_calculation: bool _skip_gradient_calculation: bool
@ -43,6 +50,14 @@ class SbS(torch.nn.Module):
_number_of_grad_weight_contributions: float = 0.0 _number_of_grad_weight_contributions: float = 0.0
last_input_store: bool = False
last_input_data: torch.Tensor | None = None
sbs_gpu_setting_position: int = -1
sbs_hdynamic_cpp_position: int = -1
spike_generation_cpp_position: int = -1
spike_generation_gpu_setting_position: int = -1
def __init__( def __init__(
self, self,
number_of_input_neurons: int, number_of_input_neurons: int,
@ -60,13 +75,14 @@ class SbS(torch.nn.Module):
padding: list[int] = [0, 0], padding: list[int] = [0, 0],
number_of_cpu_processes: int = 1, number_of_cpu_processes: int = 1,
w_trainable: bool = False, w_trainable: bool = False,
# keep_last_grad_scale: bool = False, keep_last_grad_scale: bool = False,
# disable_scale_grade: bool = True, disable_scale_grade: bool = True,
forgetting_offset: float = -1.0, forgetting_offset: float = -1.0,
skip_gradient_calculation: bool = False, skip_gradient_calculation: bool = False,
device: torch.device | None = None, device: torch.device | None = None,
default_dtype: torch.dtype | None = None, default_dtype: torch.dtype | None = None,
gpu_tuning_factor: int = 5, gpu_tuning_factor: int = 5,
layer_id: int = -1,
) -> None: ) -> None:
super().__init__() super().__init__()
@ -76,9 +92,9 @@ class SbS(torch.nn.Module):
self.default_dtype = default_dtype self.default_dtype = default_dtype
self._w_trainable = bool(w_trainable) self._w_trainable = bool(w_trainable)
# self._keep_last_grad_scale = bool(keep_last_grad_scale) self._keep_last_grad_scale = bool(keep_last_grad_scale)
self._skip_gradient_calculation = bool(skip_gradient_calculation) self._skip_gradient_calculation = bool(skip_gradient_calculation)
# self._disable_scale_grade = bool(disable_scale_grade) self._disable_scale_grade = bool(disable_scale_grade)
self._epsilon_xy_intitial = float(epsilon_xy_intitial) self._epsilon_xy_intitial = float(epsilon_xy_intitial)
self._stride = strides self._stride = strides
self._dilation = dilation self._dilation = dilation
@ -95,6 +111,21 @@ class SbS(torch.nn.Module):
assert len(input_size) == 2 assert len(input_size) == 2
self._input_size = input_size self._input_size = input_size
global_sbs_gpu_setting.append(torch.tensor([0]))
global_spike_generation_gpu_setting.append(torch.tensor([0]))
global_sbs_size.append(torch.tensor([0, 0, 0, 0]))
global_spike_size.append(torch.tensor([0, 0, 0, 0]))
global_sbs_hdynamic_cpp.append(HDynamicCNNManyIP())
global_spike_generation_cpp.append(SpikeGeneration2DManyIP())
self.sbs_gpu_setting_position = len(global_sbs_gpu_setting) - 1
self.sbs_hdynamic_cpp_position = len(global_sbs_hdynamic_cpp) - 1
self.spike_generation_cpp_position = len(global_spike_generation_cpp) - 1
self.spike_generation_gpu_setting_position = (
len(global_spike_generation_gpu_setting) - 1
)
# The GPU hates me... # The GPU hates me...
# Too many SbS threads == bad # Too many SbS threads == bad
# Thus I need to limit them... # Thus I need to limit them...
@ -105,10 +136,10 @@ class SbS(torch.nn.Module):
else: else:
self._gpu_tuning_factor = 0 self._gpu_tuning_factor = 0
# self._last_grad_scale = torch.nn.parameter.Parameter( self._last_grad_scale = torch.nn.parameter.Parameter(
# torch.tensor(-1.0, dtype=self.default_dtype), torch.tensor(-1.0, dtype=self.default_dtype),
# requires_grad=True, requires_grad=True,
# ) )
self._forgetting_offset = torch.tensor( self._forgetting_offset = torch.tensor(
forgetting_offset, dtype=self.default_dtype, device=self.device forgetting_offset, dtype=self.default_dtype, device=self.device
@ -234,12 +265,12 @@ class SbS(torch.nn.Module):
self.threshold_weights(threshold_weight) self.threshold_weights(threshold_weight)
self.norm_weights() self.norm_weights()
# def after_batch(self, new_state: bool = False): def after_batch(self, new_state: bool = False):
# if self._keep_last_grad_scale is True: if self._keep_last_grad_scale is True:
# self._last_grad_scale.data = self._last_grad_scale.grad self._last_grad_scale.data = self._last_grad_scale.grad
# self._keep_last_grad_scale = new_state self._keep_last_grad_scale = new_state
# self._last_grad_scale.grad = torch.zeros_like(self._last_grad_scale.grad) self._last_grad_scale.grad = torch.zeros_like(self._last_grad_scale.grad)
#################################################################### ####################################################################
# Helper functions # # Helper functions #
@ -339,6 +370,20 @@ class SbS(torch.nn.Module):
assert self._weights_exists is True assert self._weights_exists is True
assert self._weights is not None assert self._weights is not None
assert self.sbs_gpu_setting_position != -1
assert self.sbs_hdynamic_cpp_position != -1
assert self.spike_generation_cpp_position != -1
assert self.spike_generation_gpu_setting_position != -1
if labels is None:
labels_copy: torch.Tensor = torch.tensor(
[], dtype=torch.int64, device=self.device
)
else:
labels_copy = (
labels.detach().clone().type(dtype=torch.int64).to(device=self.device)
)
input_convolved = torch.nn.functional.fold( input_convolved = torch.nn.functional.fold(
torch.nn.functional.unfold( torch.nn.functional.unfold(
input.requires_grad_(True), input.requires_grad_(True),
@ -354,6 +399,12 @@ class SbS(torch.nn.Module):
stride=(1, 1), stride=(1, 1),
) )
if self.last_input_store is True:
self.last_input_data = input_convolved.detach().clone()
self.last_input_data /= self.last_input_data.sum(dim=1, keepdim=True)
else:
self.last_input_data = None
epsilon_t_0: torch.Tensor = ( epsilon_t_0: torch.Tensor = (
(self._epsilon_t * self._epsilon_0).type(input.dtype).to(input.device) (self._epsilon_t * self._epsilon_0).type(input.dtype).to(input.device)
) )
@ -361,8 +412,8 @@ class SbS(torch.nn.Module):
parameter_list = torch.tensor( parameter_list = torch.tensor(
[ [
int(self._w_trainable), # 0 int(self._w_trainable), # 0
int(0), # int(self._disable_scale_grade), # 1 int(self._disable_scale_grade), # 1
int(0), # int(self._keep_last_grad_scale), # 2 int(self._keep_last_grad_scale), # 2
int(self._skip_gradient_calculation), # 3 int(self._skip_gradient_calculation), # 3
int(self._number_of_spikes), # 4 int(self._number_of_spikes), # 4
int(self._number_of_cpu_processes), # 5 int(self._number_of_cpu_processes), # 5
@ -371,6 +422,10 @@ class SbS(torch.nn.Module):
int(self._gpu_tuning_factor), # 8 int(self._gpu_tuning_factor), # 8
int(self._output_layer), # 9 int(self._output_layer), # 9
int(self._local_learning), # 10 int(self._local_learning), # 10
int(self.sbs_gpu_setting_position), # 11
int(self.sbs_hdynamic_cpp_position), # 12
int(self.spike_generation_cpp_position), # 13
int(self.spike_generation_gpu_setting_position), # 14
], ],
dtype=torch.int64, dtype=torch.int64,
) )
@ -401,8 +456,9 @@ class SbS(torch.nn.Module):
self._weights, self._weights,
self._h_initial, self._h_initial,
parameter_list, parameter_list,
# self._last_grad_scale, self._last_grad_scale,
self._forgetting_offset, self._forgetting_offset,
labels_copy,
) )
self._number_of_grad_weight_contributions += ( self._number_of_grad_weight_contributions += (
@ -422,8 +478,9 @@ class FunctionalSbS(torch.autograd.Function):
weights: torch.Tensor, weights: torch.Tensor,
h_initial: torch.Tensor, h_initial: torch.Tensor,
parameter_list: torch.Tensor, parameter_list: torch.Tensor,
# grad_output_scale: torch.Tensor, grad_output_scale: torch.Tensor,
forgetting_offset: torch.Tensor, forgetting_offset: torch.Tensor,
labels: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
assert input.dim() == 4 assert input.dim() == 4
@ -444,6 +501,11 @@ class FunctionalSbS(torch.autograd.Function):
output_size_1: int = int(parameter_list[7]) output_size_1: int = int(parameter_list[7])
gpu_tuning_factor: int = int(parameter_list[8]) gpu_tuning_factor: int = int(parameter_list[8])
sbs_gpu_setting_position = int(parameter_list[11])
sbs_hdynamic_cpp_position = int(parameter_list[12])
spike_generation_cpp_position = int(parameter_list[13])
spike_generation_gpu_setting_position = int(parameter_list[14])
# ########################################################### # ###########################################################
# Spike generation # Spike generation
# ########################################################### # ###########################################################
@ -480,9 +542,59 @@ class FunctionalSbS(torch.autograd.Function):
assert spikes.is_contiguous() is True assert spikes.is_contiguous() is True
# time_start: float = time.perf_counter() # time_start: float = time.perf_counter()
spike_generation: SpikeGeneration2DManyIP = SpikeGeneration2DManyIP() spike_generation_profile = global_spike_generation_gpu_setting[
spike_generation_gpu_setting_position
].clone()
spike_generation.spike_generation( spike_generation_size = global_spike_size[
spike_generation_gpu_setting_position
].clone()
if input.device != torch.device("cpu"):
if (
(spike_generation_profile.numel() == 1)
or (spike_generation_size[0] != int(spikes.shape[0]))
or (spike_generation_size[1] != int(spikes.shape[1]))
or (spike_generation_size[2] != int(spikes.shape[2]))
or (spike_generation_size[3] != int(spikes.shape[3]))
):
spike_generation_profile = torch.zeros(
(1, 7), dtype=torch.int64, device=torch.device("cpu")
)
global_spike_generation_cpp[
spike_generation_cpp_position
].gpu_occupancy_export(
int(spikes.shape[2]),
int(spikes.shape[3]),
int(spikes.shape[0]),
int(spikes.shape[1]),
spike_generation_profile.data_ptr(),
int(spike_generation_profile.shape[0]),
int(spike_generation_profile.shape[1]),
)
global_spike_generation_gpu_setting[
spike_generation_gpu_setting_position
] = spike_generation_profile.clone()
spike_generation_size[0] = int(spikes.shape[0])
spike_generation_size[1] = int(spikes.shape[1])
spike_generation_size[2] = int(spikes.shape[2])
spike_generation_size[3] = int(spikes.shape[3])
global_spike_size[
spike_generation_gpu_setting_position
] = spike_generation_size.clone()
else:
global_spike_generation_cpp[
spike_generation_cpp_position
].gpu_occupancy_import(
spike_generation_profile.data_ptr(),
int(spike_generation_profile.shape[0]),
int(spike_generation_profile.shape[1]),
)
global_spike_generation_cpp[spike_generation_cpp_position].spike_generation(
input_cumsum.data_ptr(), input_cumsum.data_ptr(),
int(input_cumsum.shape[0]), int(input_cumsum.shape[0]),
int(input_cumsum.shape[1]), int(input_cumsum.shape[1]),
@ -536,9 +648,46 @@ class FunctionalSbS(torch.autograd.Function):
assert weights.ndim == 2 assert weights.ndim == 2
assert h_initial.ndim == 1 assert h_initial.ndim == 1
h_dynamic: HDynamicCNNManyIP = HDynamicCNNManyIP() sbs_profile = global_sbs_gpu_setting[sbs_gpu_setting_position].clone()
h_dynamic.update( sbs_size = global_sbs_size[sbs_gpu_setting_position].clone()
if input.device != torch.device("cpu"):
if (
(sbs_profile.numel() == 1)
or (sbs_size[0] != int(output.shape[0]))
or (sbs_size[1] != int(output.shape[1]))
or (sbs_size[2] != int(output.shape[2]))
or (sbs_size[3] != int(output.shape[3]))
):
sbs_profile = torch.zeros(
(14, 7), dtype=torch.int64, device=torch.device("cpu")
)
global_sbs_hdynamic_cpp[sbs_hdynamic_cpp_position].gpu_occupancy_export(
int(output.shape[2]),
int(output.shape[3]),
int(output.shape[0]),
int(output.shape[1]),
sbs_profile.data_ptr(),
int(sbs_profile.shape[0]),
int(sbs_profile.shape[1]),
)
global_sbs_gpu_setting[sbs_gpu_setting_position] = sbs_profile.clone()
sbs_size[0] = int(output.shape[0])
sbs_size[1] = int(output.shape[1])
sbs_size[2] = int(output.shape[2])
sbs_size[3] = int(output.shape[3])
global_sbs_size[sbs_gpu_setting_position] = sbs_size.clone()
else:
global_sbs_hdynamic_cpp[sbs_hdynamic_cpp_position].gpu_occupancy_import(
sbs_profile.data_ptr(),
int(sbs_profile.shape[0]),
int(sbs_profile.shape[1]),
)
global_sbs_hdynamic_cpp[sbs_hdynamic_cpp_position].update(
output.data_ptr(), output.data_ptr(),
int(output.shape[0]), int(output.shape[0]),
int(output.shape[1]), int(output.shape[1]),
@ -575,7 +724,8 @@ class FunctionalSbS(torch.autograd.Function):
weights, weights,
output, output,
parameter_list, parameter_list,
# grad_output_scale, grad_output_scale,
labels,
) )
return output return output
@ -590,9 +740,12 @@ class FunctionalSbS(torch.autograd.Function):
weights, weights,
output, output,
parameter_list, parameter_list,
# last_grad_scale, last_grad_scale,
labels,
) = ctx.saved_tensors ) = ctx.saved_tensors
assert labels.numel() > 0
# ############################################## # ##############################################
# Default output # Default output
# ############################################## # ##############################################
@ -603,13 +756,14 @@ class FunctionalSbS(torch.autograd.Function):
grad_h_initial = None grad_h_initial = None
grad_parameter_list = None grad_parameter_list = None
grad_forgetting_offset = None grad_forgetting_offset = None
grad_labels = None
# ############################################## # ##############################################
# Parameters # Parameters
# ############################################## # ##############################################
parameter_w_trainable: bool = bool(parameter_list[0]) parameter_w_trainable: bool = bool(parameter_list[0])
# parameter_disable_scale_grade: bool = bool(parameter_list[1]) parameter_disable_scale_grade: bool = bool(parameter_list[1])
# parameter_keep_last_grad_scale: bool = bool(parameter_list[2]) parameter_keep_last_grad_scale: bool = bool(parameter_list[2])
parameter_skip_gradient_calculation: bool = bool(parameter_list[3]) parameter_skip_gradient_calculation: bool = bool(parameter_list[3])
parameter_output_layer: bool = bool(parameter_list[9]) parameter_output_layer: bool = bool(parameter_list[9])
parameter_local_learning: bool = bool(parameter_list[10]) parameter_local_learning: bool = bool(parameter_list[10])
@ -617,13 +771,13 @@ class FunctionalSbS(torch.autograd.Function):
# ############################################## # ##############################################
# Dealing with overall scale of the gradient # Dealing with overall scale of the gradient
# ############################################## # ##############################################
# if parameter_disable_scale_grade is False: if parameter_disable_scale_grade is False:
# if parameter_keep_last_grad_scale is True: if parameter_keep_last_grad_scale is True:
# last_grad_scale = torch.tensor( last_grad_scale = torch.tensor(
# [torch.abs(grad_output).max(), last_grad_scale] [torch.abs(grad_output).max(), last_grad_scale]
# ).max() ).max()
# grad_output /= last_grad_scale grad_output /= last_grad_scale
# grad_output_scale = last_grad_scale.clone() grad_output_scale = last_grad_scale.clone()
input /= input.sum(dim=1, keepdim=True, dtype=weights.dtype) input /= input.sum(dim=1, keepdim=True, dtype=weights.dtype)
@ -640,8 +794,9 @@ class FunctionalSbS(torch.autograd.Function):
grad_weights, grad_weights,
grad_h_initial, grad_h_initial,
grad_parameter_list, grad_parameter_list,
# grad_output_scale, grad_output_scale,
grad_forgetting_offset, grad_forgetting_offset,
grad_labels,
) )
# ################################################# # #################################################
@ -682,6 +837,41 @@ class FunctionalSbS(torch.autograd.Function):
.sum(-1) .sum(-1)
) )
elif (parameter_output_layer is True) and (parameter_local_learning is True):
target_one_hot: torch.Tensor = torch.zeros(
(
labels.shape[0],
output.shape[1],
),
device=input.device,
dtype=input.dtype,
)
target_one_hot.scatter_(
1,
labels.to(input.device).unsqueeze(1),
torch.ones(
(labels.shape[0], 1),
device=input.device,
dtype=input.dtype,
),
)
target_one_hot = target_one_hot.unsqueeze(-1).unsqueeze(-1)
# (-2 * (input - backprop_bigr).unsqueeze(2) * (target_one_hot-output).unsqueeze(1))
# (-2 * input.unsqueeze(2) * (target_one_hot-output).unsqueeze(1))
grad_weights = (
(
-2
* (input - backprop_bigr).unsqueeze(2)
* target_one_hot.unsqueeze(1)
)
.sum(0)
.sum(-1)
.sum(-1)
)
else: else:
# ################################################# # #################################################
# Backprop # Backprop
@ -709,6 +899,7 @@ class FunctionalSbS(torch.autograd.Function):
grad_weights, grad_weights,
grad_h_initial, grad_h_initial,
grad_parameter_list, grad_parameter_list,
# grad_output_scale, grad_output_scale,
grad_forgetting_offset, grad_forgetting_offset,
grad_labels,
) )

View file

@ -0,0 +1,33 @@
import torch
from network.SbS import SbS
class SbSReconstruction(torch.nn.Module):
_the_sbs_layer: SbS
def __init__(
self,
the_sbs_layer: SbS,
) -> None:
super().__init__()
self._the_sbs_layer = the_sbs_layer
self.device = self._the_sbs_layer.device
self.default_dtype = self._the_sbs_layer.default_dtype
def forward(self, input: torch.Tensor) -> torch.Tensor:
assert self._the_sbs_layer._weights_exists is True
input_norm = input / input.sum(dim=1, keepdim=True)
output = (
self._the_sbs_layer._weights.data.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
* input_norm.unsqueeze(1)
).sum(dim=2)
output /= output.sum(dim=1, keepdim=True)
return output

View file

@ -6,6 +6,7 @@ from network.Parameter import Config
from network.SbS import SbS from network.SbS import SbS
from network.SplitOnOffLayer import SplitOnOffLayer from network.SplitOnOffLayer import SplitOnOffLayer
from network.Conv2dApproximation import Conv2dApproximation from network.Conv2dApproximation import Conv2dApproximation
from network.SbSReconstruction import SbSReconstruction
def build_network( def build_network(
@ -159,12 +160,13 @@ def build_network(
padding=padding, padding=padding,
number_of_cpu_processes=cfg.number_of_cpu_processes, number_of_cpu_processes=cfg.number_of_cpu_processes,
w_trainable=w_trainable, w_trainable=w_trainable,
# keep_last_grad_scale=cfg.learning_parameters.kepp_last_grad_scale, keep_last_grad_scale=cfg.learning_parameters.kepp_last_grad_scale,
# disable_scale_grade=cfg.learning_parameters.disable_scale_grade, disable_scale_grade=cfg.learning_parameters.disable_scale_grade,
forgetting_offset=cfg.forgetting_offset, forgetting_offset=cfg.forgetting_offset,
skip_gradient_calculation=sbs_skip_gradient_calculation, skip_gradient_calculation=sbs_skip_gradient_calculation,
device=device, device=device,
default_dtype=default_dtype, default_dtype=default_dtype,
layer_id=layer_id,
) )
) )
# Adding the x,y output dimensions # Adding the x,y output dimensions
@ -178,6 +180,25 @@ def build_network(
if cfg.network_structure.layer_type[layer_id].upper().find("LOCAL") != -1: if cfg.network_structure.layer_type[layer_id].upper().find("LOCAL") != -1:
network[-1]._local_learning = True network[-1]._local_learning = True
elif (
cfg.network_structure.layer_type[layer_id]
.upper()
.startswith("RECONSTRUCTION")
is True
):
logging.info(f"Layer: {layer_id} -> SbS Reconstruction Layer")
assert layer_id > 0
assert isinstance(network[-1], SbS) is True
network.append(SbSReconstruction(network[-1]))
network[-1]._w_trainable = False
if layer_id == len(cfg.network_structure.layer_type) - 1:
network[-2].last_input_store = True
input_size.append(input_size[-1])
# ############################################################# # #############################################################
# Split On Off Layer: # Split On Off Layer:
# ############################################################# # #############################################################

View file

@ -57,10 +57,13 @@ def build_optimizer(
optimizer_wf = Adam( optimizer_wf = Adam(
parameter_list_weights, parameter_list_weights,
parameter_list_sbs, parameter_list_sbs,
logging=logging,
lr=cfg.learning_parameters.learning_rate_gamma_w, lr=cfg.learning_parameters.learning_rate_gamma_w,
) )
else: else:
optimizer_wf = Adam(parameter_list_weights, parameter_list_sbs) optimizer_wf = Adam(
parameter_list_weights, parameter_list_sbs, logging=logging
)
elif cfg.learning_parameters.optimizer_name == "SGD": elif cfg.learning_parameters.optimizer_name == "SGD":
logging.info("Using optimizer: SGD") logging.info("Using optimizer: SGD")

View file

@ -5,6 +5,7 @@ from torch.utils.tensorboard import SummaryWriter
from network.SbS import SbS from network.SbS import SbS
from network.save_weight_and_bias import save_weight_and_bias from network.save_weight_and_bias import save_weight_and_bias
from network.SbSReconstruction import SbSReconstruction
def add_weight_and_bias_to_histogram( def add_weight_and_bias_to_histogram(
@ -94,7 +95,7 @@ def loss_function(
device=device, device=device,
dtype=default_dtype, dtype=default_dtype,
), ),
).unsqueeze(-1).unsqueeze(-1) )
h_y1 = torch.log(h + 1e-20) h_y1 = torch.log(h + 1e-20)
@ -119,6 +120,44 @@ def loss_function(
return None return None
def loss_function_reconstruction(
h_reco: torch.Tensor,
h_input: torch.Tensor,
loss_mode: int = 0,
loss_coeffs_mse: float = 0.0,
loss_coeffs_kldiv: float = 0.0,
) -> torch.Tensor | None:
assert loss_mode >= 0
assert loss_mode <= 0
assert h_reco.ndim == 4
assert h_input.ndim == 4
assert h_reco.shape[0] == h_input.shape[0]
assert h_reco.shape[1] == h_input.shape[1]
assert h_reco.shape[2] == h_input.shape[2]
assert h_reco.shape[3] == h_input.shape[3]
if loss_mode == 0:
h_reco_log = torch.log(h_reco + 1e-20)
my_loss: torch.Tensor = (
torch.nn.functional.mse_loss(
h_reco,
h_input,
reduction="sum",
)
* loss_coeffs_mse
+ torch.nn.functional.kl_div(h_reco_log, h_input + 1e-20, reduction="sum")
* loss_coeffs_kldiv
) / (loss_coeffs_kldiv + loss_coeffs_mse)
return my_loss
else:
return None
def forward_pass_train( def forward_pass_train(
input: torch.Tensor, input: torch.Tensor,
labels: torch.Tensor, labels: torch.Tensor,
@ -228,15 +267,15 @@ def run_lr_scheduler(
tb.flush() tb.flush()
# def deal_with_gradient_scale(epoch_id: int, mini_batch_number: int, network): def deal_with_gradient_scale(epoch_id: int, mini_batch_number: int, network):
# if (epoch_id == 0) and (mini_batch_number == 0): if (epoch_id == 0) and (mini_batch_number == 0):
# for id in range(0, len(network)): for id in range(0, len(network)):
# if isinstance(network[id], SbS) is True: if isinstance(network[id], SbS) is True:
# network[id].after_batch(True) network[id].after_batch(True)
# else: else:
# for id in range(0, len(network)): for id in range(0, len(network)):
# if isinstance(network[id], SbS) is True: if isinstance(network[id], SbS) is True:
# network[id].after_batch() network[id].after_batch()
def loop_train( def loop_train(
@ -318,12 +357,20 @@ def loop_train(
if last_test_performance < 0: if last_test_performance < 0:
logging.info("") logging.info("")
else: else:
if isinstance(network[-1], SbSReconstruction) is False:
logging.info( logging.info(
( (
f"\t\t\tLast test performance: " f"\t\t\tLast test performance: "
f"{last_test_performance/100.0:^6.2%}" f"{last_test_performance/100.0:^6.2%}"
) )
) )
else:
logging.info(
(
f"\t\t\tLast test performance: "
f"{last_test_performance:^6.2e}"
)
)
logging.info("----------------") logging.info("----------------")
number_of_pattern_in_minibatch += h_x_labels.shape[0] number_of_pattern_in_minibatch += h_x_labels.shape[0]
@ -345,6 +392,8 @@ def loop_train(
# ##################################################### # #####################################################
# Calculate the loss function # Calculate the loss function
# ##################################################### # #####################################################
if isinstance(network[-1], SbSReconstruction) is False:
my_loss: torch.Tensor | None = loss_function( my_loss: torch.Tensor | None = loss_function(
h=h_collection[-1], h=h_collection[-1],
labels=h_x_labels, labels=h_x_labels,
@ -357,6 +406,16 @@ def loop_train(
loss_coeffs_mse=float(cfg.learning_parameters.loss_coeffs_mse), loss_coeffs_mse=float(cfg.learning_parameters.loss_coeffs_mse),
loss_coeffs_kldiv=float(cfg.learning_parameters.loss_coeffs_kldiv), loss_coeffs_kldiv=float(cfg.learning_parameters.loss_coeffs_kldiv),
) )
else:
assert cfg.learning_parameters.lr_scheduler_use_performance is False
my_loss = loss_function_reconstruction(
h_reco=h_collection[-1],
h_input=network[-2].last_input_data,
loss_mode=cfg.learning_parameters.loss_mode,
loss_coeffs_mse=float(cfg.learning_parameters.loss_coeffs_mse),
loss_coeffs_kldiv=float(cfg.learning_parameters.loss_coeffs_kldiv),
)
assert my_loss is not None assert my_loss is not None
time_after_forward_and_loss: float = time.perf_counter() time_after_forward_and_loss: float = time.perf_counter()
@ -374,6 +433,7 @@ def loop_train(
# Performance measures # Performance measures
# ##################################################### # #####################################################
if isinstance(network[-1], SbSReconstruction) is False:
correct_in_minibatch += ( correct_in_minibatch += (
(h_collection[-1].argmax(dim=1).squeeze().cpu() == h_x_labels) (h_collection[-1].argmax(dim=1).squeeze().cpu() == h_x_labels)
.sum() .sum()
@ -391,11 +451,11 @@ def loop_train(
# the future error with it # the future error with it
# Kind of deals with the vanishing / # Kind of deals with the vanishing /
# exploding gradients # exploding gradients
# deal_with_gradient_scale( deal_with_gradient_scale(
# epoch_id=epoch_id, epoch_id=epoch_id,
# mini_batch_number=mini_batch_number, mini_batch_number=mini_batch_number,
# network=network, network=network,
# ) )
# Measure the time for one mini-batch # Measure the time for one mini-batch
time_forward += time_after_forward_and_loss - time_mini_batch_start time_forward += time_after_forward_and_loss - time_mini_batch_start
@ -403,6 +463,7 @@ def loop_train(
if number_of_pattern_in_minibatch >= cfg.get_update_after_x_pattern(): if number_of_pattern_in_minibatch >= cfg.get_update_after_x_pattern():
if isinstance(network[-1], SbSReconstruction) is False:
logging.info( logging.info(
( (
f"{epoch_id:^6}=>{mini_batch_number:^6} " f"{epoch_id:^6}=>{mini_batch_number:^6} "
@ -419,6 +480,22 @@ def loop_train(
) )
) )
else:
logging.info(
(
f"{epoch_id:^6}=>{mini_batch_number:^6} "
f"\t\tTraining {number_of_pattern_in_minibatch^6} pattern "
f"\t\t\tForward time: \t{time_forward:^6.2f}sec"
)
)
logging.info(
(
f"\t\t\tLoss: {loss_in_minibatch/number_of_pattern_in_minibatch:^15.3e} "
f"\t\t\tBackward time: \t{time_backward:^6.2f}sec "
)
)
my_loss_for_batch = loss_in_minibatch / number_of_pattern_in_minibatch my_loss_for_batch = loss_in_minibatch / number_of_pattern_in_minibatch
performance_for_batch = ( performance_for_batch = (
@ -510,3 +587,65 @@ def loop_test(
tb.flush() tb.flush()
return performance return performance
def loop_test_reconstruction(
epoch_id: int,
cfg: Config,
network: torch.nn.modules.container.Sequential,
my_loader_test: torch.utils.data.dataloader.DataLoader,
the_dataset_test,
device: torch.device,
default_dtype: torch.dtype,
logging,
tb: SummaryWriter,
) -> float:
test_count: int = 0
test_loss: float = 0.0
test_complete: int = the_dataset_test.__len__()
logging.info("")
logging.info("Testing:")
for h_x, h_x_labels in my_loader_test:
time_0 = time.perf_counter()
h_collection = forward_pass_test(
input=h_x,
the_dataset_test=the_dataset_test,
cfg=cfg,
network=network,
device=device,
default_dtype=default_dtype,
)
my_loss: torch.Tensor | None = loss_function_reconstruction(
h_reco=h_collection[-1],
h_input=network[-2].last_input_data,
loss_mode=cfg.learning_parameters.loss_mode,
loss_coeffs_mse=float(cfg.learning_parameters.loss_coeffs_mse),
loss_coeffs_kldiv=float(cfg.learning_parameters.loss_coeffs_kldiv),
)
assert my_loss is not None
test_count += h_x_labels.shape[0]
test_loss += my_loss.item()
performance = test_loss / test_count
time_1 = time.perf_counter()
time_measure_a = time_1 - time_0
logging.info(
(
f"\t\t{test_count} of {test_complete}"
f" with {performance:^6.2e} \t Time used: {time_measure_a:^6.2f}sec"
)
)
logging.info("")
tb.add_scalar("Test Error", performance, epoch_id)
tb.flush()
return performance