backprop is single precision now

This commit is contained in:
David Rotermund 2022-05-04 14:42:44 +02:00 committed by GitHub
parent aac88b41cf
commit 654014b319
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 98 additions and 82 deletions

View file

@ -71,6 +71,7 @@ class LearningParameters:
learning_rate_threshold_eps_xy: float = field(default=0.00001)
lr_schedule_name: str = field(default="ReduceLROnPlateau")
lr_scheduler_use_performance: bool = field(default=True)
lr_scheduler_factor_w: float = field(default=0.75)
lr_scheduler_patience_w: int = field(default=-1)
@ -133,11 +134,12 @@ class Config:
number_of_cpu_processes: int = field(default=-1)
number_of_spikes: int = field(default=0)
cooldown_after_number_of_spikes: int = field(default=0)
cooldown_after_number_of_spikes: int = field(default=-1)
weight_path: str = field(default="./Weights/")
eps_xy_path: str = field(default="./EpsXY/")
data_path: str = field(default="./")
results_path: str = field(default="./Results")
reduction_cooldown: float = field(default=25.0)
epsilon_0: float = field(default=1.0)
@ -159,6 +161,7 @@ class Config:
os.makedirs(self.weight_path, exist_ok=True)
os.makedirs(self.eps_xy_path, exist_ok=True)
os.makedirs(self.data_path, exist_ok=True)
os.makedirs(self.results_path, exist_ok=True)
self.batch_size = (
self.batch_size // self.number_of_cpu_processes
@ -170,6 +173,9 @@ class Config:
def get_epsilon_t(self):
"""Generates the time series of the basic epsilon."""
np_epsilon_t: np.ndarray = np.ones((self.number_of_spikes), dtype=np.float32)
if (self.cooldown_after_number_of_spikes < self.number_of_spikes) and (
self.cooldown_after_number_of_spikes >= 0
):
np_epsilon_t[
self.cooldown_after_number_of_spikes : self.number_of_spikes
] /= self.reduction_cooldown

94
SbS.py
View file

@ -122,7 +122,7 @@ class SbS(torch.nn.Module):
self.initialize_epsilon_xy(epsilon_xy_intitial)
self.epsilon_0 = torch.tensor(epsilon_0, dtype=torch.float64)
self.epsilon_0 = torch.tensor(epsilon_0, dtype=torch.float32)
self.number_of_cpu_processes = torch.tensor(
number_of_cpu_processes, dtype=torch.int64
@ -130,7 +130,7 @@ class SbS(torch.nn.Module):
self.number_of_spikes = torch.tensor(number_of_spikes, dtype=torch.int64)
self.epsilon_t = epsilon_t.type(dtype=torch.float64)
self.epsilon_t = epsilon_t.type(dtype=torch.float32)
self.initialize_weights(
is_pooling_layer=is_pooling_layer,
@ -155,7 +155,7 @@ class SbS(torch.nn.Module):
assert value is not None
assert torch.is_tensor(value) is True
assert value.dim() == 4
assert value.dtype == torch.float64
assert value.dtype == torch.float32
if self._epsilon_xy_exists is False:
self._epsilon_xy = torch.nn.parameter.Parameter(
value.detach().clone(memory_format=torch.contiguous_format),
@ -176,7 +176,7 @@ class SbS(torch.nn.Module):
assert value is not None
assert torch.is_tensor(value) is True
assert torch.numel(value) == 1
assert value.dtype == torch.float64
assert value.dtype == torch.float32
assert value.item() > 0
self._epsilon_0 = value.detach().clone(memory_format=torch.contiguous_format)
self._epsilon_0.requires_grad_(False)
@ -190,7 +190,7 @@ class SbS(torch.nn.Module):
assert value is not None
assert torch.is_tensor(value) is True
assert value.dim() == 1
assert value.dtype == torch.float64
assert value.dtype == torch.float32
self._epsilon_t = value.detach().clone(memory_format=torch.contiguous_format)
self._epsilon_t.requires_grad_(False)
@ -206,9 +206,9 @@ class SbS(torch.nn.Module):
assert value is not None
assert torch.is_tensor(value) is True
assert value.dim() == 2
assert value.dtype == torch.float64
assert value.dtype == torch.float32
temp: torch.Tensor = value.detach().clone(memory_format=torch.contiguous_format)
temp /= temp.sum(dim=0, keepdim=True, dtype=torch.float64)
temp /= temp.sum(dim=0, keepdim=True, dtype=torch.float32)
if self._weights_exists is False:
self._weights = torch.nn.parameter.Parameter(
temp,
@ -402,7 +402,7 @@ class SbS(torch.nn.Module):
assert input is not None
assert torch.is_tensor(input) is True
assert input.dim() == 4
assert input.dtype == torch.float64
assert input.dtype == torch.float32
# Are we happy with the rest of the network?
assert self._epsilon_xy_exists is True
@ -499,7 +499,7 @@ class SbS(torch.nn.Module):
torch.unsqueeze(
torch.unsqueeze(
torch.unsqueeze(
torch.arange(0, int(value[0]), dtype=torch.float64),
torch.arange(0, int(value[0]), dtype=torch.float32),
1,
),
0,
@ -516,7 +516,7 @@ class SbS(torch.nn.Module):
torch.unsqueeze(
torch.unsqueeze(
torch.unsqueeze(
torch.arange(0, int(value[1]), dtype=torch.float64),
torch.arange(0, int(value[1]), dtype=torch.float32),
0,
),
0,
@ -537,7 +537,7 @@ class SbS(torch.nn.Module):
assert torch.numel(noise_amplitude) == 1
assert noise_amplitude.item() >= 0
assert noise_amplitude.dtype == torch.float64
assert noise_amplitude.dtype == torch.float32
assert self._number_of_neurons is not None
assert self._number_of_input_neurons is not None
@ -550,7 +550,7 @@ class SbS(torch.nn.Module):
int(self._number_of_input_neurons),
int(self._number_of_neurons),
),
dtype=torch.float64,
dtype=torch.float32,
)
torch.nn.init.uniform_(weights, a=1.0, b=(1.0 + noise_amplitude.item()))
@ -571,7 +571,7 @@ class SbS(torch.nn.Module):
int(self._number_of_neurons),
int(self._number_of_neurons),
),
dtype=torch.float64,
dtype=torch.float32,
)
for i in range(0, int(self._number_of_neurons)):
@ -593,7 +593,7 @@ class SbS(torch.nn.Module):
weights = self._make_pooling_weights()
else:
weights = self._initial_random_weights(
torch.tensor(noise_amplitude, dtype=torch.float64)
torch.tensor(noise_amplitude, dtype=torch.float32)
)
weights = weights.moveaxis(-1, 0).moveaxis(-1, 1)
@ -628,7 +628,7 @@ class SbS(torch.nn.Module):
int(self._kernel_size[1]),
),
eps_xy_intitial,
dtype=torch.float64,
dtype=torch.float32,
)
self.epsilon_xy = eps_xy_temp
@ -660,7 +660,7 @@ class SbS(torch.nn.Module):
fill_value: float = float(self._epsilon_xy.data.mean())
self._epsilon_xy.data = torch.full_like(
self._epsilon_xy.data, fill_value, dtype=torch.float64
self._epsilon_xy.data, fill_value, dtype=torch.float32
)
def threshold_epsilon_xy(self, threshold: float) -> None:
@ -688,9 +688,9 @@ class SbS(torch.nn.Module):
temp: torch.Tensor = (
self._weights.data.detach()
.clone(memory_format=torch.contiguous_format)
.type(dtype=torch.float64)
.type(dtype=torch.float32)
)
temp /= temp.sum(dim=0, keepdim=True, dtype=torch.float64)
temp /= temp.sum(dim=0, keepdim=True, dtype=torch.float32)
self._weights.data = temp
def threshold_weights(self, threshold: float) -> None:
@ -708,11 +708,11 @@ class FunctionalSbS(torch.autograd.Function):
@staticmethod
def forward( # type: ignore
ctx,
input_float64: torch.Tensor,
epsilon_xy_float64: torch.Tensor,
epsilon_0_float64: torch.Tensor,
epsilon_t_float64: torch.Tensor,
weights_float64: torch.Tensor,
input: torch.Tensor,
epsilon_xy: torch.Tensor,
epsilon_0: torch.Tensor,
epsilon_t: torch.Tensor,
weights: torch.Tensor,
kernel_size: torch.Tensor,
stride: torch.Tensor,
dilation: torch.Tensor,
@ -724,11 +724,7 @@ class FunctionalSbS(torch.autograd.Function):
alpha_number_of_iterations: torch.Tensor,
) -> torch.Tensor:
input = input_float64.type(dtype=torch.float32)
epsilon_xy = epsilon_xy_float64.type(dtype=torch.float32)
weights = weights_float64.type(dtype=torch.float32)
epsilon_0 = epsilon_0_float64.type(dtype=torch.float32)
epsilon_t = epsilon_t_float64.type(dtype=torch.float32)
torch.set_default_dtype(torch.float32)
assert input.dim() == 4
assert torch.numel(kernel_size) == 2
@ -1097,7 +1093,7 @@ class FunctionalSbS(torch.autograd.Function):
)
alpha_dynamic = alpha_temp.sum(dim=1, keepdim=True)
alpha_dynamic += torch.finfo(torch.float32).eps * 1000
alpha_dynamic += 1e-20
# Alpha normalization
alpha_dynamic /= alpha_dynamic.sum(dim=3, keepdim=True).sum(
@ -1114,13 +1110,11 @@ class FunctionalSbS(torch.autograd.Function):
# Save the necessary data for the backward pass #
############################################################
output = output.type(dtype=torch.float64)
ctx.save_for_backward(
input_convolved,
epsilon_xy_convolved,
epsilon_0_float64,
weights_float64,
epsilon_0,
weights,
output,
kernel_size,
stride,
@ -1136,8 +1130,8 @@ class FunctionalSbS(torch.autograd.Function):
# Get the variables back
(
input_float32,
epsilon_xy_float32,
input,
epsilon_xy,
epsilon_0,
weights,
output,
@ -1148,9 +1142,9 @@ class FunctionalSbS(torch.autograd.Function):
input_size,
) = ctx.saved_tensors
input = input_float32.type(dtype=torch.float64)
input /= input.sum(dim=1, keepdim=True, dtype=torch.float64)
epsilon_xy = epsilon_xy_float32.type(dtype=torch.float64)
torch.set_default_dtype(torch.float32)
input /= input.sum(dim=1, keepdim=True, dtype=torch.float32)
# For debugging:
# print(
@ -1172,21 +1166,21 @@ class FunctionalSbS(torch.autograd.Function):
backprop_bigr: torch.Tensor = backprop_r.sum(axis=2)
temp: torch.Tensor = input / backprop_bigr**2
temp: torch.Tensor = input / (backprop_bigr**2 + 1e-20)
backprop_f: torch.Tensor = output.unsqueeze(1) * temp.unsqueeze(2)
torch.nan_to_num(
backprop_f, out=backprop_f, nan=1e300, posinf=1e300, neginf=-1e300
backprop_f, out=backprop_f, nan=1e30, posinf=1e30, neginf=-1e30
)
torch.clip(backprop_f, out=backprop_f, min=-1e300, max=1e300)
torch.clip(backprop_f, out=backprop_f, min=-1e30, max=1e30)
tempz: torch.Tensor = 1.0 / backprop_bigr
tempz: torch.Tensor = 1.0 / (backprop_bigr + 1e-20)
backprop_z: torch.Tensor = backprop_r * tempz.unsqueeze(2)
torch.nan_to_num(
backprop_z, out=backprop_z, nan=1e300, posinf=1e300, neginf=-1e300
backprop_z, out=backprop_z, nan=1e30, posinf=1e30, neginf=-1e30
)
torch.clip(backprop_z, out=backprop_z, min=-1e300, max=1e300)
torch.clip(backprop_z, out=backprop_z, min=-1e30, max=1e30)
result_omega: torch.Tensor = backprop_bigr.unsqueeze(2) * grad_output.unsqueeze(
1
@ -1211,9 +1205,9 @@ class FunctionalSbS(torch.autograd.Function):
grad_weights = result_omega.sum(0).sum(-1).sum(-1)
torch.nan_to_num(
grad_weights, out=grad_weights, nan=1e300, posinf=1e300, neginf=-1e300
grad_weights, out=grad_weights, nan=1e30, posinf=1e30, neginf=-1e30
)
torch.clip(grad_weights, out=grad_weights, min=-1e300, max=1e300)
torch.clip(grad_weights, out=grad_weights, min=-1e30, max=1e30)
grad_input = torch.nn.functional.fold(
torch.nn.functional.unfold(
@ -1230,9 +1224,9 @@ class FunctionalSbS(torch.autograd.Function):
stride=stride,
)
torch.nan_to_num(
grad_input, out=grad_input, nan=1e300, posinf=1e300, neginf=-1e300
grad_input, out=grad_input, nan=1e30, posinf=1e30, neginf=-1e30
)
torch.clip(grad_input, out=grad_input, min=-1e300, max=1e300)
torch.clip(grad_input, out=grad_input, min=-1e30, max=1e30)
grad_eps_xy_temp = torch.nn.functional.fold(
result_eps_xy.moveaxis(0, -1)
@ -1260,9 +1254,9 @@ class FunctionalSbS(torch.autograd.Function):
.contiguous(memory_format=torch.contiguous_format)
)
torch.nan_to_num(
grad_eps_xy, out=grad_eps_xy, nan=1e300, posinf=1e300, neginf=-1e300
grad_eps_xy, out=grad_eps_xy, nan=1e30, posinf=1e30, neginf=-1e30
)
torch.clip(grad_eps_xy, out=grad_eps_xy, min=-1e300, max=1e300)
torch.clip(grad_eps_xy, out=grad_eps_xy, min=-1e30, max=1e30)
grad_epsilon_0 = None
grad_epsilon_t = None

View file

@ -56,6 +56,8 @@ from torch.utils.tensorboard import SummaryWriter
tb = SummaryWriter()
torch.set_default_dtype(torch.float32)
#######################################################################
# We want to log what is going on into a file and screen #
#######################################################################
@ -191,7 +193,7 @@ for id in range(0, len(network)):
if os.path.exists(filename) is True:
network[id].weights = torch.tensor(
np.load(filename),
dtype=torch.float64,
dtype=torch.float32,
)
wf[id] = np.load(filename)
@ -206,7 +208,7 @@ for id in range(0, len(network)):
if os.path.exists(filename) is True:
network[id].epsilon_xy = torch.tensor(
np.load(filename),
dtype=torch.float64,
dtype=torch.float32,
)
eps_xy[id] = np.load(filename)
@ -225,7 +227,7 @@ for id in range(0, len(network)):
if len(file_to_load) == 1:
network[id].weights = torch.tensor(
np.load(file_to_load[0]),
dtype=torch.float64,
dtype=torch.float32,
)
wf[id] = np.load(file_to_load[0])
logging.info(f"File used: {file_to_load[0]}")
@ -243,7 +245,7 @@ for id in range(0, len(network)):
if len(file_to_load) == 1:
network[id].epsilon_xy = torch.tensor(
np.load(file_to_load[0]),
dtype=torch.float64,
dtype=torch.float32,
)
eps_xy[id] = np.load(file_to_load[0])
logging.info(f"File used: {file_to_load[0]}")
@ -346,7 +348,7 @@ with torch.no_grad():
h_collection = []
h_collection.append(
the_dataset_train.pattern_filter_train(h_x, cfg).type(
dtype=torch.float64
dtype=torch.float32
)
)
for id in range(0, len(network)):
@ -365,21 +367,21 @@ with torch.no_grad():
target_one_hot = (
target_one_hot.unsqueeze(2)
.unsqueeze(2)
.type(dtype=torch.float64)
.type(dtype=torch.float32)
)
# through the loss functions
h_y1 = torch.log(h_collection[-1])
h_y2 = torch.nan_to_num(h_y1, nan=0.0, posinf=0.0, neginf=0.0)
h_y1 = torch.log(h_collection[-1] + 1e-20)
my_loss: torch.Tensor = (
(
torch.nn.functional.mse_loss(
h_collection[-1], target_one_hot, reduction="none"
h_collection[-1],
target_one_hot,
reduction="none",
)
* cfg.learning_parameters.loss_coeffs_mse
+ torch.nn.functional.kl_div(
h_y2, target_one_hot, reduction="none"
h_y1, target_one_hot + 1e-20, reduction="none"
)
* cfg.learning_parameters.loss_coeffs_kldiv
)
@ -392,6 +394,7 @@ with torch.no_grad():
time_1: float = time.perf_counter()
my_loss.backward()
my_loss_float = my_loss.item()
time_2: float = time.perf_counter()
@ -447,7 +450,7 @@ with torch.no_grad():
network[id].norm_weights()
else:
network[id].weights = torch.tensor(
wf[id], dtype=torch.float64
wf[id], dtype=torch.float32
)
if cfg.network_structure.eps_xy_trainable[id] is True:
@ -458,7 +461,7 @@ with torch.no_grad():
network[id].mean_epsilon_xy()
else:
network[id].epsilon_xy = torch.tensor(
eps_xy[id], dtype=torch.float64
eps_xy[id], dtype=torch.float32
)
if cfg.network_structure.w_trainable[id] is True:
@ -504,13 +507,18 @@ with torch.no_grad():
# Let the torch learning rate scheduler update the
# learning rates of the optimiers
if cfg.learning_parameters.lr_scheduler_patience_w > 0:
if cfg.learning_parameters.lr_scheduler_use_performance is True:
lr_scheduler_wf.step(100.0 - performance)
else:
lr_scheduler_wf.step(my_loss_for_batch)
if cfg.learning_parameters.lr_scheduler_patience_eps_xy > 0:
if cfg.learning_parameters.lr_scheduler_use_performance is True:
lr_scheduler_eps.step(100.0 - performance)
else:
lr_scheduler_eps.step(my_loss_for_batch)
tb.add_scalar(
"Train Error", 100.0 - performance, cfg.learning_step
)
tb.add_scalar("Train Error", 100.0 - performance, cfg.learning_step)
tb.add_scalar("Train Loss", my_loss_for_batch, cfg.learning_step)
tb.add_scalar(
"Learning Rate Scale WF",
@ -568,7 +576,7 @@ with torch.no_grad():
h_h: torch.Tensor = network(
the_dataset_test.pattern_filter_test(h_x, cfg).type(
dtype=torch.float64
dtype=torch.float32
)
)

6
test_all.sh Normal file
View file

@ -0,0 +1,6 @@
#!/bin/bash
for i in $(seq 1 1 999)
do
echo $i
/home/davrot/P3.10/bin/python3 test_it.py mnist.json $i
done

View file

@ -182,7 +182,7 @@ for id in range(0, len(network)):
if len(file_to_load) == 1:
network[id].weights = torch.tensor(
np.load(file_to_load[0]),
dtype=torch.float64,
dtype=torch.float32,
)
wf[id] = np.load(file_to_load[0])
logging.info(f"File used: {file_to_load[0]}")
@ -200,7 +200,7 @@ for id in range(0, len(network)):
if len(file_to_load) == 1:
network[id].epsilon_xy = torch.tensor(
np.load(file_to_load[0]),
dtype=torch.float64,
dtype=torch.float32,
)
eps_xy[id] = np.load(file_to_load[0])
logging.info(f"File used: {file_to_load[0]}")
@ -219,7 +219,7 @@ for id in range(0, len(network)):
if os.path.exists(filename) is True:
network[id].weights = torch.tensor(
np.load(filename),
dtype=torch.float64,
dtype=torch.float32,
)
wf[id] = np.load(filename)
@ -234,7 +234,7 @@ for id in range(0, len(network)):
if os.path.exists(filename) is True:
network[id].epsilon_xy = torch.tensor(
np.load(filename),
dtype=torch.float64,
dtype=torch.float32,
)
eps_xy[id] = np.load(filename)
@ -256,7 +256,7 @@ with torch.no_grad():
time_0 = time.perf_counter()
h_h: torch.Tensor = network(
the_dataset_test.pattern_filter_test(h_x, cfg).type(dtype=torch.float64)
the_dataset_test.pattern_filter_test(h_x, cfg).type(dtype=torch.float32)
)
test_correct += (h_h.argmax(dim=1).squeeze() == h_x_labels).sum().numpy()
@ -271,6 +271,8 @@ with torch.no_grad():
f" with {performance/100:^6.2%} \t Time used: {time_measure_a:^6.2f}sec"
)
)
np_performance = np.array(performance)
np.save(f"{cfg.results_path}/{cfg.learning_step}.npy", np_performance)
# %%