New LR scheduler, minor fixes
This commit is contained in:
parent
654014b319
commit
b18a999cbf
4 changed files with 178 additions and 22 deletions
|
@ -74,9 +74,11 @@ class LearningParameters:
|
||||||
lr_scheduler_use_performance: bool = field(default=True)
|
lr_scheduler_use_performance: bool = field(default=True)
|
||||||
lr_scheduler_factor_w: float = field(default=0.75)
|
lr_scheduler_factor_w: float = field(default=0.75)
|
||||||
lr_scheduler_patience_w: int = field(default=-1)
|
lr_scheduler_patience_w: int = field(default=-1)
|
||||||
|
lr_scheduler_tau_w: int = field(default=10)
|
||||||
|
|
||||||
lr_scheduler_factor_eps_xy: float = field(default=0.75)
|
lr_scheduler_factor_eps_xy: float = field(default=0.75)
|
||||||
lr_scheduler_patience_eps_xy: int = field(default=-1)
|
lr_scheduler_patience_eps_xy: int = field(default=-1)
|
||||||
|
lr_scheduler_tau_eps_xy: int = field(default=10)
|
||||||
|
|
||||||
number_of_batches_for_one_update: int = field(default=1)
|
number_of_batches_for_one_update: int = field(default=1)
|
||||||
overload_path: str = field(default="./Previous")
|
overload_path: str = field(default="./Previous")
|
||||||
|
@ -144,8 +146,6 @@ class Config:
|
||||||
reduction_cooldown: float = field(default=25.0)
|
reduction_cooldown: float = field(default=25.0)
|
||||||
epsilon_0: float = field(default=1.0)
|
epsilon_0: float = field(default=1.0)
|
||||||
|
|
||||||
update_after_x_batch: float = field(default=1.0)
|
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
"""Post init determines the number of cores.
|
"""Post init determines the number of cores.
|
||||||
Creates the required directory and gives us an optimized
|
Creates the required directory and gives us an optimized
|
||||||
|
@ -183,4 +183,6 @@ class Config:
|
||||||
|
|
||||||
def get_update_after_x_pattern(self):
|
def get_update_after_x_pattern(self):
|
||||||
"""Tells us after how many pattern we need to update the weights."""
|
"""Tells us after how many pattern we need to update the weights."""
|
||||||
return self.batch_size * self.update_after_x_batch
|
return (
|
||||||
|
self.batch_size * self.learning_parameters.number_of_batches_for_one_update
|
||||||
|
)
|
||||||
|
|
6
SbS.py
6
SbS.py
|
@ -1147,9 +1147,9 @@ class FunctionalSbS(torch.autograd.Function):
|
||||||
input /= input.sum(dim=1, keepdim=True, dtype=torch.float32)
|
input /= input.sum(dim=1, keepdim=True, dtype=torch.float32)
|
||||||
|
|
||||||
# For debugging:
|
# For debugging:
|
||||||
# print(
|
# print(
|
||||||
# f"S: O: {output.min().item():e} {output.max().item():e} I: {input.min().item():e} {input.max().item():e} G: {grad_output.min().item():e} {grad_output.max().item():e}"
|
# f"S: O: {output.min().item():e} {output.max().item():e} I: {input.min().item():e} {input.max().item():e} G: {grad_output.min().item():e} {grad_output.max().item():e}"
|
||||||
# )
|
# )
|
||||||
|
|
||||||
epsilon_0_float: float = epsilon_0.item()
|
epsilon_0_float: float = epsilon_0.item()
|
||||||
|
|
||||||
|
|
106
SbSLRScheduler.py
Normal file
106
SbSLRScheduler.py
Normal file
|
@ -0,0 +1,106 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class SbSLRScheduler(torch.optim.lr_scheduler.ReduceLROnPlateau):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizer,
|
||||||
|
mode: str = "min",
|
||||||
|
factor: float = 0.1,
|
||||||
|
patience: int = 10,
|
||||||
|
threshold: float = 1e-4,
|
||||||
|
threshold_mode: str = "rel",
|
||||||
|
cooldown: int = 0,
|
||||||
|
min_lr: float = 0,
|
||||||
|
eps: float = 1e-8,
|
||||||
|
verbose: bool = False,
|
||||||
|
tau: float = 10,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
optimizer=optimizer,
|
||||||
|
mode=mode,
|
||||||
|
factor=factor,
|
||||||
|
patience=patience,
|
||||||
|
threshold=threshold,
|
||||||
|
threshold_mode=threshold_mode,
|
||||||
|
cooldown=cooldown,
|
||||||
|
min_lr=min_lr,
|
||||||
|
eps=eps,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
self.lowpass_tau: float = tau
|
||||||
|
self.lowpass_decay_value: float = 1.0 - (1.0 / self.lowpass_tau)
|
||||||
|
|
||||||
|
self.lowpass_number_of_steps: int = 0
|
||||||
|
self.loss_maximum_over_time: float | None = None
|
||||||
|
self.lowpass_memory: float = 0.0
|
||||||
|
self.lowpass_learning_rate_minimum_over_time: float | None = None
|
||||||
|
self.lowpass_learning_rate_minimum_over_time_past_step: float | None = None
|
||||||
|
|
||||||
|
self.previous_learning_rate: float | None = None
|
||||||
|
self.loss_normalized_past_step: float | None = None
|
||||||
|
|
||||||
|
def step(self, metrics, epoch=None) -> None:
|
||||||
|
|
||||||
|
loss = float(metrics)
|
||||||
|
|
||||||
|
if self.loss_maximum_over_time is None:
|
||||||
|
self.loss_maximum_over_time = loss
|
||||||
|
|
||||||
|
if self.loss_normalized_past_step is None:
|
||||||
|
self.loss_normalized_past_step = loss / self.loss_maximum_over_time
|
||||||
|
|
||||||
|
if self.previous_learning_rate is None:
|
||||||
|
self.previous_learning_rate = self.optimizer.param_groups[-1]["lr"] # type: ignore
|
||||||
|
|
||||||
|
# The parent lr scheduler controlls the basic learn rate
|
||||||
|
self.previous_learning_rate = self.optimizer.param_groups[-1]["lr"] # type: ignore
|
||||||
|
super().step(metrics=self.loss_normalized_past_step, epoch=epoch)
|
||||||
|
|
||||||
|
# If the parent changes the base learning rate,
|
||||||
|
# then we reset the adaptive part
|
||||||
|
if self.optimizer.param_groups[-1]["lr"] != self.previous_learning_rate: # type: ignore
|
||||||
|
self.previous_learning_rate = self.optimizer.param_groups[-1]["lr"] # type: ignore
|
||||||
|
|
||||||
|
self.lowpass_number_of_steps = 0
|
||||||
|
self.loss_maximum_over_time = None
|
||||||
|
self.lowpass_memory = 0.0
|
||||||
|
self.lowpass_learning_rate_minimum_over_time = None
|
||||||
|
self.lowpass_learning_rate_minimum_over_time_past_step = None
|
||||||
|
|
||||||
|
if self.loss_maximum_over_time is None:
|
||||||
|
self.loss_maximum_over_time = loss
|
||||||
|
else:
|
||||||
|
self.loss_maximum_over_time = max(self.loss_maximum_over_time, loss)
|
||||||
|
|
||||||
|
self.lowpass_number_of_steps += 1
|
||||||
|
|
||||||
|
self.lowpass_memory = self.lowpass_memory * self.lowpass_decay_value + (
|
||||||
|
loss / self.loss_maximum_over_time
|
||||||
|
) * (1.0 / self.lowpass_tau)
|
||||||
|
|
||||||
|
loss_normalized: float = self.lowpass_memory / (
|
||||||
|
1.0 - self.lowpass_decay_value ** float(self.lowpass_number_of_steps)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.lowpass_learning_rate_minimum_over_time is None:
|
||||||
|
self.lowpass_learning_rate_minimum_over_time = loss_normalized
|
||||||
|
else:
|
||||||
|
self.lowpass_learning_rate_minimum_over_time = min(
|
||||||
|
self.lowpass_learning_rate_minimum_over_time, loss_normalized
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.lowpass_learning_rate_minimum_over_time_past_step is None:
|
||||||
|
self.lowpass_learning_rate_minimum_over_time_past_step = (
|
||||||
|
self.lowpass_learning_rate_minimum_over_time
|
||||||
|
)
|
||||||
|
|
||||||
|
self.optimizer.param_groups[-1]["lr"] *= ( # type: ignore
|
||||||
|
self.lowpass_learning_rate_minimum_over_time
|
||||||
|
/ self.lowpass_learning_rate_minimum_over_time_past_step
|
||||||
|
)
|
||||||
|
self.lowpass_learning_rate_minimum_over_time_past_step = (
|
||||||
|
self.lowpass_learning_rate_minimum_over_time
|
||||||
|
)
|
||||||
|
self.loss_normalized_past_step = loss_normalized
|
80
learn_it.py
80
learn_it.py
|
@ -54,6 +54,14 @@ from SbS import SbS
|
||||||
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
try:
|
||||||
|
from SbSLRScheduler import SbSLRScheduler
|
||||||
|
|
||||||
|
sbs_lr_scheduler: bool = True
|
||||||
|
except Exception:
|
||||||
|
sbs_lr_scheduler = False
|
||||||
|
|
||||||
|
|
||||||
tb = SummaryWriter()
|
tb = SummaryWriter()
|
||||||
|
|
||||||
torch.set_default_dtype(torch.float32)
|
torch.set_default_dtype(torch.float32)
|
||||||
|
@ -264,6 +272,7 @@ for id in range(0, len(network)):
|
||||||
parameter_list_epsilon_xy.append(network[id]._epsilon_xy)
|
parameter_list_epsilon_xy.append(network[id]._epsilon_xy)
|
||||||
|
|
||||||
if cfg.learning_parameters.optimizer_name == "Adam":
|
if cfg.learning_parameters.optimizer_name == "Adam":
|
||||||
|
logging.info("Using optimizer: Adam")
|
||||||
if cfg.learning_parameters.learning_rate_gamma_w > 0:
|
if cfg.learning_parameters.learning_rate_gamma_w > 0:
|
||||||
optimizer_wf: torch.optim.Optimizer = torch.optim.Adam(
|
optimizer_wf: torch.optim.Optimizer = torch.optim.Adam(
|
||||||
parameter_list_weights,
|
parameter_list_weights,
|
||||||
|
@ -286,20 +295,56 @@ if cfg.learning_parameters.optimizer_name == "Adam":
|
||||||
else:
|
else:
|
||||||
raise Exception("Optimizer not implemented")
|
raise Exception("Optimizer not implemented")
|
||||||
|
|
||||||
if cfg.learning_parameters.lr_schedule_name == "ReduceLROnPlateau":
|
do_lr_scheduler_step: bool = True
|
||||||
if cfg.learning_parameters.lr_scheduler_patience_w > 0:
|
|
||||||
lr_scheduler_wf = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
||||||
optimizer_wf,
|
|
||||||
factor=cfg.learning_parameters.lr_scheduler_factor_w,
|
|
||||||
patience=cfg.learning_parameters.lr_scheduler_patience_w,
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.learning_parameters.lr_scheduler_patience_eps_xy > 0:
|
if cfg.learning_parameters.lr_schedule_name == "None":
|
||||||
lr_scheduler_eps = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
logging.info("Using lr scheduler: None")
|
||||||
optimizer_eps,
|
do_lr_scheduler_step = False
|
||||||
factor=cfg.learning_parameters.lr_scheduler_factor_eps_xy,
|
|
||||||
patience=cfg.learning_parameters.lr_scheduler_patience_eps_xy,
|
elif cfg.learning_parameters.lr_schedule_name == "ReduceLROnPlateau":
|
||||||
)
|
logging.info("Using lr scheduler: ReduceLROnPlateau")
|
||||||
|
|
||||||
|
assert cfg.learning_parameters.lr_scheduler_factor_w > 0
|
||||||
|
assert cfg.learning_parameters.lr_scheduler_factor_eps_xy > 0
|
||||||
|
assert cfg.learning_parameters.lr_scheduler_patience_w > 0
|
||||||
|
assert cfg.learning_parameters.lr_scheduler_patience_eps_xy > 0
|
||||||
|
|
||||||
|
lr_scheduler_wf = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
|
optimizer_wf,
|
||||||
|
factor=cfg.learning_parameters.lr_scheduler_factor_w,
|
||||||
|
patience=cfg.learning_parameters.lr_scheduler_patience_w,
|
||||||
|
)
|
||||||
|
|
||||||
|
lr_scheduler_eps = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
|
optimizer_eps,
|
||||||
|
factor=cfg.learning_parameters.lr_scheduler_factor_eps_xy,
|
||||||
|
patience=cfg.learning_parameters.lr_scheduler_patience_eps_xy,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif cfg.learning_parameters.lr_schedule_name == "SbSLRScheduler":
|
||||||
|
logging.info("Using lr scheduler: SbSLRScheduler")
|
||||||
|
|
||||||
|
assert cfg.learning_parameters.lr_scheduler_factor_w > 0
|
||||||
|
assert cfg.learning_parameters.lr_scheduler_factor_eps_xy > 0
|
||||||
|
assert cfg.learning_parameters.lr_scheduler_patience_w > 0
|
||||||
|
assert cfg.learning_parameters.lr_scheduler_patience_eps_xy > 0
|
||||||
|
|
||||||
|
if sbs_lr_scheduler is False:
|
||||||
|
raise Exception("lr_scheduler: SbSLRScheduler.py missing")
|
||||||
|
|
||||||
|
lr_scheduler_wf = SbSLRScheduler(
|
||||||
|
optimizer_wf,
|
||||||
|
factor=cfg.learning_parameters.lr_scheduler_factor_w,
|
||||||
|
patience=cfg.learning_parameters.lr_scheduler_patience_w,
|
||||||
|
tau=cfg.learning_parameters.lr_scheduler_tau_w,
|
||||||
|
)
|
||||||
|
|
||||||
|
lr_scheduler_eps = SbSLRScheduler(
|
||||||
|
optimizer_eps,
|
||||||
|
factor=cfg.learning_parameters.lr_scheduler_factor_eps_xy,
|
||||||
|
patience=cfg.learning_parameters.lr_scheduler_patience_eps_xy,
|
||||||
|
tau=cfg.learning_parameters.lr_scheduler_tau_eps_xy,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception("lr_scheduler not implemented")
|
raise Exception("lr_scheduler not implemented")
|
||||||
|
|
||||||
|
@ -433,7 +478,6 @@ with torch.no_grad():
|
||||||
train_number_of_processed_pattern
|
train_number_of_processed_pattern
|
||||||
>= cfg.get_update_after_x_pattern()
|
>= cfg.get_update_after_x_pattern()
|
||||||
):
|
):
|
||||||
logging.info("\t\t\t*** Updating the weights ***")
|
|
||||||
my_loss_for_batch: float = (
|
my_loss_for_batch: float = (
|
||||||
train_loss[0] / train_number_of_processed_pattern
|
train_loss[0] / train_number_of_processed_pattern
|
||||||
)
|
)
|
||||||
|
@ -506,13 +550,13 @@ with torch.no_grad():
|
||||||
|
|
||||||
# Let the torch learning rate scheduler update the
|
# Let the torch learning rate scheduler update the
|
||||||
# learning rates of the optimiers
|
# learning rates of the optimiers
|
||||||
if cfg.learning_parameters.lr_scheduler_patience_w > 0:
|
if do_lr_scheduler_step is True:
|
||||||
if cfg.learning_parameters.lr_scheduler_use_performance is True:
|
if cfg.learning_parameters.lr_scheduler_use_performance is True:
|
||||||
lr_scheduler_wf.step(100.0 - performance)
|
lr_scheduler_wf.step(100.0 - performance)
|
||||||
else:
|
else:
|
||||||
lr_scheduler_wf.step(my_loss_for_batch)
|
lr_scheduler_wf.step(my_loss_for_batch)
|
||||||
|
|
||||||
if cfg.learning_parameters.lr_scheduler_patience_eps_xy > 0:
|
if do_lr_scheduler_step is True:
|
||||||
if cfg.learning_parameters.lr_scheduler_use_performance is True:
|
if cfg.learning_parameters.lr_scheduler_use_performance is True:
|
||||||
lr_scheduler_eps.step(100.0 - performance)
|
lr_scheduler_eps.step(100.0 - performance)
|
||||||
else:
|
else:
|
||||||
|
@ -530,6 +574,10 @@ with torch.no_grad():
|
||||||
optimizer_eps.param_groups[-1]["lr"],
|
optimizer_eps.param_groups[-1]["lr"],
|
||||||
cfg.learning_step,
|
cfg.learning_step,
|
||||||
)
|
)
|
||||||
|
logging.info(
|
||||||
|
f"\t\t\tLearning rate: weights:{optimizer_wf.param_groups[-1]['lr']:^15.3e} \t epsilon xy:{optimizer_eps.param_groups[-1]['lr']:^15.3e}"
|
||||||
|
)
|
||||||
|
logging.info("\t\t\t*** Updating the weights ***")
|
||||||
|
|
||||||
cfg.learning_step += 1
|
cfg.learning_step += 1
|
||||||
train_loss = np.zeros((1), dtype=np.float32)
|
train_loss = np.zeros((1), dtype=np.float32)
|
||||||
|
|
Loading…
Reference in a new issue