New LR scheduler, minor fixes

This commit is contained in:
David Rotermund 2022-05-08 15:43:10 +02:00 committed by GitHub
parent 654014b319
commit b18a999cbf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 178 additions and 22 deletions

View file

@ -74,9 +74,11 @@ class LearningParameters:
lr_scheduler_use_performance: bool = field(default=True)
lr_scheduler_factor_w: float = field(default=0.75)
lr_scheduler_patience_w: int = field(default=-1)
lr_scheduler_tau_w: int = field(default=10)
lr_scheduler_factor_eps_xy: float = field(default=0.75)
lr_scheduler_patience_eps_xy: int = field(default=-1)
lr_scheduler_tau_eps_xy: int = field(default=10)
number_of_batches_for_one_update: int = field(default=1)
overload_path: str = field(default="./Previous")
@ -144,8 +146,6 @@ class Config:
reduction_cooldown: float = field(default=25.0)
epsilon_0: float = field(default=1.0)
update_after_x_batch: float = field(default=1.0)
def __post_init__(self) -> None:
"""Post init determines the number of cores.
Creates the required directory and gives us an optimized
@ -183,4 +183,6 @@ class Config:
def get_update_after_x_pattern(self):
"""Tells us after how many pattern we need to update the weights."""
return self.batch_size * self.update_after_x_batch
return (
self.batch_size * self.learning_parameters.number_of_batches_for_one_update
)

106
SbSLRScheduler.py Normal file
View file

@ -0,0 +1,106 @@
import torch
class SbSLRScheduler(torch.optim.lr_scheduler.ReduceLROnPlateau):
def __init__(
self,
optimizer,
mode: str = "min",
factor: float = 0.1,
patience: int = 10,
threshold: float = 1e-4,
threshold_mode: str = "rel",
cooldown: int = 0,
min_lr: float = 0,
eps: float = 1e-8,
verbose: bool = False,
tau: float = 10,
) -> None:
super().__init__(
optimizer=optimizer,
mode=mode,
factor=factor,
patience=patience,
threshold=threshold,
threshold_mode=threshold_mode,
cooldown=cooldown,
min_lr=min_lr,
eps=eps,
verbose=verbose,
)
self.lowpass_tau: float = tau
self.lowpass_decay_value: float = 1.0 - (1.0 / self.lowpass_tau)
self.lowpass_number_of_steps: int = 0
self.loss_maximum_over_time: float | None = None
self.lowpass_memory: float = 0.0
self.lowpass_learning_rate_minimum_over_time: float | None = None
self.lowpass_learning_rate_minimum_over_time_past_step: float | None = None
self.previous_learning_rate: float | None = None
self.loss_normalized_past_step: float | None = None
def step(self, metrics, epoch=None) -> None:
loss = float(metrics)
if self.loss_maximum_over_time is None:
self.loss_maximum_over_time = loss
if self.loss_normalized_past_step is None:
self.loss_normalized_past_step = loss / self.loss_maximum_over_time
if self.previous_learning_rate is None:
self.previous_learning_rate = self.optimizer.param_groups[-1]["lr"] # type: ignore
# The parent lr scheduler controlls the basic learn rate
self.previous_learning_rate = self.optimizer.param_groups[-1]["lr"] # type: ignore
super().step(metrics=self.loss_normalized_past_step, epoch=epoch)
# If the parent changes the base learning rate,
# then we reset the adaptive part
if self.optimizer.param_groups[-1]["lr"] != self.previous_learning_rate: # type: ignore
self.previous_learning_rate = self.optimizer.param_groups[-1]["lr"] # type: ignore
self.lowpass_number_of_steps = 0
self.loss_maximum_over_time = None
self.lowpass_memory = 0.0
self.lowpass_learning_rate_minimum_over_time = None
self.lowpass_learning_rate_minimum_over_time_past_step = None
if self.loss_maximum_over_time is None:
self.loss_maximum_over_time = loss
else:
self.loss_maximum_over_time = max(self.loss_maximum_over_time, loss)
self.lowpass_number_of_steps += 1
self.lowpass_memory = self.lowpass_memory * self.lowpass_decay_value + (
loss / self.loss_maximum_over_time
) * (1.0 / self.lowpass_tau)
loss_normalized: float = self.lowpass_memory / (
1.0 - self.lowpass_decay_value ** float(self.lowpass_number_of_steps)
)
if self.lowpass_learning_rate_minimum_over_time is None:
self.lowpass_learning_rate_minimum_over_time = loss_normalized
else:
self.lowpass_learning_rate_minimum_over_time = min(
self.lowpass_learning_rate_minimum_over_time, loss_normalized
)
if self.lowpass_learning_rate_minimum_over_time_past_step is None:
self.lowpass_learning_rate_minimum_over_time_past_step = (
self.lowpass_learning_rate_minimum_over_time
)
self.optimizer.param_groups[-1]["lr"] *= ( # type: ignore
self.lowpass_learning_rate_minimum_over_time
/ self.lowpass_learning_rate_minimum_over_time_past_step
)
self.lowpass_learning_rate_minimum_over_time_past_step = (
self.lowpass_learning_rate_minimum_over_time
)
self.loss_normalized_past_step = loss_normalized

View file

@ -54,6 +54,14 @@ from SbS import SbS
from torch.utils.tensorboard import SummaryWriter
try:
from SbSLRScheduler import SbSLRScheduler
sbs_lr_scheduler: bool = True
except Exception:
sbs_lr_scheduler = False
tb = SummaryWriter()
torch.set_default_dtype(torch.float32)
@ -264,6 +272,7 @@ for id in range(0, len(network)):
parameter_list_epsilon_xy.append(network[id]._epsilon_xy)
if cfg.learning_parameters.optimizer_name == "Adam":
logging.info("Using optimizer: Adam")
if cfg.learning_parameters.learning_rate_gamma_w > 0:
optimizer_wf: torch.optim.Optimizer = torch.optim.Adam(
parameter_list_weights,
@ -286,20 +295,56 @@ if cfg.learning_parameters.optimizer_name == "Adam":
else:
raise Exception("Optimizer not implemented")
if cfg.learning_parameters.lr_schedule_name == "ReduceLROnPlateau":
if cfg.learning_parameters.lr_scheduler_patience_w > 0:
do_lr_scheduler_step: bool = True
if cfg.learning_parameters.lr_schedule_name == "None":
logging.info("Using lr scheduler: None")
do_lr_scheduler_step = False
elif cfg.learning_parameters.lr_schedule_name == "ReduceLROnPlateau":
logging.info("Using lr scheduler: ReduceLROnPlateau")
assert cfg.learning_parameters.lr_scheduler_factor_w > 0
assert cfg.learning_parameters.lr_scheduler_factor_eps_xy > 0
assert cfg.learning_parameters.lr_scheduler_patience_w > 0
assert cfg.learning_parameters.lr_scheduler_patience_eps_xy > 0
lr_scheduler_wf = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer_wf,
factor=cfg.learning_parameters.lr_scheduler_factor_w,
patience=cfg.learning_parameters.lr_scheduler_patience_w,
)
if cfg.learning_parameters.lr_scheduler_patience_eps_xy > 0:
lr_scheduler_eps = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer_eps,
factor=cfg.learning_parameters.lr_scheduler_factor_eps_xy,
patience=cfg.learning_parameters.lr_scheduler_patience_eps_xy,
)
elif cfg.learning_parameters.lr_schedule_name == "SbSLRScheduler":
logging.info("Using lr scheduler: SbSLRScheduler")
assert cfg.learning_parameters.lr_scheduler_factor_w > 0
assert cfg.learning_parameters.lr_scheduler_factor_eps_xy > 0
assert cfg.learning_parameters.lr_scheduler_patience_w > 0
assert cfg.learning_parameters.lr_scheduler_patience_eps_xy > 0
if sbs_lr_scheduler is False:
raise Exception("lr_scheduler: SbSLRScheduler.py missing")
lr_scheduler_wf = SbSLRScheduler(
optimizer_wf,
factor=cfg.learning_parameters.lr_scheduler_factor_w,
patience=cfg.learning_parameters.lr_scheduler_patience_w,
tau=cfg.learning_parameters.lr_scheduler_tau_w,
)
lr_scheduler_eps = SbSLRScheduler(
optimizer_eps,
factor=cfg.learning_parameters.lr_scheduler_factor_eps_xy,
patience=cfg.learning_parameters.lr_scheduler_patience_eps_xy,
tau=cfg.learning_parameters.lr_scheduler_tau_eps_xy,
)
else:
raise Exception("lr_scheduler not implemented")
@ -433,7 +478,6 @@ with torch.no_grad():
train_number_of_processed_pattern
>= cfg.get_update_after_x_pattern()
):
logging.info("\t\t\t*** Updating the weights ***")
my_loss_for_batch: float = (
train_loss[0] / train_number_of_processed_pattern
)
@ -506,13 +550,13 @@ with torch.no_grad():
# Let the torch learning rate scheduler update the
# learning rates of the optimiers
if cfg.learning_parameters.lr_scheduler_patience_w > 0:
if do_lr_scheduler_step is True:
if cfg.learning_parameters.lr_scheduler_use_performance is True:
lr_scheduler_wf.step(100.0 - performance)
else:
lr_scheduler_wf.step(my_loss_for_batch)
if cfg.learning_parameters.lr_scheduler_patience_eps_xy > 0:
if do_lr_scheduler_step is True:
if cfg.learning_parameters.lr_scheduler_use_performance is True:
lr_scheduler_eps.step(100.0 - performance)
else:
@ -530,6 +574,10 @@ with torch.no_grad():
optimizer_eps.param_groups[-1]["lr"],
cfg.learning_step,
)
logging.info(
f"\t\t\tLearning rate: weights:{optimizer_wf.param_groups[-1]['lr']:^15.3e} \t epsilon xy:{optimizer_eps.param_groups[-1]['lr']:^15.3e}"
)
logging.info("\t\t\t*** Updating the weights ***")
cfg.learning_step += 1
train_loss = np.zeros((1), dtype=np.float32)