87 lines
3 KiB
Python
87 lines
3 KiB
Python
|
# %%
|
||
|
import torch
|
||
|
from network.Parameter import Config
|
||
|
|
||
|
try:
|
||
|
from network.SbSLRScheduler import SbSLRScheduler
|
||
|
|
||
|
sbs_lr_scheduler: bool = True
|
||
|
except Exception:
|
||
|
sbs_lr_scheduler = False
|
||
|
|
||
|
|
||
|
def build_lr_scheduler(
|
||
|
optimizer, cfg: Config, logging
|
||
|
) -> list[torch.optim.lr_scheduler.ReduceLROnPlateau | SbSLRScheduler | None]:
|
||
|
|
||
|
assert len(optimizer) > 0
|
||
|
|
||
|
lr_scheduler_list: list[
|
||
|
torch.optim.lr_scheduler.ReduceLROnPlateau | SbSLRScheduler | None
|
||
|
] = []
|
||
|
|
||
|
for id_optimizer in range(0, len(optimizer)):
|
||
|
|
||
|
if cfg.learning_parameters.lr_schedule_name == "None":
|
||
|
logging.info(f"Using lr scheduler for optimizer {id_optimizer} : None")
|
||
|
|
||
|
lr_scheduler_list.append(None)
|
||
|
|
||
|
elif cfg.learning_parameters.lr_schedule_name == "ReduceLROnPlateau":
|
||
|
logging.info(
|
||
|
f"Using lr scheduler for optimizer {id_optimizer}: ReduceLROnPlateau"
|
||
|
)
|
||
|
|
||
|
if optimizer[id_optimizer] is None:
|
||
|
lr_scheduler_list.append(None)
|
||
|
elif (cfg.learning_parameters.lr_scheduler_factor_w <= 0) or (
|
||
|
cfg.learning_parameters.lr_scheduler_patience_w <= 0
|
||
|
):
|
||
|
lr_scheduler_list.append(
|
||
|
torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||
|
optimizer[id_optimizer],eps=1e-14,
|
||
|
)
|
||
|
)
|
||
|
else:
|
||
|
lr_scheduler_list.append(
|
||
|
torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||
|
optimizer[id_optimizer],
|
||
|
factor=cfg.learning_parameters.lr_scheduler_factor_w,
|
||
|
patience=cfg.learning_parameters.lr_scheduler_patience_w,
|
||
|
eps=1e-14,
|
||
|
)
|
||
|
)
|
||
|
|
||
|
elif cfg.learning_parameters.lr_schedule_name == "SbSLRScheduler":
|
||
|
logging.info(
|
||
|
f"Using lr scheduler for optimizer {id_optimizer}: SbSLRScheduler"
|
||
|
)
|
||
|
|
||
|
if sbs_lr_scheduler is False:
|
||
|
raise Exception(
|
||
|
f"lr_scheduler for optimizer {id_optimizer}: SbSLRScheduler.py missing"
|
||
|
)
|
||
|
|
||
|
if optimizer[id_optimizer] is None:
|
||
|
lr_scheduler_list.append(None)
|
||
|
elif (
|
||
|
(cfg.learning_parameters.lr_scheduler_factor_w <= 0)
|
||
|
or (cfg.learning_parameters.lr_scheduler_patience_w <= 0)
|
||
|
or (cfg.learning_parameters.lr_scheduler_tau_w <= 0)
|
||
|
):
|
||
|
lr_scheduler_list.append(None)
|
||
|
else:
|
||
|
lr_scheduler_list.append(
|
||
|
SbSLRScheduler(
|
||
|
optimizer[id_optimizer],
|
||
|
factor=cfg.learning_parameters.lr_scheduler_factor_w,
|
||
|
patience=cfg.learning_parameters.lr_scheduler_patience_w,
|
||
|
tau=cfg.learning_parameters.lr_scheduler_tau_w,
|
||
|
)
|
||
|
)
|
||
|
|
||
|
else:
|
||
|
raise Exception("lr_scheduler not implemented")
|
||
|
|
||
|
return lr_scheduler_list
|