From 7bcd733d3da38662bc67627c885549fea60de62c Mon Sep 17 00:00:00 2001 From: David Rotermund Date: Sat, 1 Jun 2024 02:14:18 +0200 Subject: [PATCH] Add files via upload --- make_optimize.py | 8 ++++++++ run_network.py | 1 + 2 files changed, 9 insertions(+) diff --git a/make_optimize.py b/make_optimize.py index dc5d4f8..c647740 100644 --- a/make_optimize.py +++ b/make_optimize.py @@ -1,5 +1,6 @@ import torch from NNMF2d import NNMF2d +from Y import Y def make_optimize( @@ -44,6 +45,13 @@ def make_optimize( for netp in network[layerid].parameters(): list_cnn.append(netp) + if isinstance(network[layerid], Y): + for sublayer in network[layerid].segments: + for subsublayer in sublayer: + if isinstance(subsublayer, NNMF2d): + for netp in subsublayer.parameters(): + list_nnmf.append(netp) + if isinstance(network[layerid], NNMF2d): for netp in network[layerid].parameters(): list_nnmf.append(netp) diff --git a/run_network.py b/run_network.py index d68e0a0..2059a11 100644 --- a/run_network.py +++ b/run_network.py @@ -127,6 +127,7 @@ def main( print(network) + ( optimizer_nnmf, optimizer_cnn,