New LR scheduler, minor fixes

This commit is contained in:
David Rotermund 2022-05-08 15:43:10 +02:00 committed by GitHub
parent 654014b319
commit b18a999cbf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 178 additions and 22 deletions

View file

@ -74,9 +74,11 @@ class LearningParameters:
lr_scheduler_use_performance: bool = field(default=True) lr_scheduler_use_performance: bool = field(default=True)
lr_scheduler_factor_w: float = field(default=0.75) lr_scheduler_factor_w: float = field(default=0.75)
lr_scheduler_patience_w: int = field(default=-1) lr_scheduler_patience_w: int = field(default=-1)
lr_scheduler_tau_w: int = field(default=10)
lr_scheduler_factor_eps_xy: float = field(default=0.75) lr_scheduler_factor_eps_xy: float = field(default=0.75)
lr_scheduler_patience_eps_xy: int = field(default=-1) lr_scheduler_patience_eps_xy: int = field(default=-1)
lr_scheduler_tau_eps_xy: int = field(default=10)
number_of_batches_for_one_update: int = field(default=1) number_of_batches_for_one_update: int = field(default=1)
overload_path: str = field(default="./Previous") overload_path: str = field(default="./Previous")
@ -144,8 +146,6 @@ class Config:
reduction_cooldown: float = field(default=25.0) reduction_cooldown: float = field(default=25.0)
epsilon_0: float = field(default=1.0) epsilon_0: float = field(default=1.0)
update_after_x_batch: float = field(default=1.0)
def __post_init__(self) -> None: def __post_init__(self) -> None:
"""Post init determines the number of cores. """Post init determines the number of cores.
Creates the required directory and gives us an optimized Creates the required directory and gives us an optimized
@ -183,4 +183,6 @@ class Config:
def get_update_after_x_pattern(self): def get_update_after_x_pattern(self):
"""Tells us after how many pattern we need to update the weights.""" """Tells us after how many pattern we need to update the weights."""
return self.batch_size * self.update_after_x_batch return (
self.batch_size * self.learning_parameters.number_of_batches_for_one_update
)

106
SbSLRScheduler.py Normal file
View file

@ -0,0 +1,106 @@
import torch
class SbSLRScheduler(torch.optim.lr_scheduler.ReduceLROnPlateau):
def __init__(
self,
optimizer,
mode: str = "min",
factor: float = 0.1,
patience: int = 10,
threshold: float = 1e-4,
threshold_mode: str = "rel",
cooldown: int = 0,
min_lr: float = 0,
eps: float = 1e-8,
verbose: bool = False,
tau: float = 10,
) -> None:
super().__init__(
optimizer=optimizer,
mode=mode,
factor=factor,
patience=patience,
threshold=threshold,
threshold_mode=threshold_mode,
cooldown=cooldown,
min_lr=min_lr,
eps=eps,
verbose=verbose,
)
self.lowpass_tau: float = tau
self.lowpass_decay_value: float = 1.0 - (1.0 / self.lowpass_tau)
self.lowpass_number_of_steps: int = 0
self.loss_maximum_over_time: float | None = None
self.lowpass_memory: float = 0.0
self.lowpass_learning_rate_minimum_over_time: float | None = None
self.lowpass_learning_rate_minimum_over_time_past_step: float | None = None
self.previous_learning_rate: float | None = None
self.loss_normalized_past_step: float | None = None
def step(self, metrics, epoch=None) -> None:
loss = float(metrics)
if self.loss_maximum_over_time is None:
self.loss_maximum_over_time = loss
if self.loss_normalized_past_step is None:
self.loss_normalized_past_step = loss / self.loss_maximum_over_time
if self.previous_learning_rate is None:
self.previous_learning_rate = self.optimizer.param_groups[-1]["lr"] # type: ignore
# The parent lr scheduler controlls the basic learn rate
self.previous_learning_rate = self.optimizer.param_groups[-1]["lr"] # type: ignore
super().step(metrics=self.loss_normalized_past_step, epoch=epoch)
# If the parent changes the base learning rate,
# then we reset the adaptive part
if self.optimizer.param_groups[-1]["lr"] != self.previous_learning_rate: # type: ignore
self.previous_learning_rate = self.optimizer.param_groups[-1]["lr"] # type: ignore
self.lowpass_number_of_steps = 0
self.loss_maximum_over_time = None
self.lowpass_memory = 0.0
self.lowpass_learning_rate_minimum_over_time = None
self.lowpass_learning_rate_minimum_over_time_past_step = None
if self.loss_maximum_over_time is None:
self.loss_maximum_over_time = loss
else:
self.loss_maximum_over_time = max(self.loss_maximum_over_time, loss)
self.lowpass_number_of_steps += 1
self.lowpass_memory = self.lowpass_memory * self.lowpass_decay_value + (
loss / self.loss_maximum_over_time
) * (1.0 / self.lowpass_tau)
loss_normalized: float = self.lowpass_memory / (
1.0 - self.lowpass_decay_value ** float(self.lowpass_number_of_steps)
)
if self.lowpass_learning_rate_minimum_over_time is None:
self.lowpass_learning_rate_minimum_over_time = loss_normalized
else:
self.lowpass_learning_rate_minimum_over_time = min(
self.lowpass_learning_rate_minimum_over_time, loss_normalized
)
if self.lowpass_learning_rate_minimum_over_time_past_step is None:
self.lowpass_learning_rate_minimum_over_time_past_step = (
self.lowpass_learning_rate_minimum_over_time
)
self.optimizer.param_groups[-1]["lr"] *= ( # type: ignore
self.lowpass_learning_rate_minimum_over_time
/ self.lowpass_learning_rate_minimum_over_time_past_step
)
self.lowpass_learning_rate_minimum_over_time_past_step = (
self.lowpass_learning_rate_minimum_over_time
)
self.loss_normalized_past_step = loss_normalized

View file

@ -54,6 +54,14 @@ from SbS import SbS
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
try:
from SbSLRScheduler import SbSLRScheduler
sbs_lr_scheduler: bool = True
except Exception:
sbs_lr_scheduler = False
tb = SummaryWriter() tb = SummaryWriter()
torch.set_default_dtype(torch.float32) torch.set_default_dtype(torch.float32)
@ -264,6 +272,7 @@ for id in range(0, len(network)):
parameter_list_epsilon_xy.append(network[id]._epsilon_xy) parameter_list_epsilon_xy.append(network[id]._epsilon_xy)
if cfg.learning_parameters.optimizer_name == "Adam": if cfg.learning_parameters.optimizer_name == "Adam":
logging.info("Using optimizer: Adam")
if cfg.learning_parameters.learning_rate_gamma_w > 0: if cfg.learning_parameters.learning_rate_gamma_w > 0:
optimizer_wf: torch.optim.Optimizer = torch.optim.Adam( optimizer_wf: torch.optim.Optimizer = torch.optim.Adam(
parameter_list_weights, parameter_list_weights,
@ -286,20 +295,56 @@ if cfg.learning_parameters.optimizer_name == "Adam":
else: else:
raise Exception("Optimizer not implemented") raise Exception("Optimizer not implemented")
if cfg.learning_parameters.lr_schedule_name == "ReduceLROnPlateau": do_lr_scheduler_step: bool = True
if cfg.learning_parameters.lr_scheduler_patience_w > 0:
if cfg.learning_parameters.lr_schedule_name == "None":
logging.info("Using lr scheduler: None")
do_lr_scheduler_step = False
elif cfg.learning_parameters.lr_schedule_name == "ReduceLROnPlateau":
logging.info("Using lr scheduler: ReduceLROnPlateau")
assert cfg.learning_parameters.lr_scheduler_factor_w > 0
assert cfg.learning_parameters.lr_scheduler_factor_eps_xy > 0
assert cfg.learning_parameters.lr_scheduler_patience_w > 0
assert cfg.learning_parameters.lr_scheduler_patience_eps_xy > 0
lr_scheduler_wf = torch.optim.lr_scheduler.ReduceLROnPlateau( lr_scheduler_wf = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer_wf, optimizer_wf,
factor=cfg.learning_parameters.lr_scheduler_factor_w, factor=cfg.learning_parameters.lr_scheduler_factor_w,
patience=cfg.learning_parameters.lr_scheduler_patience_w, patience=cfg.learning_parameters.lr_scheduler_patience_w,
) )
if cfg.learning_parameters.lr_scheduler_patience_eps_xy > 0:
lr_scheduler_eps = torch.optim.lr_scheduler.ReduceLROnPlateau( lr_scheduler_eps = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer_eps, optimizer_eps,
factor=cfg.learning_parameters.lr_scheduler_factor_eps_xy, factor=cfg.learning_parameters.lr_scheduler_factor_eps_xy,
patience=cfg.learning_parameters.lr_scheduler_patience_eps_xy, patience=cfg.learning_parameters.lr_scheduler_patience_eps_xy,
) )
elif cfg.learning_parameters.lr_schedule_name == "SbSLRScheduler":
logging.info("Using lr scheduler: SbSLRScheduler")
assert cfg.learning_parameters.lr_scheduler_factor_w > 0
assert cfg.learning_parameters.lr_scheduler_factor_eps_xy > 0
assert cfg.learning_parameters.lr_scheduler_patience_w > 0
assert cfg.learning_parameters.lr_scheduler_patience_eps_xy > 0
if sbs_lr_scheduler is False:
raise Exception("lr_scheduler: SbSLRScheduler.py missing")
lr_scheduler_wf = SbSLRScheduler(
optimizer_wf,
factor=cfg.learning_parameters.lr_scheduler_factor_w,
patience=cfg.learning_parameters.lr_scheduler_patience_w,
tau=cfg.learning_parameters.lr_scheduler_tau_w,
)
lr_scheduler_eps = SbSLRScheduler(
optimizer_eps,
factor=cfg.learning_parameters.lr_scheduler_factor_eps_xy,
patience=cfg.learning_parameters.lr_scheduler_patience_eps_xy,
tau=cfg.learning_parameters.lr_scheduler_tau_eps_xy,
)
else: else:
raise Exception("lr_scheduler not implemented") raise Exception("lr_scheduler not implemented")
@ -433,7 +478,6 @@ with torch.no_grad():
train_number_of_processed_pattern train_number_of_processed_pattern
>= cfg.get_update_after_x_pattern() >= cfg.get_update_after_x_pattern()
): ):
logging.info("\t\t\t*** Updating the weights ***")
my_loss_for_batch: float = ( my_loss_for_batch: float = (
train_loss[0] / train_number_of_processed_pattern train_loss[0] / train_number_of_processed_pattern
) )
@ -506,13 +550,13 @@ with torch.no_grad():
# Let the torch learning rate scheduler update the # Let the torch learning rate scheduler update the
# learning rates of the optimiers # learning rates of the optimiers
if cfg.learning_parameters.lr_scheduler_patience_w > 0: if do_lr_scheduler_step is True:
if cfg.learning_parameters.lr_scheduler_use_performance is True: if cfg.learning_parameters.lr_scheduler_use_performance is True:
lr_scheduler_wf.step(100.0 - performance) lr_scheduler_wf.step(100.0 - performance)
else: else:
lr_scheduler_wf.step(my_loss_for_batch) lr_scheduler_wf.step(my_loss_for_batch)
if cfg.learning_parameters.lr_scheduler_patience_eps_xy > 0: if do_lr_scheduler_step is True:
if cfg.learning_parameters.lr_scheduler_use_performance is True: if cfg.learning_parameters.lr_scheduler_use_performance is True:
lr_scheduler_eps.step(100.0 - performance) lr_scheduler_eps.step(100.0 - performance)
else: else:
@ -530,6 +574,10 @@ with torch.no_grad():
optimizer_eps.param_groups[-1]["lr"], optimizer_eps.param_groups[-1]["lr"],
cfg.learning_step, cfg.learning_step,
) )
logging.info(
f"\t\t\tLearning rate: weights:{optimizer_wf.param_groups[-1]['lr']:^15.3e} \t epsilon xy:{optimizer_eps.param_groups[-1]['lr']:^15.3e}"
)
logging.info("\t\t\t*** Updating the weights ***")
cfg.learning_step += 1 cfg.learning_step += 1
train_loss = np.zeros((1), dtype=np.float32) train_loss = np.zeros((1), dtype=np.float32)