104 lines
3.2 KiB
Python
104 lines
3.2 KiB
Python
|
import torch
|
||
|
from NNMFConv2d import NNMFConv2d
|
||
|
from NNMFConv2dP import NNMFConv2dP
|
||
|
|
||
|
|
||
|
def make_optimize(
|
||
|
network: torch.nn.Sequential,
|
||
|
list_cnn_top_id: list[int],
|
||
|
list_other_id: list[int],
|
||
|
lr_initial_nnmf: float = 0.01,
|
||
|
lr_initial_cnn: float = 0.001,
|
||
|
lr_initial_cnn_top: float = 0.001,
|
||
|
eps=1e-10,
|
||
|
) -> tuple[
|
||
|
torch.optim.Adam | None,
|
||
|
torch.optim.Adam | None,
|
||
|
torch.optim.Adam | None,
|
||
|
torch.optim.lr_scheduler.ReduceLROnPlateau | None,
|
||
|
torch.optim.lr_scheduler.ReduceLROnPlateau | None,
|
||
|
torch.optim.lr_scheduler.ReduceLROnPlateau | None,
|
||
|
]:
|
||
|
|
||
|
list_cnn_top: list = []
|
||
|
# Init the cnn top layers 1x1 conv2d layers
|
||
|
for layerid in list_cnn_top_id:
|
||
|
for netp in network[layerid].parameters():
|
||
|
with torch.no_grad():
|
||
|
if netp.ndim == 1:
|
||
|
netp.data *= 0
|
||
|
if netp.ndim == 4:
|
||
|
assert netp.shape[-2] == 1
|
||
|
assert netp.shape[-1] == 1
|
||
|
netp[: netp.shape[0], : netp.shape[0], 0, 0] = torch.eye(
|
||
|
netp.shape[0], dtype=netp.dtype, device=netp.device
|
||
|
)
|
||
|
netp[netp.shape[0] :, :, 0, 0] = 0
|
||
|
netp[:, netp.shape[0] :, 0, 0] = 0
|
||
|
|
||
|
list_cnn_top.append(netp)
|
||
|
|
||
|
list_cnn: list = []
|
||
|
list_nnmf: list = []
|
||
|
for layerid in list_other_id:
|
||
|
if isinstance(network[layerid], torch.nn.Conv2d):
|
||
|
for netp in network[layerid].parameters():
|
||
|
list_cnn.append(netp)
|
||
|
|
||
|
if isinstance(network[layerid], (NNMFConv2d, NNMFConv2dP)):
|
||
|
for netp in network[layerid].parameters():
|
||
|
list_nnmf.append(netp)
|
||
|
|
||
|
# The optimizer
|
||
|
if len(list_nnmf) > 0:
|
||
|
optimizer_nnmf: torch.optim.Adam | None = torch.optim.Adam(
|
||
|
list_nnmf, lr=lr_initial_nnmf
|
||
|
)
|
||
|
else:
|
||
|
optimizer_nnmf = None
|
||
|
|
||
|
if len(list_cnn) > 0:
|
||
|
optimizer_cnn: torch.optim.Adam | None = torch.optim.Adam(
|
||
|
list_cnn, lr=lr_initial_cnn
|
||
|
)
|
||
|
else:
|
||
|
optimizer_cnn = None
|
||
|
|
||
|
if len(list_cnn_top) > 0:
|
||
|
optimizer_cnn_top: torch.optim.Adam | None = torch.optim.Adam(
|
||
|
list_cnn_top, lr=lr_initial_cnn_top
|
||
|
)
|
||
|
else:
|
||
|
optimizer_cnn_top = None
|
||
|
|
||
|
# The LR Scheduler
|
||
|
if optimizer_nnmf is not None:
|
||
|
lr_scheduler_nnmf: torch.optim.lr_scheduler.ReduceLROnPlateau | None = (
|
||
|
torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_nnmf, eps=eps)
|
||
|
)
|
||
|
else:
|
||
|
lr_scheduler_nnmf = None
|
||
|
|
||
|
if optimizer_cnn is not None:
|
||
|
lr_scheduler_cnn: torch.optim.lr_scheduler.ReduceLROnPlateau | None = (
|
||
|
torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_cnn, eps=eps)
|
||
|
)
|
||
|
else:
|
||
|
lr_scheduler_cnn = None
|
||
|
|
||
|
if optimizer_cnn_top is not None:
|
||
|
lr_scheduler_cnn_top: torch.optim.lr_scheduler.ReduceLROnPlateau | None = (
|
||
|
torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_cnn_top, eps=eps)
|
||
|
)
|
||
|
else:
|
||
|
lr_scheduler_cnn_top = None
|
||
|
|
||
|
return (
|
||
|
optimizer_nnmf,
|
||
|
optimizer_cnn,
|
||
|
optimizer_cnn_top,
|
||
|
lr_scheduler_nnmf,
|
||
|
lr_scheduler_cnn,
|
||
|
lr_scheduler_cnn_top,
|
||
|
)
|