102 lines
3.1 KiB
Python
102 lines
3.1 KiB
Python
import torch
|
|
from NNMF2d import NNMF2d
|
|
|
|
|
|
def make_optimize(
|
|
network: torch.nn.Sequential,
|
|
list_cnn_top_id: list[int],
|
|
list_other_id: list[int],
|
|
lr_initial_nnmf: float = 0.01,
|
|
lr_initial_cnn: float = 0.001,
|
|
lr_initial_cnn_top: float = 0.001,
|
|
eps=1e-10,
|
|
) -> tuple[
|
|
torch.optim.Adam | None,
|
|
torch.optim.Adam | None,
|
|
torch.optim.Adam | None,
|
|
torch.optim.lr_scheduler.ReduceLROnPlateau | None,
|
|
torch.optim.lr_scheduler.ReduceLROnPlateau | None,
|
|
torch.optim.lr_scheduler.ReduceLROnPlateau | None,
|
|
]:
|
|
|
|
list_cnn_top: list = []
|
|
# Init the cnn top layers 1x1 conv2d layers
|
|
for layerid in list_cnn_top_id:
|
|
for netp in network[layerid].parameters():
|
|
with torch.no_grad():
|
|
if netp.ndim == 1:
|
|
netp.data *= 0
|
|
if netp.ndim == 4:
|
|
assert netp.shape[-2] == 1
|
|
assert netp.shape[-1] == 1
|
|
netp[: netp.shape[0], : netp.shape[0], 0, 0] = torch.eye(
|
|
netp.shape[0], dtype=netp.dtype, device=netp.device
|
|
)
|
|
netp[netp.shape[0] :, :, 0, 0] = 0
|
|
netp[:, netp.shape[0] :, 0, 0] = 0
|
|
|
|
list_cnn_top.append(netp)
|
|
|
|
list_cnn: list = []
|
|
list_nnmf: list = []
|
|
for layerid in list_other_id:
|
|
if isinstance(network[layerid], torch.nn.Conv2d):
|
|
for netp in network[layerid].parameters():
|
|
list_cnn.append(netp)
|
|
|
|
if isinstance(network[layerid], NNMF2d):
|
|
for netp in network[layerid].parameters():
|
|
list_nnmf.append(netp)
|
|
|
|
# The optimizer
|
|
if len(list_nnmf) > 0:
|
|
optimizer_nnmf: torch.optim.Adam | None = torch.optim.Adam(
|
|
list_nnmf, lr=lr_initial_nnmf
|
|
)
|
|
else:
|
|
optimizer_nnmf = None
|
|
|
|
if len(list_cnn) > 0:
|
|
optimizer_cnn: torch.optim.Adam | None = torch.optim.Adam(
|
|
list_cnn, lr=lr_initial_cnn
|
|
)
|
|
else:
|
|
optimizer_cnn = None
|
|
|
|
if len(list_cnn_top) > 0:
|
|
optimizer_cnn_top: torch.optim.Adam | None = torch.optim.Adam(
|
|
list_cnn_top, lr=lr_initial_cnn_top
|
|
)
|
|
else:
|
|
optimizer_cnn_top = None
|
|
|
|
# The LR Scheduler
|
|
if optimizer_nnmf is not None:
|
|
lr_scheduler_nnmf: torch.optim.lr_scheduler.ReduceLROnPlateau | None = (
|
|
torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_nnmf, eps=eps)
|
|
)
|
|
else:
|
|
lr_scheduler_nnmf = None
|
|
|
|
if optimizer_cnn is not None:
|
|
lr_scheduler_cnn: torch.optim.lr_scheduler.ReduceLROnPlateau | None = (
|
|
torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_cnn, eps=eps)
|
|
)
|
|
else:
|
|
lr_scheduler_cnn = None
|
|
|
|
if optimizer_cnn_top is not None:
|
|
lr_scheduler_cnn_top: torch.optim.lr_scheduler.ReduceLROnPlateau | None = (
|
|
torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_cnn_top, eps=eps)
|
|
)
|
|
else:
|
|
lr_scheduler_cnn_top = None
|
|
|
|
return (
|
|
optimizer_nnmf,
|
|
optimizer_cnn,
|
|
optimizer_cnn_top,
|
|
lr_scheduler_nnmf,
|
|
lr_scheduler_cnn,
|
|
lr_scheduler_cnn_top,
|
|
)
|