pytorch-sbs/network/build_lr_scheduler.py

88 lines
3 KiB
Python
Raw Normal View History

2023-01-05 13:23:58 +01:00
# %%
import torch
from network.Parameter import Config
try:
from network.SbSLRScheduler import SbSLRScheduler
sbs_lr_scheduler: bool = True
except Exception:
sbs_lr_scheduler = False
def build_lr_scheduler(
optimizer, cfg: Config, logging
) -> list[torch.optim.lr_scheduler.ReduceLROnPlateau | SbSLRScheduler | None]:
assert len(optimizer) > 0
lr_scheduler_list: list[
torch.optim.lr_scheduler.ReduceLROnPlateau | SbSLRScheduler | None
] = []
for id_optimizer in range(0, len(optimizer)):
if cfg.learning_parameters.lr_schedule_name == "None":
logging.info(f"Using lr scheduler for optimizer {id_optimizer} : None")
lr_scheduler_list.append(None)
elif cfg.learning_parameters.lr_schedule_name == "ReduceLROnPlateau":
logging.info(
f"Using lr scheduler for optimizer {id_optimizer}: ReduceLROnPlateau"
)
if optimizer[id_optimizer] is None:
lr_scheduler_list.append(None)
elif (cfg.learning_parameters.lr_scheduler_factor_w <= 0) or (
cfg.learning_parameters.lr_scheduler_patience_w <= 0
):
lr_scheduler_list.append(
torch.optim.lr_scheduler.ReduceLROnPlateau(
2023-02-21 14:37:51 +01:00
optimizer[id_optimizer],
eps=1e-14,
2023-01-05 13:23:58 +01:00
)
)
else:
lr_scheduler_list.append(
torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer[id_optimizer],
factor=cfg.learning_parameters.lr_scheduler_factor_w,
patience=cfg.learning_parameters.lr_scheduler_patience_w,
eps=1e-14,
)
)
elif cfg.learning_parameters.lr_schedule_name == "SbSLRScheduler":
logging.info(
f"Using lr scheduler for optimizer {id_optimizer}: SbSLRScheduler"
)
if sbs_lr_scheduler is False:
raise Exception(
f"lr_scheduler for optimizer {id_optimizer}: SbSLRScheduler.py missing"
)
if optimizer[id_optimizer] is None:
lr_scheduler_list.append(None)
elif (
(cfg.learning_parameters.lr_scheduler_factor_w <= 0)
or (cfg.learning_parameters.lr_scheduler_patience_w <= 0)
or (cfg.learning_parameters.lr_scheduler_tau_w <= 0)
):
lr_scheduler_list.append(None)
else:
lr_scheduler_list.append(
SbSLRScheduler(
optimizer[id_optimizer],
factor=cfg.learning_parameters.lr_scheduler_factor_w,
patience=cfg.learning_parameters.lr_scheduler_patience_w,
tau=cfg.learning_parameters.lr_scheduler_tau_w,
)
)
else:
raise Exception("lr_scheduler not implemented")
return lr_scheduler_list