diff --git a/make_optimize.py b/make_optimize.py index dc5d4f8..c647740 100644 --- a/make_optimize.py +++ b/make_optimize.py @@ -1,5 +1,6 @@ import torch from NNMF2d import NNMF2d +from Y import Y def make_optimize( @@ -44,6 +45,13 @@ def make_optimize( for netp in network[layerid].parameters(): list_cnn.append(netp) + if isinstance(network[layerid], Y): + for sublayer in network[layerid].segments: + for subsublayer in sublayer: + if isinstance(subsublayer, NNMF2d): + for netp in subsublayer.parameters(): + list_nnmf.append(netp) + if isinstance(network[layerid], NNMF2d): for netp in network[layerid].parameters(): list_nnmf.append(netp) diff --git a/run_network.py b/run_network.py index d68e0a0..2059a11 100644 --- a/run_network.py +++ b/run_network.py @@ -127,6 +127,7 @@ def main( print(network) + ( optimizer_nnmf, optimizer_cnn,