diff --git a/SbSLRScheduler.py b/SbSLRScheduler.py deleted file mode 100644 index ea81038..0000000 --- a/SbSLRScheduler.py +++ /dev/null @@ -1,106 +0,0 @@ -import torch - - -class SbSLRScheduler(torch.optim.lr_scheduler.ReduceLROnPlateau): - def __init__( - self, - optimizer, - mode: str = "min", - factor: float = 0.1, - patience: int = 10, - threshold: float = 1e-4, - threshold_mode: str = "rel", - cooldown: int = 0, - min_lr: float = 0, - eps: float = 1e-8, - verbose: bool = False, - tau: float = 10, - ) -> None: - - super().__init__( - optimizer=optimizer, - mode=mode, - factor=factor, - patience=patience, - threshold=threshold, - threshold_mode=threshold_mode, - cooldown=cooldown, - min_lr=min_lr, - eps=eps, - verbose=verbose, - ) - self.lowpass_tau: float = tau - self.lowpass_decay_value: float = 1.0 - (1.0 / self.lowpass_tau) - - self.lowpass_number_of_steps: int = 0 - self.loss_maximum_over_time: float | None = None - self.lowpass_memory: float = 0.0 - self.lowpass_learning_rate_minimum_over_time: float | None = None - self.lowpass_learning_rate_minimum_over_time_past_step: float | None = None - - self.previous_learning_rate: float | None = None - self.loss_normalized_past_step: float | None = None - - def step(self, metrics, epoch=None) -> None: - - loss = float(metrics) - - if self.loss_maximum_over_time is None: - self.loss_maximum_over_time = loss - - if self.loss_normalized_past_step is None: - self.loss_normalized_past_step = loss / self.loss_maximum_over_time - - if self.previous_learning_rate is None: - self.previous_learning_rate = self.optimizer.param_groups[-1]["lr"] # type: ignore - - # The parent lr scheduler controlls the basic learn rate - self.previous_learning_rate = self.optimizer.param_groups[-1]["lr"] # type: ignore - super().step(metrics=self.loss_normalized_past_step, epoch=epoch) - - # If the parent changes the base learning rate, - # then we reset the adaptive part - if self.optimizer.param_groups[-1]["lr"] != self.previous_learning_rate: # type: ignore - self.previous_learning_rate = self.optimizer.param_groups[-1]["lr"] # type: ignore - - self.lowpass_number_of_steps = 0 - self.loss_maximum_over_time = None - self.lowpass_memory = 0.0 - self.lowpass_learning_rate_minimum_over_time = None - self.lowpass_learning_rate_minimum_over_time_past_step = None - - if self.loss_maximum_over_time is None: - self.loss_maximum_over_time = loss - else: - self.loss_maximum_over_time = max(self.loss_maximum_over_time, loss) - - self.lowpass_number_of_steps += 1 - - self.lowpass_memory = self.lowpass_memory * self.lowpass_decay_value + ( - loss / self.loss_maximum_over_time - ) * (1.0 / self.lowpass_tau) - - loss_normalized: float = self.lowpass_memory / ( - 1.0 - self.lowpass_decay_value ** float(self.lowpass_number_of_steps) - ) - - if self.lowpass_learning_rate_minimum_over_time is None: - self.lowpass_learning_rate_minimum_over_time = loss_normalized - else: - self.lowpass_learning_rate_minimum_over_time = min( - self.lowpass_learning_rate_minimum_over_time, loss_normalized - ) - - if self.lowpass_learning_rate_minimum_over_time_past_step is None: - self.lowpass_learning_rate_minimum_over_time_past_step = ( - self.lowpass_learning_rate_minimum_over_time - ) - - self.optimizer.param_groups[-1]["lr"] *= ( # type: ignore - self.lowpass_learning_rate_minimum_over_time - / self.lowpass_learning_rate_minimum_over_time_past_step - ) - self.lowpass_learning_rate_minimum_over_time_past_step = ( - self.lowpass_learning_rate_minimum_over_time - ) - self.loss_normalized_past_step = loss_normalized