Delete SbSLRScheduler.py
This commit is contained in:
parent
45d146e529
commit
ce3a1224c7
1 changed files with 0 additions and 106 deletions
|
@ -1,106 +0,0 @@
|
|||
import torch
|
||||
|
||||
|
||||
class SbSLRScheduler(torch.optim.lr_scheduler.ReduceLROnPlateau):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
mode: str = "min",
|
||||
factor: float = 0.1,
|
||||
patience: int = 10,
|
||||
threshold: float = 1e-4,
|
||||
threshold_mode: str = "rel",
|
||||
cooldown: int = 0,
|
||||
min_lr: float = 0,
|
||||
eps: float = 1e-8,
|
||||
verbose: bool = False,
|
||||
tau: float = 10,
|
||||
) -> None:
|
||||
|
||||
super().__init__(
|
||||
optimizer=optimizer,
|
||||
mode=mode,
|
||||
factor=factor,
|
||||
patience=patience,
|
||||
threshold=threshold,
|
||||
threshold_mode=threshold_mode,
|
||||
cooldown=cooldown,
|
||||
min_lr=min_lr,
|
||||
eps=eps,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.lowpass_tau: float = tau
|
||||
self.lowpass_decay_value: float = 1.0 - (1.0 / self.lowpass_tau)
|
||||
|
||||
self.lowpass_number_of_steps: int = 0
|
||||
self.loss_maximum_over_time: float | None = None
|
||||
self.lowpass_memory: float = 0.0
|
||||
self.lowpass_learning_rate_minimum_over_time: float | None = None
|
||||
self.lowpass_learning_rate_minimum_over_time_past_step: float | None = None
|
||||
|
||||
self.previous_learning_rate: float | None = None
|
||||
self.loss_normalized_past_step: float | None = None
|
||||
|
||||
def step(self, metrics, epoch=None) -> None:
|
||||
|
||||
loss = float(metrics)
|
||||
|
||||
if self.loss_maximum_over_time is None:
|
||||
self.loss_maximum_over_time = loss
|
||||
|
||||
if self.loss_normalized_past_step is None:
|
||||
self.loss_normalized_past_step = loss / self.loss_maximum_over_time
|
||||
|
||||
if self.previous_learning_rate is None:
|
||||
self.previous_learning_rate = self.optimizer.param_groups[-1]["lr"] # type: ignore
|
||||
|
||||
# The parent lr scheduler controlls the basic learn rate
|
||||
self.previous_learning_rate = self.optimizer.param_groups[-1]["lr"] # type: ignore
|
||||
super().step(metrics=self.loss_normalized_past_step, epoch=epoch)
|
||||
|
||||
# If the parent changes the base learning rate,
|
||||
# then we reset the adaptive part
|
||||
if self.optimizer.param_groups[-1]["lr"] != self.previous_learning_rate: # type: ignore
|
||||
self.previous_learning_rate = self.optimizer.param_groups[-1]["lr"] # type: ignore
|
||||
|
||||
self.lowpass_number_of_steps = 0
|
||||
self.loss_maximum_over_time = None
|
||||
self.lowpass_memory = 0.0
|
||||
self.lowpass_learning_rate_minimum_over_time = None
|
||||
self.lowpass_learning_rate_minimum_over_time_past_step = None
|
||||
|
||||
if self.loss_maximum_over_time is None:
|
||||
self.loss_maximum_over_time = loss
|
||||
else:
|
||||
self.loss_maximum_over_time = max(self.loss_maximum_over_time, loss)
|
||||
|
||||
self.lowpass_number_of_steps += 1
|
||||
|
||||
self.lowpass_memory = self.lowpass_memory * self.lowpass_decay_value + (
|
||||
loss / self.loss_maximum_over_time
|
||||
) * (1.0 / self.lowpass_tau)
|
||||
|
||||
loss_normalized: float = self.lowpass_memory / (
|
||||
1.0 - self.lowpass_decay_value ** float(self.lowpass_number_of_steps)
|
||||
)
|
||||
|
||||
if self.lowpass_learning_rate_minimum_over_time is None:
|
||||
self.lowpass_learning_rate_minimum_over_time = loss_normalized
|
||||
else:
|
||||
self.lowpass_learning_rate_minimum_over_time = min(
|
||||
self.lowpass_learning_rate_minimum_over_time, loss_normalized
|
||||
)
|
||||
|
||||
if self.lowpass_learning_rate_minimum_over_time_past_step is None:
|
||||
self.lowpass_learning_rate_minimum_over_time_past_step = (
|
||||
self.lowpass_learning_rate_minimum_over_time
|
||||
)
|
||||
|
||||
self.optimizer.param_groups[-1]["lr"] *= ( # type: ignore
|
||||
self.lowpass_learning_rate_minimum_over_time
|
||||
/ self.lowpass_learning_rate_minimum_over_time_past_step
|
||||
)
|
||||
self.lowpass_learning_rate_minimum_over_time_past_step = (
|
||||
self.lowpass_learning_rate_minimum_over_time
|
||||
)
|
||||
self.loss_normalized_past_step = loss_normalized
|
Loading…
Reference in a new issue