From 28a21b32c2a350dd86b0cfe5bde848e3e03d7cb1 Mon Sep 17 00:00:00 2001 From: David Rotermund <54365609+davrot@users.noreply.github.com> Date: Thu, 27 Jul 2023 20:13:59 +0200 Subject: [PATCH] Add files via upload --- cnn_training.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cnn_training.py b/cnn_training.py index 8ff0eee..dad319d 100644 --- a/cnn_training.py +++ b/cnn_training.py @@ -165,6 +165,8 @@ def run_network( logger=logger, data_path=data_path, ) + assert data_train.__len__() > 0 + input_shape = data_train.__getitem__(0)[1].shape logger.info("Loading test data") data_test = alicorn_data_loader( @@ -252,6 +254,7 @@ def run_network( conv_0_power_softmax=conv_0_power_softmax, conv_0_meanmode_softmax=conv_0_meanmode_softmax, conv_0_no_input_mode_softmax=conv_0_no_input_mode_softmax, + input_shape=input_shape, ).to(device) logger.info(model)