diff --git a/Parameter.py b/Parameter.py index 61a9346..d9d7016 100644 --- a/Parameter.py +++ b/Parameter.py @@ -74,9 +74,11 @@ class LearningParameters: lr_scheduler_use_performance: bool = field(default=True) lr_scheduler_factor_w: float = field(default=0.75) lr_scheduler_patience_w: int = field(default=-1) + lr_scheduler_tau_w: int = field(default=10) lr_scheduler_factor_eps_xy: float = field(default=0.75) lr_scheduler_patience_eps_xy: int = field(default=-1) + lr_scheduler_tau_eps_xy: int = field(default=10) number_of_batches_for_one_update: int = field(default=1) overload_path: str = field(default="./Previous") @@ -144,8 +146,6 @@ class Config: reduction_cooldown: float = field(default=25.0) epsilon_0: float = field(default=1.0) - update_after_x_batch: float = field(default=1.0) - def __post_init__(self) -> None: """Post init determines the number of cores. Creates the required directory and gives us an optimized @@ -183,4 +183,6 @@ class Config: def get_update_after_x_pattern(self): """Tells us after how many pattern we need to update the weights.""" - return self.batch_size * self.update_after_x_batch + return ( + self.batch_size * self.learning_parameters.number_of_batches_for_one_update + ) diff --git a/SbS.py b/SbS.py index df50b28..59cd9fe 100644 --- a/SbS.py +++ b/SbS.py @@ -1147,9 +1147,9 @@ class FunctionalSbS(torch.autograd.Function): input /= input.sum(dim=1, keepdim=True, dtype=torch.float32) # For debugging: -# print( -# f"S: O: {output.min().item():e} {output.max().item():e} I: {input.min().item():e} {input.max().item():e} G: {grad_output.min().item():e} {grad_output.max().item():e}" -# ) + # print( + # f"S: O: {output.min().item():e} {output.max().item():e} I: {input.min().item():e} {input.max().item():e} G: {grad_output.min().item():e} {grad_output.max().item():e}" + # ) epsilon_0_float: float = epsilon_0.item() diff --git a/SbSLRScheduler.py b/SbSLRScheduler.py new file mode 100644 index 0000000..ea81038 --- /dev/null +++ b/SbSLRScheduler.py @@ -0,0 +1,106 @@ +import torch + + +class SbSLRScheduler(torch.optim.lr_scheduler.ReduceLROnPlateau): + def __init__( + self, + optimizer, + mode: str = "min", + factor: float = 0.1, + patience: int = 10, + threshold: float = 1e-4, + threshold_mode: str = "rel", + cooldown: int = 0, + min_lr: float = 0, + eps: float = 1e-8, + verbose: bool = False, + tau: float = 10, + ) -> None: + + super().__init__( + optimizer=optimizer, + mode=mode, + factor=factor, + patience=patience, + threshold=threshold, + threshold_mode=threshold_mode, + cooldown=cooldown, + min_lr=min_lr, + eps=eps, + verbose=verbose, + ) + self.lowpass_tau: float = tau + self.lowpass_decay_value: float = 1.0 - (1.0 / self.lowpass_tau) + + self.lowpass_number_of_steps: int = 0 + self.loss_maximum_over_time: float | None = None + self.lowpass_memory: float = 0.0 + self.lowpass_learning_rate_minimum_over_time: float | None = None + self.lowpass_learning_rate_minimum_over_time_past_step: float | None = None + + self.previous_learning_rate: float | None = None + self.loss_normalized_past_step: float | None = None + + def step(self, metrics, epoch=None) -> None: + + loss = float(metrics) + + if self.loss_maximum_over_time is None: + self.loss_maximum_over_time = loss + + if self.loss_normalized_past_step is None: + self.loss_normalized_past_step = loss / self.loss_maximum_over_time + + if self.previous_learning_rate is None: + self.previous_learning_rate = self.optimizer.param_groups[-1]["lr"] # type: ignore + + # The parent lr scheduler controlls the basic learn rate + self.previous_learning_rate = self.optimizer.param_groups[-1]["lr"] # type: ignore + super().step(metrics=self.loss_normalized_past_step, epoch=epoch) + + # If the parent changes the base learning rate, + # then we reset the adaptive part + if self.optimizer.param_groups[-1]["lr"] != self.previous_learning_rate: # type: ignore + self.previous_learning_rate = self.optimizer.param_groups[-1]["lr"] # type: ignore + + self.lowpass_number_of_steps = 0 + self.loss_maximum_over_time = None + self.lowpass_memory = 0.0 + self.lowpass_learning_rate_minimum_over_time = None + self.lowpass_learning_rate_minimum_over_time_past_step = None + + if self.loss_maximum_over_time is None: + self.loss_maximum_over_time = loss + else: + self.loss_maximum_over_time = max(self.loss_maximum_over_time, loss) + + self.lowpass_number_of_steps += 1 + + self.lowpass_memory = self.lowpass_memory * self.lowpass_decay_value + ( + loss / self.loss_maximum_over_time + ) * (1.0 / self.lowpass_tau) + + loss_normalized: float = self.lowpass_memory / ( + 1.0 - self.lowpass_decay_value ** float(self.lowpass_number_of_steps) + ) + + if self.lowpass_learning_rate_minimum_over_time is None: + self.lowpass_learning_rate_minimum_over_time = loss_normalized + else: + self.lowpass_learning_rate_minimum_over_time = min( + self.lowpass_learning_rate_minimum_over_time, loss_normalized + ) + + if self.lowpass_learning_rate_minimum_over_time_past_step is None: + self.lowpass_learning_rate_minimum_over_time_past_step = ( + self.lowpass_learning_rate_minimum_over_time + ) + + self.optimizer.param_groups[-1]["lr"] *= ( # type: ignore + self.lowpass_learning_rate_minimum_over_time + / self.lowpass_learning_rate_minimum_over_time_past_step + ) + self.lowpass_learning_rate_minimum_over_time_past_step = ( + self.lowpass_learning_rate_minimum_over_time + ) + self.loss_normalized_past_step = loss_normalized diff --git a/learn_it.py b/learn_it.py index a3cd228..8123404 100644 --- a/learn_it.py +++ b/learn_it.py @@ -54,6 +54,14 @@ from SbS import SbS from torch.utils.tensorboard import SummaryWriter +try: + from SbSLRScheduler import SbSLRScheduler + + sbs_lr_scheduler: bool = True +except Exception: + sbs_lr_scheduler = False + + tb = SummaryWriter() torch.set_default_dtype(torch.float32) @@ -264,6 +272,7 @@ for id in range(0, len(network)): parameter_list_epsilon_xy.append(network[id]._epsilon_xy) if cfg.learning_parameters.optimizer_name == "Adam": + logging.info("Using optimizer: Adam") if cfg.learning_parameters.learning_rate_gamma_w > 0: optimizer_wf: torch.optim.Optimizer = torch.optim.Adam( parameter_list_weights, @@ -286,20 +295,56 @@ if cfg.learning_parameters.optimizer_name == "Adam": else: raise Exception("Optimizer not implemented") -if cfg.learning_parameters.lr_schedule_name == "ReduceLROnPlateau": - if cfg.learning_parameters.lr_scheduler_patience_w > 0: - lr_scheduler_wf = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer_wf, - factor=cfg.learning_parameters.lr_scheduler_factor_w, - patience=cfg.learning_parameters.lr_scheduler_patience_w, - ) +do_lr_scheduler_step: bool = True - if cfg.learning_parameters.lr_scheduler_patience_eps_xy > 0: - lr_scheduler_eps = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer_eps, - factor=cfg.learning_parameters.lr_scheduler_factor_eps_xy, - patience=cfg.learning_parameters.lr_scheduler_patience_eps_xy, - ) +if cfg.learning_parameters.lr_schedule_name == "None": + logging.info("Using lr scheduler: None") + do_lr_scheduler_step = False + +elif cfg.learning_parameters.lr_schedule_name == "ReduceLROnPlateau": + logging.info("Using lr scheduler: ReduceLROnPlateau") + + assert cfg.learning_parameters.lr_scheduler_factor_w > 0 + assert cfg.learning_parameters.lr_scheduler_factor_eps_xy > 0 + assert cfg.learning_parameters.lr_scheduler_patience_w > 0 + assert cfg.learning_parameters.lr_scheduler_patience_eps_xy > 0 + + lr_scheduler_wf = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer_wf, + factor=cfg.learning_parameters.lr_scheduler_factor_w, + patience=cfg.learning_parameters.lr_scheduler_patience_w, + ) + + lr_scheduler_eps = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer_eps, + factor=cfg.learning_parameters.lr_scheduler_factor_eps_xy, + patience=cfg.learning_parameters.lr_scheduler_patience_eps_xy, + ) + +elif cfg.learning_parameters.lr_schedule_name == "SbSLRScheduler": + logging.info("Using lr scheduler: SbSLRScheduler") + + assert cfg.learning_parameters.lr_scheduler_factor_w > 0 + assert cfg.learning_parameters.lr_scheduler_factor_eps_xy > 0 + assert cfg.learning_parameters.lr_scheduler_patience_w > 0 + assert cfg.learning_parameters.lr_scheduler_patience_eps_xy > 0 + + if sbs_lr_scheduler is False: + raise Exception("lr_scheduler: SbSLRScheduler.py missing") + + lr_scheduler_wf = SbSLRScheduler( + optimizer_wf, + factor=cfg.learning_parameters.lr_scheduler_factor_w, + patience=cfg.learning_parameters.lr_scheduler_patience_w, + tau=cfg.learning_parameters.lr_scheduler_tau_w, + ) + + lr_scheduler_eps = SbSLRScheduler( + optimizer_eps, + factor=cfg.learning_parameters.lr_scheduler_factor_eps_xy, + patience=cfg.learning_parameters.lr_scheduler_patience_eps_xy, + tau=cfg.learning_parameters.lr_scheduler_tau_eps_xy, + ) else: raise Exception("lr_scheduler not implemented") @@ -433,7 +478,6 @@ with torch.no_grad(): train_number_of_processed_pattern >= cfg.get_update_after_x_pattern() ): - logging.info("\t\t\t*** Updating the weights ***") my_loss_for_batch: float = ( train_loss[0] / train_number_of_processed_pattern ) @@ -506,13 +550,13 @@ with torch.no_grad(): # Let the torch learning rate scheduler update the # learning rates of the optimiers - if cfg.learning_parameters.lr_scheduler_patience_w > 0: + if do_lr_scheduler_step is True: if cfg.learning_parameters.lr_scheduler_use_performance is True: lr_scheduler_wf.step(100.0 - performance) else: lr_scheduler_wf.step(my_loss_for_batch) - if cfg.learning_parameters.lr_scheduler_patience_eps_xy > 0: + if do_lr_scheduler_step is True: if cfg.learning_parameters.lr_scheduler_use_performance is True: lr_scheduler_eps.step(100.0 - performance) else: @@ -530,6 +574,10 @@ with torch.no_grad(): optimizer_eps.param_groups[-1]["lr"], cfg.learning_step, ) + logging.info( + f"\t\t\tLearning rate: weights:{optimizer_wf.param_groups[-1]['lr']:^15.3e} \t epsilon xy:{optimizer_eps.param_groups[-1]['lr']:^15.3e}" + ) + logging.info("\t\t\t*** Updating the weights ***") cfg.learning_step += 1 train_loss = np.zeros((1), dtype=np.float32)