pynnmf/make_optimize.py

104 lines
3.2 KiB
Python
Raw Normal View History

2024-05-30 14:08:44 +02:00
import torch
from NNMFConv2d import NNMFConv2d
from NNMFConv2dP import NNMFConv2dP
def make_optimize(
network: torch.nn.Sequential,
list_cnn_top_id: list[int],
list_other_id: list[int],
lr_initial_nnmf: float = 0.01,
lr_initial_cnn: float = 0.001,
lr_initial_cnn_top: float = 0.001,
eps=1e-10,
) -> tuple[
torch.optim.Adam | None,
torch.optim.Adam | None,
torch.optim.Adam | None,
torch.optim.lr_scheduler.ReduceLROnPlateau | None,
torch.optim.lr_scheduler.ReduceLROnPlateau | None,
torch.optim.lr_scheduler.ReduceLROnPlateau | None,
]:
list_cnn_top: list = []
# Init the cnn top layers 1x1 conv2d layers
for layerid in list_cnn_top_id:
for netp in network[layerid].parameters():
with torch.no_grad():
if netp.ndim == 1:
netp.data *= 0
if netp.ndim == 4:
assert netp.shape[-2] == 1
assert netp.shape[-1] == 1
netp[: netp.shape[0], : netp.shape[0], 0, 0] = torch.eye(
netp.shape[0], dtype=netp.dtype, device=netp.device
)
netp[netp.shape[0] :, :, 0, 0] = 0
netp[:, netp.shape[0] :, 0, 0] = 0
list_cnn_top.append(netp)
list_cnn: list = []
list_nnmf: list = []
for layerid in list_other_id:
if isinstance(network[layerid], torch.nn.Conv2d):
for netp in network[layerid].parameters():
list_cnn.append(netp)
if isinstance(network[layerid], (NNMFConv2d, NNMFConv2dP)):
for netp in network[layerid].parameters():
list_nnmf.append(netp)
# The optimizer
if len(list_nnmf) > 0:
optimizer_nnmf: torch.optim.Adam | None = torch.optim.Adam(
list_nnmf, lr=lr_initial_nnmf
)
else:
optimizer_nnmf = None
if len(list_cnn) > 0:
optimizer_cnn: torch.optim.Adam | None = torch.optim.Adam(
list_cnn, lr=lr_initial_cnn
)
else:
optimizer_cnn = None
if len(list_cnn_top) > 0:
optimizer_cnn_top: torch.optim.Adam | None = torch.optim.Adam(
list_cnn_top, lr=lr_initial_cnn_top
)
else:
optimizer_cnn_top = None
# The LR Scheduler
if optimizer_nnmf is not None:
lr_scheduler_nnmf: torch.optim.lr_scheduler.ReduceLROnPlateau | None = (
torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_nnmf, eps=eps)
)
else:
lr_scheduler_nnmf = None
if optimizer_cnn is not None:
lr_scheduler_cnn: torch.optim.lr_scheduler.ReduceLROnPlateau | None = (
torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_cnn, eps=eps)
)
else:
lr_scheduler_cnn = None
if optimizer_cnn_top is not None:
lr_scheduler_cnn_top: torch.optim.lr_scheduler.ReduceLROnPlateau | None = (
torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_cnn_top, eps=eps)
)
else:
lr_scheduler_cnn_top = None
return (
optimizer_nnmf,
optimizer_cnn,
optimizer_cnn_top,
lr_scheduler_nnmf,
lr_scheduler_cnn,
lr_scheduler_cnn_top,
)