Add files via upload
This commit is contained in:
parent
ff64a9da14
commit
f156fbf8bb
8 changed files with 558 additions and 101 deletions
|
@ -15,6 +15,7 @@ class Adam(torch.optim.Optimizer):
|
||||||
self,
|
self,
|
||||||
params,
|
params,
|
||||||
sbs_setting: list[bool],
|
sbs_setting: list[bool],
|
||||||
|
logging,
|
||||||
lr: float = 1e-3,
|
lr: float = 1e-3,
|
||||||
beta1: float = 0.9,
|
beta1: float = 0.9,
|
||||||
beta2: float = 0.999,
|
beta2: float = 0.999,
|
||||||
|
@ -41,6 +42,7 @@ class Adam(torch.optim.Optimizer):
|
||||||
self.beta2 = beta2
|
self.beta2 = beta2
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.maximize = maximize
|
self.maximize = maximize
|
||||||
|
self._logging = logging
|
||||||
|
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr,
|
lr=lr,
|
||||||
|
@ -149,8 +151,12 @@ class Adam(torch.optim.Optimizer):
|
||||||
if sbs_setting[i] is False:
|
if sbs_setting[i] is False:
|
||||||
param -= step_size * (exp_avg / denom)
|
param -= step_size * (exp_avg / denom)
|
||||||
else:
|
else:
|
||||||
delta = torch.exp(-step_size * (exp_avg / denom))
|
# delta = torch.exp(-step_size * (exp_avg / denom))
|
||||||
print(
|
delta = torch.tanh(-step_size * (exp_avg / denom))
|
||||||
f"{float(delta.min()) - 1.0:.4e} {float(delta.max()) - 1.0:.4e} {lr:.4e}"
|
delta += 1.0
|
||||||
|
delta *= 0.5
|
||||||
|
delta += 0.5
|
||||||
|
self._logging.info(
|
||||||
|
f"ADAM: Layer {i} -> dw_min:{float(delta.min()):.4e} dw_max:{float(delta.max()):.4e} lr:{lr:.4e}"
|
||||||
)
|
)
|
||||||
param *= delta
|
param *= delta
|
||||||
|
|
|
@ -3,6 +3,10 @@ import math
|
||||||
|
|
||||||
from network.CPP.PyMultiApp import MultiApp
|
from network.CPP.PyMultiApp import MultiApp
|
||||||
|
|
||||||
|
global_multiapp_gpu_setting: list[torch.Tensor] = []
|
||||||
|
global_multiapp_size: list[torch.Tensor] = []
|
||||||
|
global_multiapp_cpp: list[MultiApp] = []
|
||||||
|
|
||||||
|
|
||||||
class Conv2dApproximation(torch.nn.Module):
|
class Conv2dApproximation(torch.nn.Module):
|
||||||
|
|
||||||
|
@ -26,6 +30,9 @@ class Conv2dApproximation(torch.nn.Module):
|
||||||
device: torch.device
|
device: torch.device
|
||||||
dtype: torch.dtype
|
dtype: torch.dtype
|
||||||
|
|
||||||
|
multiapp_gpu_setting_position: int = -1
|
||||||
|
multiapp_cpp_position: int = -1
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
|
@ -68,6 +75,12 @@ class Conv2dApproximation(torch.nn.Module):
|
||||||
self.number_of_trunc_bits = number_of_trunc_bits
|
self.number_of_trunc_bits = number_of_trunc_bits
|
||||||
self.number_of_frac = number_of_frac
|
self.number_of_frac = number_of_frac
|
||||||
|
|
||||||
|
global_multiapp_gpu_setting.append(torch.tensor([0]))
|
||||||
|
global_multiapp_size.append(torch.tensor([0, 0, 0, 0]))
|
||||||
|
global_multiapp_cpp.append(MultiApp())
|
||||||
|
self.multiapp_gpu_setting_position = len(global_multiapp_gpu_setting) - 1
|
||||||
|
self.multiapp_cpp_position = len(global_multiapp_cpp) - 1
|
||||||
|
|
||||||
if self.use_bias is True:
|
if self.use_bias is True:
|
||||||
self.bias: torch.nn.parameter.Parameter | None = (
|
self.bias: torch.nn.parameter.Parameter | None = (
|
||||||
torch.nn.parameter.Parameter(
|
torch.nn.parameter.Parameter(
|
||||||
|
@ -190,6 +203,8 @@ class Conv2dApproximation(torch.nn.Module):
|
||||||
assert input.dim() == 4
|
assert input.dim() == 4
|
||||||
|
|
||||||
assert self.kernel_size is not None
|
assert self.kernel_size is not None
|
||||||
|
assert self.multiapp_gpu_setting_position != -1
|
||||||
|
assert self.multiapp_cpp_position != -1
|
||||||
|
|
||||||
input_size = torch.Tensor([int(input.shape[-2]), int(input.shape[-1])]).type(
|
input_size = torch.Tensor([int(input.shape[-2]), int(input.shape[-1])]).type(
|
||||||
dtype=torch.int64
|
dtype=torch.int64
|
||||||
|
@ -232,6 +247,8 @@ class Conv2dApproximation(torch.nn.Module):
|
||||||
int(self.number_of_trunc_bits), # 1
|
int(self.number_of_trunc_bits), # 1
|
||||||
int(self.number_of_frac), # 2
|
int(self.number_of_frac), # 2
|
||||||
int(number_of_cpu_processes), # 3
|
int(number_of_cpu_processes), # 3
|
||||||
|
int(self.multiapp_gpu_setting_position), # 4
|
||||||
|
int(self.multiapp_cpp_position), # 5
|
||||||
],
|
],
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
|
@ -267,6 +284,8 @@ class FunctionalMultiConv2d(torch.autograd.Function):
|
||||||
number_of_trunc_bits = int(parameter_list[1])
|
number_of_trunc_bits = int(parameter_list[1])
|
||||||
number_of_frac = int(parameter_list[2])
|
number_of_frac = int(parameter_list[2])
|
||||||
number_of_processes = int(parameter_list[3])
|
number_of_processes = int(parameter_list[3])
|
||||||
|
multiapp_gpu_setting_position = int(parameter_list[4])
|
||||||
|
multiapp_cpp_position = int(parameter_list[5])
|
||||||
|
|
||||||
assert input.device == weights.device
|
assert input.device == weights.device
|
||||||
|
|
||||||
|
@ -278,9 +297,54 @@ class FunctionalMultiConv2d(torch.autograd.Function):
|
||||||
)
|
)
|
||||||
assert output.is_contiguous() is True
|
assert output.is_contiguous() is True
|
||||||
|
|
||||||
multiplier: MultiApp = MultiApp()
|
multiapp_profile = global_multiapp_gpu_setting[
|
||||||
|
multiapp_gpu_setting_position
|
||||||
|
].clone()
|
||||||
|
|
||||||
multiplier.update_with_init_vector_multi_pattern(
|
multiapp_size = global_multiapp_size[multiapp_gpu_setting_position].clone()
|
||||||
|
|
||||||
|
if input.device != torch.device("cpu"):
|
||||||
|
if (
|
||||||
|
(multiapp_profile.numel() == 1)
|
||||||
|
or (multiapp_size[0] != int(output.shape[0]))
|
||||||
|
or (multiapp_size[1] != int(output.shape[1]))
|
||||||
|
or (multiapp_size[2] != int(output.shape[2]))
|
||||||
|
or (multiapp_size[3] != int(output.shape[3]))
|
||||||
|
):
|
||||||
|
multiapp_profile = torch.zeros(
|
||||||
|
(1, 7), dtype=torch.int64, device=torch.device("cpu")
|
||||||
|
)
|
||||||
|
|
||||||
|
global_multiapp_cpp[multiapp_cpp_position].gpu_occupancy_export(
|
||||||
|
int(output.shape[2]),
|
||||||
|
int(output.shape[3]),
|
||||||
|
int(output.shape[0]),
|
||||||
|
int(output.shape[1]),
|
||||||
|
multiapp_profile.data_ptr(),
|
||||||
|
int(multiapp_profile.shape[0]),
|
||||||
|
int(multiapp_profile.shape[1]),
|
||||||
|
)
|
||||||
|
global_multiapp_gpu_setting[
|
||||||
|
multiapp_gpu_setting_position
|
||||||
|
] = multiapp_profile.clone()
|
||||||
|
|
||||||
|
multiapp_size[0] = int(output.shape[0])
|
||||||
|
multiapp_size[1] = int(output.shape[1])
|
||||||
|
multiapp_size[2] = int(output.shape[2])
|
||||||
|
multiapp_size[3] = int(output.shape[3])
|
||||||
|
|
||||||
|
global_multiapp_size[
|
||||||
|
multiapp_gpu_setting_position
|
||||||
|
] = multiapp_size.clone()
|
||||||
|
|
||||||
|
else:
|
||||||
|
global_multiapp_cpp[multiapp_cpp_position].gpu_occupancy_import(
|
||||||
|
multiapp_profile.data_ptr(),
|
||||||
|
int(multiapp_profile.shape[0]),
|
||||||
|
int(multiapp_profile.shape[1]),
|
||||||
|
)
|
||||||
|
|
||||||
|
global_multiapp_cpp[multiapp_cpp_position].update_entrypoint(
|
||||||
input.data_ptr(),
|
input.data_ptr(),
|
||||||
weights.data_ptr(),
|
weights.data_ptr(),
|
||||||
output.data_ptr(),
|
output.data_ptr(),
|
||||||
|
|
|
@ -47,8 +47,8 @@ class LearningParameters:
|
||||||
weight_noise_range: list[float] = field(default_factory=list)
|
weight_noise_range: list[float] = field(default_factory=list)
|
||||||
eps_xy_intitial: float = field(default=0.1)
|
eps_xy_intitial: float = field(default=0.1)
|
||||||
|
|
||||||
# disable_scale_grade: bool = field(default=False)
|
disable_scale_grade: bool = field(default=False)
|
||||||
# kepp_last_grad_scale: bool = field(default=True)
|
kepp_last_grad_scale: bool = field(default=True)
|
||||||
|
|
||||||
sbs_skip_gradient_calculation: list[bool] = field(default_factory=list)
|
sbs_skip_gradient_calculation: list[bool] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
265
network/SbS.py
265
network/SbS.py
|
@ -4,6 +4,13 @@ from network.CPP.PySpikeGeneration2DManyIP import SpikeGeneration2DManyIP
|
||||||
from network.CPP.PyHDynamicCNNManyIP import HDynamicCNNManyIP
|
from network.CPP.PyHDynamicCNNManyIP import HDynamicCNNManyIP
|
||||||
from network.calculate_output_size import calculate_output_size
|
from network.calculate_output_size import calculate_output_size
|
||||||
|
|
||||||
|
global_sbs_gpu_setting: list[torch.Tensor] = []
|
||||||
|
global_sbs_size: list[torch.Tensor] = []
|
||||||
|
global_sbs_hdynamic_cpp: list[HDynamicCNNManyIP] = []
|
||||||
|
global_spike_generation_gpu_setting: list[torch.Tensor] = []
|
||||||
|
global_spike_size: list[torch.Tensor] = []
|
||||||
|
global_spike_generation_cpp: list[SpikeGeneration2DManyIP] = []
|
||||||
|
|
||||||
|
|
||||||
class SbS(torch.nn.Module):
|
class SbS(torch.nn.Module):
|
||||||
|
|
||||||
|
@ -24,9 +31,9 @@ class SbS(torch.nn.Module):
|
||||||
_epsilon_xy_intitial: float
|
_epsilon_xy_intitial: float
|
||||||
_h_initial: torch.Tensor | None = None
|
_h_initial: torch.Tensor | None = None
|
||||||
_w_trainable: bool
|
_w_trainable: bool
|
||||||
# _last_grad_scale: torch.nn.parameter.Parameter
|
_last_grad_scale: torch.nn.parameter.Parameter
|
||||||
# _keep_last_grad_scale: bool
|
_keep_last_grad_scale: bool
|
||||||
# _disable_scale_grade: bool
|
_disable_scale_grade: bool
|
||||||
_forgetting_offset: torch.Tensor | None = None
|
_forgetting_offset: torch.Tensor | None = None
|
||||||
_weight_noise_range: list[float]
|
_weight_noise_range: list[float]
|
||||||
_skip_gradient_calculation: bool
|
_skip_gradient_calculation: bool
|
||||||
|
@ -43,6 +50,14 @@ class SbS(torch.nn.Module):
|
||||||
|
|
||||||
_number_of_grad_weight_contributions: float = 0.0
|
_number_of_grad_weight_contributions: float = 0.0
|
||||||
|
|
||||||
|
last_input_store: bool = False
|
||||||
|
last_input_data: torch.Tensor | None = None
|
||||||
|
|
||||||
|
sbs_gpu_setting_position: int = -1
|
||||||
|
sbs_hdynamic_cpp_position: int = -1
|
||||||
|
spike_generation_cpp_position: int = -1
|
||||||
|
spike_generation_gpu_setting_position: int = -1
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
number_of_input_neurons: int,
|
number_of_input_neurons: int,
|
||||||
|
@ -60,13 +75,14 @@ class SbS(torch.nn.Module):
|
||||||
padding: list[int] = [0, 0],
|
padding: list[int] = [0, 0],
|
||||||
number_of_cpu_processes: int = 1,
|
number_of_cpu_processes: int = 1,
|
||||||
w_trainable: bool = False,
|
w_trainable: bool = False,
|
||||||
# keep_last_grad_scale: bool = False,
|
keep_last_grad_scale: bool = False,
|
||||||
# disable_scale_grade: bool = True,
|
disable_scale_grade: bool = True,
|
||||||
forgetting_offset: float = -1.0,
|
forgetting_offset: float = -1.0,
|
||||||
skip_gradient_calculation: bool = False,
|
skip_gradient_calculation: bool = False,
|
||||||
device: torch.device | None = None,
|
device: torch.device | None = None,
|
||||||
default_dtype: torch.dtype | None = None,
|
default_dtype: torch.dtype | None = None,
|
||||||
gpu_tuning_factor: int = 5,
|
gpu_tuning_factor: int = 5,
|
||||||
|
layer_id: int = -1,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -76,9 +92,9 @@ class SbS(torch.nn.Module):
|
||||||
self.default_dtype = default_dtype
|
self.default_dtype = default_dtype
|
||||||
|
|
||||||
self._w_trainable = bool(w_trainable)
|
self._w_trainable = bool(w_trainable)
|
||||||
# self._keep_last_grad_scale = bool(keep_last_grad_scale)
|
self._keep_last_grad_scale = bool(keep_last_grad_scale)
|
||||||
self._skip_gradient_calculation = bool(skip_gradient_calculation)
|
self._skip_gradient_calculation = bool(skip_gradient_calculation)
|
||||||
# self._disable_scale_grade = bool(disable_scale_grade)
|
self._disable_scale_grade = bool(disable_scale_grade)
|
||||||
self._epsilon_xy_intitial = float(epsilon_xy_intitial)
|
self._epsilon_xy_intitial = float(epsilon_xy_intitial)
|
||||||
self._stride = strides
|
self._stride = strides
|
||||||
self._dilation = dilation
|
self._dilation = dilation
|
||||||
|
@ -95,6 +111,21 @@ class SbS(torch.nn.Module):
|
||||||
assert len(input_size) == 2
|
assert len(input_size) == 2
|
||||||
self._input_size = input_size
|
self._input_size = input_size
|
||||||
|
|
||||||
|
global_sbs_gpu_setting.append(torch.tensor([0]))
|
||||||
|
global_spike_generation_gpu_setting.append(torch.tensor([0]))
|
||||||
|
global_sbs_size.append(torch.tensor([0, 0, 0, 0]))
|
||||||
|
global_spike_size.append(torch.tensor([0, 0, 0, 0]))
|
||||||
|
|
||||||
|
global_sbs_hdynamic_cpp.append(HDynamicCNNManyIP())
|
||||||
|
global_spike_generation_cpp.append(SpikeGeneration2DManyIP())
|
||||||
|
|
||||||
|
self.sbs_gpu_setting_position = len(global_sbs_gpu_setting) - 1
|
||||||
|
self.sbs_hdynamic_cpp_position = len(global_sbs_hdynamic_cpp) - 1
|
||||||
|
self.spike_generation_cpp_position = len(global_spike_generation_cpp) - 1
|
||||||
|
self.spike_generation_gpu_setting_position = (
|
||||||
|
len(global_spike_generation_gpu_setting) - 1
|
||||||
|
)
|
||||||
|
|
||||||
# The GPU hates me...
|
# The GPU hates me...
|
||||||
# Too many SbS threads == bad
|
# Too many SbS threads == bad
|
||||||
# Thus I need to limit them...
|
# Thus I need to limit them...
|
||||||
|
@ -105,10 +136,10 @@ class SbS(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
self._gpu_tuning_factor = 0
|
self._gpu_tuning_factor = 0
|
||||||
|
|
||||||
# self._last_grad_scale = torch.nn.parameter.Parameter(
|
self._last_grad_scale = torch.nn.parameter.Parameter(
|
||||||
# torch.tensor(-1.0, dtype=self.default_dtype),
|
torch.tensor(-1.0, dtype=self.default_dtype),
|
||||||
# requires_grad=True,
|
requires_grad=True,
|
||||||
# )
|
)
|
||||||
|
|
||||||
self._forgetting_offset = torch.tensor(
|
self._forgetting_offset = torch.tensor(
|
||||||
forgetting_offset, dtype=self.default_dtype, device=self.device
|
forgetting_offset, dtype=self.default_dtype, device=self.device
|
||||||
|
@ -234,12 +265,12 @@ class SbS(torch.nn.Module):
|
||||||
self.threshold_weights(threshold_weight)
|
self.threshold_weights(threshold_weight)
|
||||||
self.norm_weights()
|
self.norm_weights()
|
||||||
|
|
||||||
# def after_batch(self, new_state: bool = False):
|
def after_batch(self, new_state: bool = False):
|
||||||
# if self._keep_last_grad_scale is True:
|
if self._keep_last_grad_scale is True:
|
||||||
# self._last_grad_scale.data = self._last_grad_scale.grad
|
self._last_grad_scale.data = self._last_grad_scale.grad
|
||||||
# self._keep_last_grad_scale = new_state
|
self._keep_last_grad_scale = new_state
|
||||||
|
|
||||||
# self._last_grad_scale.grad = torch.zeros_like(self._last_grad_scale.grad)
|
self._last_grad_scale.grad = torch.zeros_like(self._last_grad_scale.grad)
|
||||||
|
|
||||||
####################################################################
|
####################################################################
|
||||||
# Helper functions #
|
# Helper functions #
|
||||||
|
@ -339,6 +370,20 @@ class SbS(torch.nn.Module):
|
||||||
assert self._weights_exists is True
|
assert self._weights_exists is True
|
||||||
assert self._weights is not None
|
assert self._weights is not None
|
||||||
|
|
||||||
|
assert self.sbs_gpu_setting_position != -1
|
||||||
|
assert self.sbs_hdynamic_cpp_position != -1
|
||||||
|
assert self.spike_generation_cpp_position != -1
|
||||||
|
assert self.spike_generation_gpu_setting_position != -1
|
||||||
|
|
||||||
|
if labels is None:
|
||||||
|
labels_copy: torch.Tensor = torch.tensor(
|
||||||
|
[], dtype=torch.int64, device=self.device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
labels_copy = (
|
||||||
|
labels.detach().clone().type(dtype=torch.int64).to(device=self.device)
|
||||||
|
)
|
||||||
|
|
||||||
input_convolved = torch.nn.functional.fold(
|
input_convolved = torch.nn.functional.fold(
|
||||||
torch.nn.functional.unfold(
|
torch.nn.functional.unfold(
|
||||||
input.requires_grad_(True),
|
input.requires_grad_(True),
|
||||||
|
@ -354,6 +399,12 @@ class SbS(torch.nn.Module):
|
||||||
stride=(1, 1),
|
stride=(1, 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.last_input_store is True:
|
||||||
|
self.last_input_data = input_convolved.detach().clone()
|
||||||
|
self.last_input_data /= self.last_input_data.sum(dim=1, keepdim=True)
|
||||||
|
else:
|
||||||
|
self.last_input_data = None
|
||||||
|
|
||||||
epsilon_t_0: torch.Tensor = (
|
epsilon_t_0: torch.Tensor = (
|
||||||
(self._epsilon_t * self._epsilon_0).type(input.dtype).to(input.device)
|
(self._epsilon_t * self._epsilon_0).type(input.dtype).to(input.device)
|
||||||
)
|
)
|
||||||
|
@ -361,8 +412,8 @@ class SbS(torch.nn.Module):
|
||||||
parameter_list = torch.tensor(
|
parameter_list = torch.tensor(
|
||||||
[
|
[
|
||||||
int(self._w_trainable), # 0
|
int(self._w_trainable), # 0
|
||||||
int(0), # int(self._disable_scale_grade), # 1
|
int(self._disable_scale_grade), # 1
|
||||||
int(0), # int(self._keep_last_grad_scale), # 2
|
int(self._keep_last_grad_scale), # 2
|
||||||
int(self._skip_gradient_calculation), # 3
|
int(self._skip_gradient_calculation), # 3
|
||||||
int(self._number_of_spikes), # 4
|
int(self._number_of_spikes), # 4
|
||||||
int(self._number_of_cpu_processes), # 5
|
int(self._number_of_cpu_processes), # 5
|
||||||
|
@ -371,6 +422,10 @@ class SbS(torch.nn.Module):
|
||||||
int(self._gpu_tuning_factor), # 8
|
int(self._gpu_tuning_factor), # 8
|
||||||
int(self._output_layer), # 9
|
int(self._output_layer), # 9
|
||||||
int(self._local_learning), # 10
|
int(self._local_learning), # 10
|
||||||
|
int(self.sbs_gpu_setting_position), # 11
|
||||||
|
int(self.sbs_hdynamic_cpp_position), # 12
|
||||||
|
int(self.spike_generation_cpp_position), # 13
|
||||||
|
int(self.spike_generation_gpu_setting_position), # 14
|
||||||
],
|
],
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
|
@ -401,8 +456,9 @@ class SbS(torch.nn.Module):
|
||||||
self._weights,
|
self._weights,
|
||||||
self._h_initial,
|
self._h_initial,
|
||||||
parameter_list,
|
parameter_list,
|
||||||
# self._last_grad_scale,
|
self._last_grad_scale,
|
||||||
self._forgetting_offset,
|
self._forgetting_offset,
|
||||||
|
labels_copy,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._number_of_grad_weight_contributions += (
|
self._number_of_grad_weight_contributions += (
|
||||||
|
@ -422,8 +478,9 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
weights: torch.Tensor,
|
weights: torch.Tensor,
|
||||||
h_initial: torch.Tensor,
|
h_initial: torch.Tensor,
|
||||||
parameter_list: torch.Tensor,
|
parameter_list: torch.Tensor,
|
||||||
# grad_output_scale: torch.Tensor,
|
grad_output_scale: torch.Tensor,
|
||||||
forgetting_offset: torch.Tensor,
|
forgetting_offset: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
assert input.dim() == 4
|
assert input.dim() == 4
|
||||||
|
@ -444,6 +501,11 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
output_size_1: int = int(parameter_list[7])
|
output_size_1: int = int(parameter_list[7])
|
||||||
gpu_tuning_factor: int = int(parameter_list[8])
|
gpu_tuning_factor: int = int(parameter_list[8])
|
||||||
|
|
||||||
|
sbs_gpu_setting_position = int(parameter_list[11])
|
||||||
|
sbs_hdynamic_cpp_position = int(parameter_list[12])
|
||||||
|
spike_generation_cpp_position = int(parameter_list[13])
|
||||||
|
spike_generation_gpu_setting_position = int(parameter_list[14])
|
||||||
|
|
||||||
# ###########################################################
|
# ###########################################################
|
||||||
# Spike generation
|
# Spike generation
|
||||||
# ###########################################################
|
# ###########################################################
|
||||||
|
@ -480,9 +542,59 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
assert spikes.is_contiguous() is True
|
assert spikes.is_contiguous() is True
|
||||||
|
|
||||||
# time_start: float = time.perf_counter()
|
# time_start: float = time.perf_counter()
|
||||||
spike_generation: SpikeGeneration2DManyIP = SpikeGeneration2DManyIP()
|
spike_generation_profile = global_spike_generation_gpu_setting[
|
||||||
|
spike_generation_gpu_setting_position
|
||||||
|
].clone()
|
||||||
|
|
||||||
spike_generation.spike_generation(
|
spike_generation_size = global_spike_size[
|
||||||
|
spike_generation_gpu_setting_position
|
||||||
|
].clone()
|
||||||
|
|
||||||
|
if input.device != torch.device("cpu"):
|
||||||
|
if (
|
||||||
|
(spike_generation_profile.numel() == 1)
|
||||||
|
or (spike_generation_size[0] != int(spikes.shape[0]))
|
||||||
|
or (spike_generation_size[1] != int(spikes.shape[1]))
|
||||||
|
or (spike_generation_size[2] != int(spikes.shape[2]))
|
||||||
|
or (spike_generation_size[3] != int(spikes.shape[3]))
|
||||||
|
):
|
||||||
|
spike_generation_profile = torch.zeros(
|
||||||
|
(1, 7), dtype=torch.int64, device=torch.device("cpu")
|
||||||
|
)
|
||||||
|
|
||||||
|
global_spike_generation_cpp[
|
||||||
|
spike_generation_cpp_position
|
||||||
|
].gpu_occupancy_export(
|
||||||
|
int(spikes.shape[2]),
|
||||||
|
int(spikes.shape[3]),
|
||||||
|
int(spikes.shape[0]),
|
||||||
|
int(spikes.shape[1]),
|
||||||
|
spike_generation_profile.data_ptr(),
|
||||||
|
int(spike_generation_profile.shape[0]),
|
||||||
|
int(spike_generation_profile.shape[1]),
|
||||||
|
)
|
||||||
|
global_spike_generation_gpu_setting[
|
||||||
|
spike_generation_gpu_setting_position
|
||||||
|
] = spike_generation_profile.clone()
|
||||||
|
|
||||||
|
spike_generation_size[0] = int(spikes.shape[0])
|
||||||
|
spike_generation_size[1] = int(spikes.shape[1])
|
||||||
|
spike_generation_size[2] = int(spikes.shape[2])
|
||||||
|
spike_generation_size[3] = int(spikes.shape[3])
|
||||||
|
global_spike_size[
|
||||||
|
spike_generation_gpu_setting_position
|
||||||
|
] = spike_generation_size.clone()
|
||||||
|
|
||||||
|
else:
|
||||||
|
global_spike_generation_cpp[
|
||||||
|
spike_generation_cpp_position
|
||||||
|
].gpu_occupancy_import(
|
||||||
|
spike_generation_profile.data_ptr(),
|
||||||
|
int(spike_generation_profile.shape[0]),
|
||||||
|
int(spike_generation_profile.shape[1]),
|
||||||
|
)
|
||||||
|
|
||||||
|
global_spike_generation_cpp[spike_generation_cpp_position].spike_generation(
|
||||||
input_cumsum.data_ptr(),
|
input_cumsum.data_ptr(),
|
||||||
int(input_cumsum.shape[0]),
|
int(input_cumsum.shape[0]),
|
||||||
int(input_cumsum.shape[1]),
|
int(input_cumsum.shape[1]),
|
||||||
|
@ -536,9 +648,46 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
assert weights.ndim == 2
|
assert weights.ndim == 2
|
||||||
assert h_initial.ndim == 1
|
assert h_initial.ndim == 1
|
||||||
|
|
||||||
h_dynamic: HDynamicCNNManyIP = HDynamicCNNManyIP()
|
sbs_profile = global_sbs_gpu_setting[sbs_gpu_setting_position].clone()
|
||||||
|
|
||||||
h_dynamic.update(
|
sbs_size = global_sbs_size[sbs_gpu_setting_position].clone()
|
||||||
|
|
||||||
|
if input.device != torch.device("cpu"):
|
||||||
|
if (
|
||||||
|
(sbs_profile.numel() == 1)
|
||||||
|
or (sbs_size[0] != int(output.shape[0]))
|
||||||
|
or (sbs_size[1] != int(output.shape[1]))
|
||||||
|
or (sbs_size[2] != int(output.shape[2]))
|
||||||
|
or (sbs_size[3] != int(output.shape[3]))
|
||||||
|
):
|
||||||
|
sbs_profile = torch.zeros(
|
||||||
|
(14, 7), dtype=torch.int64, device=torch.device("cpu")
|
||||||
|
)
|
||||||
|
|
||||||
|
global_sbs_hdynamic_cpp[sbs_hdynamic_cpp_position].gpu_occupancy_export(
|
||||||
|
int(output.shape[2]),
|
||||||
|
int(output.shape[3]),
|
||||||
|
int(output.shape[0]),
|
||||||
|
int(output.shape[1]),
|
||||||
|
sbs_profile.data_ptr(),
|
||||||
|
int(sbs_profile.shape[0]),
|
||||||
|
int(sbs_profile.shape[1]),
|
||||||
|
)
|
||||||
|
global_sbs_gpu_setting[sbs_gpu_setting_position] = sbs_profile.clone()
|
||||||
|
sbs_size[0] = int(output.shape[0])
|
||||||
|
sbs_size[1] = int(output.shape[1])
|
||||||
|
sbs_size[2] = int(output.shape[2])
|
||||||
|
sbs_size[3] = int(output.shape[3])
|
||||||
|
global_sbs_size[sbs_gpu_setting_position] = sbs_size.clone()
|
||||||
|
|
||||||
|
else:
|
||||||
|
global_sbs_hdynamic_cpp[sbs_hdynamic_cpp_position].gpu_occupancy_import(
|
||||||
|
sbs_profile.data_ptr(),
|
||||||
|
int(sbs_profile.shape[0]),
|
||||||
|
int(sbs_profile.shape[1]),
|
||||||
|
)
|
||||||
|
|
||||||
|
global_sbs_hdynamic_cpp[sbs_hdynamic_cpp_position].update(
|
||||||
output.data_ptr(),
|
output.data_ptr(),
|
||||||
int(output.shape[0]),
|
int(output.shape[0]),
|
||||||
int(output.shape[1]),
|
int(output.shape[1]),
|
||||||
|
@ -575,7 +724,8 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
weights,
|
weights,
|
||||||
output,
|
output,
|
||||||
parameter_list,
|
parameter_list,
|
||||||
# grad_output_scale,
|
grad_output_scale,
|
||||||
|
labels,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
@ -590,9 +740,12 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
weights,
|
weights,
|
||||||
output,
|
output,
|
||||||
parameter_list,
|
parameter_list,
|
||||||
# last_grad_scale,
|
last_grad_scale,
|
||||||
|
labels,
|
||||||
) = ctx.saved_tensors
|
) = ctx.saved_tensors
|
||||||
|
|
||||||
|
assert labels.numel() > 0
|
||||||
|
|
||||||
# ##############################################
|
# ##############################################
|
||||||
# Default output
|
# Default output
|
||||||
# ##############################################
|
# ##############################################
|
||||||
|
@ -603,13 +756,14 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
grad_h_initial = None
|
grad_h_initial = None
|
||||||
grad_parameter_list = None
|
grad_parameter_list = None
|
||||||
grad_forgetting_offset = None
|
grad_forgetting_offset = None
|
||||||
|
grad_labels = None
|
||||||
|
|
||||||
# ##############################################
|
# ##############################################
|
||||||
# Parameters
|
# Parameters
|
||||||
# ##############################################
|
# ##############################################
|
||||||
parameter_w_trainable: bool = bool(parameter_list[0])
|
parameter_w_trainable: bool = bool(parameter_list[0])
|
||||||
# parameter_disable_scale_grade: bool = bool(parameter_list[1])
|
parameter_disable_scale_grade: bool = bool(parameter_list[1])
|
||||||
# parameter_keep_last_grad_scale: bool = bool(parameter_list[2])
|
parameter_keep_last_grad_scale: bool = bool(parameter_list[2])
|
||||||
parameter_skip_gradient_calculation: bool = bool(parameter_list[3])
|
parameter_skip_gradient_calculation: bool = bool(parameter_list[3])
|
||||||
parameter_output_layer: bool = bool(parameter_list[9])
|
parameter_output_layer: bool = bool(parameter_list[9])
|
||||||
parameter_local_learning: bool = bool(parameter_list[10])
|
parameter_local_learning: bool = bool(parameter_list[10])
|
||||||
|
@ -617,13 +771,13 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
# ##############################################
|
# ##############################################
|
||||||
# Dealing with overall scale of the gradient
|
# Dealing with overall scale of the gradient
|
||||||
# ##############################################
|
# ##############################################
|
||||||
# if parameter_disable_scale_grade is False:
|
if parameter_disable_scale_grade is False:
|
||||||
# if parameter_keep_last_grad_scale is True:
|
if parameter_keep_last_grad_scale is True:
|
||||||
# last_grad_scale = torch.tensor(
|
last_grad_scale = torch.tensor(
|
||||||
# [torch.abs(grad_output).max(), last_grad_scale]
|
[torch.abs(grad_output).max(), last_grad_scale]
|
||||||
# ).max()
|
).max()
|
||||||
# grad_output /= last_grad_scale
|
grad_output /= last_grad_scale
|
||||||
# grad_output_scale = last_grad_scale.clone()
|
grad_output_scale = last_grad_scale.clone()
|
||||||
|
|
||||||
input /= input.sum(dim=1, keepdim=True, dtype=weights.dtype)
|
input /= input.sum(dim=1, keepdim=True, dtype=weights.dtype)
|
||||||
|
|
||||||
|
@ -640,8 +794,9 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
grad_weights,
|
grad_weights,
|
||||||
grad_h_initial,
|
grad_h_initial,
|
||||||
grad_parameter_list,
|
grad_parameter_list,
|
||||||
# grad_output_scale,
|
grad_output_scale,
|
||||||
grad_forgetting_offset,
|
grad_forgetting_offset,
|
||||||
|
grad_labels,
|
||||||
)
|
)
|
||||||
|
|
||||||
# #################################################
|
# #################################################
|
||||||
|
@ -682,6 +837,41 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
.sum(-1)
|
.sum(-1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif (parameter_output_layer is True) and (parameter_local_learning is True):
|
||||||
|
|
||||||
|
target_one_hot: torch.Tensor = torch.zeros(
|
||||||
|
(
|
||||||
|
labels.shape[0],
|
||||||
|
output.shape[1],
|
||||||
|
),
|
||||||
|
device=input.device,
|
||||||
|
dtype=input.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
target_one_hot.scatter_(
|
||||||
|
1,
|
||||||
|
labels.to(input.device).unsqueeze(1),
|
||||||
|
torch.ones(
|
||||||
|
(labels.shape[0], 1),
|
||||||
|
device=input.device,
|
||||||
|
dtype=input.dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
target_one_hot = target_one_hot.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
|
||||||
|
# (-2 * (input - backprop_bigr).unsqueeze(2) * (target_one_hot-output).unsqueeze(1))
|
||||||
|
# (-2 * input.unsqueeze(2) * (target_one_hot-output).unsqueeze(1))
|
||||||
|
grad_weights = (
|
||||||
|
(
|
||||||
|
-2
|
||||||
|
* (input - backprop_bigr).unsqueeze(2)
|
||||||
|
* target_one_hot.unsqueeze(1)
|
||||||
|
)
|
||||||
|
.sum(0)
|
||||||
|
.sum(-1)
|
||||||
|
.sum(-1)
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# #################################################
|
# #################################################
|
||||||
# Backprop
|
# Backprop
|
||||||
|
@ -709,6 +899,7 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
grad_weights,
|
grad_weights,
|
||||||
grad_h_initial,
|
grad_h_initial,
|
||||||
grad_parameter_list,
|
grad_parameter_list,
|
||||||
# grad_output_scale,
|
grad_output_scale,
|
||||||
grad_forgetting_offset,
|
grad_forgetting_offset,
|
||||||
|
grad_labels,
|
||||||
)
|
)
|
||||||
|
|
33
network/SbSReconstruction.py
Normal file
33
network/SbSReconstruction.py
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from network.SbS import SbS
|
||||||
|
|
||||||
|
|
||||||
|
class SbSReconstruction(torch.nn.Module):
|
||||||
|
|
||||||
|
_the_sbs_layer: SbS
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
the_sbs_layer: SbS,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._the_sbs_layer = the_sbs_layer
|
||||||
|
self.device = self._the_sbs_layer.device
|
||||||
|
self.default_dtype = self._the_sbs_layer.default_dtype
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
assert self._the_sbs_layer._weights_exists is True
|
||||||
|
|
||||||
|
input_norm = input / input.sum(dim=1, keepdim=True)
|
||||||
|
|
||||||
|
output = (
|
||||||
|
self._the_sbs_layer._weights.data.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
* input_norm.unsqueeze(1)
|
||||||
|
).sum(dim=2)
|
||||||
|
|
||||||
|
output /= output.sum(dim=1, keepdim=True)
|
||||||
|
|
||||||
|
return output
|
|
@ -6,6 +6,7 @@ from network.Parameter import Config
|
||||||
from network.SbS import SbS
|
from network.SbS import SbS
|
||||||
from network.SplitOnOffLayer import SplitOnOffLayer
|
from network.SplitOnOffLayer import SplitOnOffLayer
|
||||||
from network.Conv2dApproximation import Conv2dApproximation
|
from network.Conv2dApproximation import Conv2dApproximation
|
||||||
|
from network.SbSReconstruction import SbSReconstruction
|
||||||
|
|
||||||
|
|
||||||
def build_network(
|
def build_network(
|
||||||
|
@ -159,12 +160,13 @@ def build_network(
|
||||||
padding=padding,
|
padding=padding,
|
||||||
number_of_cpu_processes=cfg.number_of_cpu_processes,
|
number_of_cpu_processes=cfg.number_of_cpu_processes,
|
||||||
w_trainable=w_trainable,
|
w_trainable=w_trainable,
|
||||||
# keep_last_grad_scale=cfg.learning_parameters.kepp_last_grad_scale,
|
keep_last_grad_scale=cfg.learning_parameters.kepp_last_grad_scale,
|
||||||
# disable_scale_grade=cfg.learning_parameters.disable_scale_grade,
|
disable_scale_grade=cfg.learning_parameters.disable_scale_grade,
|
||||||
forgetting_offset=cfg.forgetting_offset,
|
forgetting_offset=cfg.forgetting_offset,
|
||||||
skip_gradient_calculation=sbs_skip_gradient_calculation,
|
skip_gradient_calculation=sbs_skip_gradient_calculation,
|
||||||
device=device,
|
device=device,
|
||||||
default_dtype=default_dtype,
|
default_dtype=default_dtype,
|
||||||
|
layer_id=layer_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Adding the x,y output dimensions
|
# Adding the x,y output dimensions
|
||||||
|
@ -178,6 +180,25 @@ def build_network(
|
||||||
if cfg.network_structure.layer_type[layer_id].upper().find("LOCAL") != -1:
|
if cfg.network_structure.layer_type[layer_id].upper().find("LOCAL") != -1:
|
||||||
network[-1]._local_learning = True
|
network[-1]._local_learning = True
|
||||||
|
|
||||||
|
elif (
|
||||||
|
cfg.network_structure.layer_type[layer_id]
|
||||||
|
.upper()
|
||||||
|
.startswith("RECONSTRUCTION")
|
||||||
|
is True
|
||||||
|
):
|
||||||
|
logging.info(f"Layer: {layer_id} -> SbS Reconstruction Layer")
|
||||||
|
|
||||||
|
assert layer_id > 0
|
||||||
|
assert isinstance(network[-1], SbS) is True
|
||||||
|
|
||||||
|
network.append(SbSReconstruction(network[-1]))
|
||||||
|
network[-1]._w_trainable = False
|
||||||
|
|
||||||
|
if layer_id == len(cfg.network_structure.layer_type) - 1:
|
||||||
|
network[-2].last_input_store = True
|
||||||
|
|
||||||
|
input_size.append(input_size[-1])
|
||||||
|
|
||||||
# #############################################################
|
# #############################################################
|
||||||
# Split On Off Layer:
|
# Split On Off Layer:
|
||||||
# #############################################################
|
# #############################################################
|
||||||
|
|
|
@ -57,10 +57,13 @@ def build_optimizer(
|
||||||
optimizer_wf = Adam(
|
optimizer_wf = Adam(
|
||||||
parameter_list_weights,
|
parameter_list_weights,
|
||||||
parameter_list_sbs,
|
parameter_list_sbs,
|
||||||
|
logging=logging,
|
||||||
lr=cfg.learning_parameters.learning_rate_gamma_w,
|
lr=cfg.learning_parameters.learning_rate_gamma_w,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
optimizer_wf = Adam(parameter_list_weights, parameter_list_sbs)
|
optimizer_wf = Adam(
|
||||||
|
parameter_list_weights, parameter_list_sbs, logging=logging
|
||||||
|
)
|
||||||
|
|
||||||
elif cfg.learning_parameters.optimizer_name == "SGD":
|
elif cfg.learning_parameters.optimizer_name == "SGD":
|
||||||
logging.info("Using optimizer: SGD")
|
logging.info("Using optimizer: SGD")
|
||||||
|
|
|
@ -5,6 +5,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from network.SbS import SbS
|
from network.SbS import SbS
|
||||||
from network.save_weight_and_bias import save_weight_and_bias
|
from network.save_weight_and_bias import save_weight_and_bias
|
||||||
|
from network.SbSReconstruction import SbSReconstruction
|
||||||
|
|
||||||
|
|
||||||
def add_weight_and_bias_to_histogram(
|
def add_weight_and_bias_to_histogram(
|
||||||
|
@ -94,7 +95,7 @@ def loss_function(
|
||||||
device=device,
|
device=device,
|
||||||
dtype=default_dtype,
|
dtype=default_dtype,
|
||||||
),
|
),
|
||||||
).unsqueeze(-1).unsqueeze(-1)
|
)
|
||||||
|
|
||||||
h_y1 = torch.log(h + 1e-20)
|
h_y1 = torch.log(h + 1e-20)
|
||||||
|
|
||||||
|
@ -119,6 +120,44 @@ def loss_function(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def loss_function_reconstruction(
|
||||||
|
h_reco: torch.Tensor,
|
||||||
|
h_input: torch.Tensor,
|
||||||
|
loss_mode: int = 0,
|
||||||
|
loss_coeffs_mse: float = 0.0,
|
||||||
|
loss_coeffs_kldiv: float = 0.0,
|
||||||
|
) -> torch.Tensor | None:
|
||||||
|
assert loss_mode >= 0
|
||||||
|
assert loss_mode <= 0
|
||||||
|
|
||||||
|
assert h_reco.ndim == 4
|
||||||
|
assert h_input.ndim == 4
|
||||||
|
assert h_reco.shape[0] == h_input.shape[0]
|
||||||
|
assert h_reco.shape[1] == h_input.shape[1]
|
||||||
|
assert h_reco.shape[2] == h_input.shape[2]
|
||||||
|
assert h_reco.shape[3] == h_input.shape[3]
|
||||||
|
|
||||||
|
if loss_mode == 0:
|
||||||
|
|
||||||
|
h_reco_log = torch.log(h_reco + 1e-20)
|
||||||
|
|
||||||
|
my_loss: torch.Tensor = (
|
||||||
|
torch.nn.functional.mse_loss(
|
||||||
|
h_reco,
|
||||||
|
h_input,
|
||||||
|
reduction="sum",
|
||||||
|
)
|
||||||
|
* loss_coeffs_mse
|
||||||
|
+ torch.nn.functional.kl_div(h_reco_log, h_input + 1e-20, reduction="sum")
|
||||||
|
* loss_coeffs_kldiv
|
||||||
|
) / (loss_coeffs_kldiv + loss_coeffs_mse)
|
||||||
|
|
||||||
|
return my_loss
|
||||||
|
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def forward_pass_train(
|
def forward_pass_train(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
labels: torch.Tensor,
|
labels: torch.Tensor,
|
||||||
|
@ -228,15 +267,15 @@ def run_lr_scheduler(
|
||||||
tb.flush()
|
tb.flush()
|
||||||
|
|
||||||
|
|
||||||
# def deal_with_gradient_scale(epoch_id: int, mini_batch_number: int, network):
|
def deal_with_gradient_scale(epoch_id: int, mini_batch_number: int, network):
|
||||||
# if (epoch_id == 0) and (mini_batch_number == 0):
|
if (epoch_id == 0) and (mini_batch_number == 0):
|
||||||
# for id in range(0, len(network)):
|
for id in range(0, len(network)):
|
||||||
# if isinstance(network[id], SbS) is True:
|
if isinstance(network[id], SbS) is True:
|
||||||
# network[id].after_batch(True)
|
network[id].after_batch(True)
|
||||||
# else:
|
else:
|
||||||
# for id in range(0, len(network)):
|
for id in range(0, len(network)):
|
||||||
# if isinstance(network[id], SbS) is True:
|
if isinstance(network[id], SbS) is True:
|
||||||
# network[id].after_batch()
|
network[id].after_batch()
|
||||||
|
|
||||||
|
|
||||||
def loop_train(
|
def loop_train(
|
||||||
|
@ -318,12 +357,20 @@ def loop_train(
|
||||||
if last_test_performance < 0:
|
if last_test_performance < 0:
|
||||||
logging.info("")
|
logging.info("")
|
||||||
else:
|
else:
|
||||||
|
if isinstance(network[-1], SbSReconstruction) is False:
|
||||||
logging.info(
|
logging.info(
|
||||||
(
|
(
|
||||||
f"\t\t\tLast test performance: "
|
f"\t\t\tLast test performance: "
|
||||||
f"{last_test_performance/100.0:^6.2%}"
|
f"{last_test_performance/100.0:^6.2%}"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
logging.info(
|
||||||
|
(
|
||||||
|
f"\t\t\tLast test performance: "
|
||||||
|
f"{last_test_performance:^6.2e}"
|
||||||
|
)
|
||||||
|
)
|
||||||
logging.info("----------------")
|
logging.info("----------------")
|
||||||
|
|
||||||
number_of_pattern_in_minibatch += h_x_labels.shape[0]
|
number_of_pattern_in_minibatch += h_x_labels.shape[0]
|
||||||
|
@ -345,6 +392,8 @@ def loop_train(
|
||||||
# #####################################################
|
# #####################################################
|
||||||
# Calculate the loss function
|
# Calculate the loss function
|
||||||
# #####################################################
|
# #####################################################
|
||||||
|
|
||||||
|
if isinstance(network[-1], SbSReconstruction) is False:
|
||||||
my_loss: torch.Tensor | None = loss_function(
|
my_loss: torch.Tensor | None = loss_function(
|
||||||
h=h_collection[-1],
|
h=h_collection[-1],
|
||||||
labels=h_x_labels,
|
labels=h_x_labels,
|
||||||
|
@ -357,6 +406,16 @@ def loop_train(
|
||||||
loss_coeffs_mse=float(cfg.learning_parameters.loss_coeffs_mse),
|
loss_coeffs_mse=float(cfg.learning_parameters.loss_coeffs_mse),
|
||||||
loss_coeffs_kldiv=float(cfg.learning_parameters.loss_coeffs_kldiv),
|
loss_coeffs_kldiv=float(cfg.learning_parameters.loss_coeffs_kldiv),
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
assert cfg.learning_parameters.lr_scheduler_use_performance is False
|
||||||
|
my_loss = loss_function_reconstruction(
|
||||||
|
h_reco=h_collection[-1],
|
||||||
|
h_input=network[-2].last_input_data,
|
||||||
|
loss_mode=cfg.learning_parameters.loss_mode,
|
||||||
|
loss_coeffs_mse=float(cfg.learning_parameters.loss_coeffs_mse),
|
||||||
|
loss_coeffs_kldiv=float(cfg.learning_parameters.loss_coeffs_kldiv),
|
||||||
|
)
|
||||||
|
|
||||||
assert my_loss is not None
|
assert my_loss is not None
|
||||||
|
|
||||||
time_after_forward_and_loss: float = time.perf_counter()
|
time_after_forward_and_loss: float = time.perf_counter()
|
||||||
|
@ -374,6 +433,7 @@ def loop_train(
|
||||||
# Performance measures
|
# Performance measures
|
||||||
# #####################################################
|
# #####################################################
|
||||||
|
|
||||||
|
if isinstance(network[-1], SbSReconstruction) is False:
|
||||||
correct_in_minibatch += (
|
correct_in_minibatch += (
|
||||||
(h_collection[-1].argmax(dim=1).squeeze().cpu() == h_x_labels)
|
(h_collection[-1].argmax(dim=1).squeeze().cpu() == h_x_labels)
|
||||||
.sum()
|
.sum()
|
||||||
|
@ -391,11 +451,11 @@ def loop_train(
|
||||||
# the future error with it
|
# the future error with it
|
||||||
# Kind of deals with the vanishing /
|
# Kind of deals with the vanishing /
|
||||||
# exploding gradients
|
# exploding gradients
|
||||||
# deal_with_gradient_scale(
|
deal_with_gradient_scale(
|
||||||
# epoch_id=epoch_id,
|
epoch_id=epoch_id,
|
||||||
# mini_batch_number=mini_batch_number,
|
mini_batch_number=mini_batch_number,
|
||||||
# network=network,
|
network=network,
|
||||||
# )
|
)
|
||||||
|
|
||||||
# Measure the time for one mini-batch
|
# Measure the time for one mini-batch
|
||||||
time_forward += time_after_forward_and_loss - time_mini_batch_start
|
time_forward += time_after_forward_and_loss - time_mini_batch_start
|
||||||
|
@ -403,6 +463,7 @@ def loop_train(
|
||||||
|
|
||||||
if number_of_pattern_in_minibatch >= cfg.get_update_after_x_pattern():
|
if number_of_pattern_in_minibatch >= cfg.get_update_after_x_pattern():
|
||||||
|
|
||||||
|
if isinstance(network[-1], SbSReconstruction) is False:
|
||||||
logging.info(
|
logging.info(
|
||||||
(
|
(
|
||||||
f"{epoch_id:^6}=>{mini_batch_number:^6} "
|
f"{epoch_id:^6}=>{mini_batch_number:^6} "
|
||||||
|
@ -419,6 +480,22 @@ def loop_train(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
logging.info(
|
||||||
|
(
|
||||||
|
f"{epoch_id:^6}=>{mini_batch_number:^6} "
|
||||||
|
f"\t\tTraining {number_of_pattern_in_minibatch^6} pattern "
|
||||||
|
f"\t\t\tForward time: \t{time_forward:^6.2f}sec"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
(
|
||||||
|
f"\t\t\tLoss: {loss_in_minibatch/number_of_pattern_in_minibatch:^15.3e} "
|
||||||
|
f"\t\t\tBackward time: \t{time_backward:^6.2f}sec "
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
my_loss_for_batch = loss_in_minibatch / number_of_pattern_in_minibatch
|
my_loss_for_batch = loss_in_minibatch / number_of_pattern_in_minibatch
|
||||||
|
|
||||||
performance_for_batch = (
|
performance_for_batch = (
|
||||||
|
@ -510,3 +587,65 @@ def loop_test(
|
||||||
tb.flush()
|
tb.flush()
|
||||||
|
|
||||||
return performance
|
return performance
|
||||||
|
|
||||||
|
|
||||||
|
def loop_test_reconstruction(
|
||||||
|
epoch_id: int,
|
||||||
|
cfg: Config,
|
||||||
|
network: torch.nn.modules.container.Sequential,
|
||||||
|
my_loader_test: torch.utils.data.dataloader.DataLoader,
|
||||||
|
the_dataset_test,
|
||||||
|
device: torch.device,
|
||||||
|
default_dtype: torch.dtype,
|
||||||
|
logging,
|
||||||
|
tb: SummaryWriter,
|
||||||
|
) -> float:
|
||||||
|
|
||||||
|
test_count: int = 0
|
||||||
|
test_loss: float = 0.0
|
||||||
|
test_complete: int = the_dataset_test.__len__()
|
||||||
|
|
||||||
|
logging.info("")
|
||||||
|
logging.info("Testing:")
|
||||||
|
|
||||||
|
for h_x, h_x_labels in my_loader_test:
|
||||||
|
time_0 = time.perf_counter()
|
||||||
|
|
||||||
|
h_collection = forward_pass_test(
|
||||||
|
input=h_x,
|
||||||
|
the_dataset_test=the_dataset_test,
|
||||||
|
cfg=cfg,
|
||||||
|
network=network,
|
||||||
|
device=device,
|
||||||
|
default_dtype=default_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
my_loss: torch.Tensor | None = loss_function_reconstruction(
|
||||||
|
h_reco=h_collection[-1],
|
||||||
|
h_input=network[-2].last_input_data,
|
||||||
|
loss_mode=cfg.learning_parameters.loss_mode,
|
||||||
|
loss_coeffs_mse=float(cfg.learning_parameters.loss_coeffs_mse),
|
||||||
|
loss_coeffs_kldiv=float(cfg.learning_parameters.loss_coeffs_kldiv),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert my_loss is not None
|
||||||
|
test_count += h_x_labels.shape[0]
|
||||||
|
test_loss += my_loss.item()
|
||||||
|
|
||||||
|
performance = test_loss / test_count
|
||||||
|
time_1 = time.perf_counter()
|
||||||
|
time_measure_a = time_1 - time_0
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
(
|
||||||
|
f"\t\t{test_count} of {test_complete}"
|
||||||
|
f" with {performance:^6.2e} \t Time used: {time_measure_a:^6.2f}sec"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("")
|
||||||
|
|
||||||
|
tb.add_scalar("Test Error", performance, epoch_id)
|
||||||
|
tb.flush()
|
||||||
|
|
||||||
|
return performance
|
||||||
|
|
Loading…
Reference in a new issue