pytorch-sbs/SbSLRScheduler.py
2022-05-08 15:43:10 +02:00

106 lines
3.9 KiB
Python

import torch
class SbSLRScheduler(torch.optim.lr_scheduler.ReduceLROnPlateau):
def __init__(
self,
optimizer,
mode: str = "min",
factor: float = 0.1,
patience: int = 10,
threshold: float = 1e-4,
threshold_mode: str = "rel",
cooldown: int = 0,
min_lr: float = 0,
eps: float = 1e-8,
verbose: bool = False,
tau: float = 10,
) -> None:
super().__init__(
optimizer=optimizer,
mode=mode,
factor=factor,
patience=patience,
threshold=threshold,
threshold_mode=threshold_mode,
cooldown=cooldown,
min_lr=min_lr,
eps=eps,
verbose=verbose,
)
self.lowpass_tau: float = tau
self.lowpass_decay_value: float = 1.0 - (1.0 / self.lowpass_tau)
self.lowpass_number_of_steps: int = 0
self.loss_maximum_over_time: float | None = None
self.lowpass_memory: float = 0.0
self.lowpass_learning_rate_minimum_over_time: float | None = None
self.lowpass_learning_rate_minimum_over_time_past_step: float | None = None
self.previous_learning_rate: float | None = None
self.loss_normalized_past_step: float | None = None
def step(self, metrics, epoch=None) -> None:
loss = float(metrics)
if self.loss_maximum_over_time is None:
self.loss_maximum_over_time = loss
if self.loss_normalized_past_step is None:
self.loss_normalized_past_step = loss / self.loss_maximum_over_time
if self.previous_learning_rate is None:
self.previous_learning_rate = self.optimizer.param_groups[-1]["lr"] # type: ignore
# The parent lr scheduler controlls the basic learn rate
self.previous_learning_rate = self.optimizer.param_groups[-1]["lr"] # type: ignore
super().step(metrics=self.loss_normalized_past_step, epoch=epoch)
# If the parent changes the base learning rate,
# then we reset the adaptive part
if self.optimizer.param_groups[-1]["lr"] != self.previous_learning_rate: # type: ignore
self.previous_learning_rate = self.optimizer.param_groups[-1]["lr"] # type: ignore
self.lowpass_number_of_steps = 0
self.loss_maximum_over_time = None
self.lowpass_memory = 0.0
self.lowpass_learning_rate_minimum_over_time = None
self.lowpass_learning_rate_minimum_over_time_past_step = None
if self.loss_maximum_over_time is None:
self.loss_maximum_over_time = loss
else:
self.loss_maximum_over_time = max(self.loss_maximum_over_time, loss)
self.lowpass_number_of_steps += 1
self.lowpass_memory = self.lowpass_memory * self.lowpass_decay_value + (
loss / self.loss_maximum_over_time
) * (1.0 / self.lowpass_tau)
loss_normalized: float = self.lowpass_memory / (
1.0 - self.lowpass_decay_value ** float(self.lowpass_number_of_steps)
)
if self.lowpass_learning_rate_minimum_over_time is None:
self.lowpass_learning_rate_minimum_over_time = loss_normalized
else:
self.lowpass_learning_rate_minimum_over_time = min(
self.lowpass_learning_rate_minimum_over_time, loss_normalized
)
if self.lowpass_learning_rate_minimum_over_time_past_step is None:
self.lowpass_learning_rate_minimum_over_time_past_step = (
self.lowpass_learning_rate_minimum_over_time
)
self.optimizer.param_groups[-1]["lr"] *= ( # type: ignore
self.lowpass_learning_rate_minimum_over_time
/ self.lowpass_learning_rate_minimum_over_time_past_step
)
self.lowpass_learning_rate_minimum_over_time_past_step = (
self.lowpass_learning_rate_minimum_over_time
)
self.loss_normalized_past_step = loss_normalized