diff --git a/cnn_training.py b/cnn_training.py index 8ff0eee..dad319d 100644 --- a/cnn_training.py +++ b/cnn_training.py @@ -165,6 +165,8 @@ def run_network( logger=logger, data_path=data_path, ) + assert data_train.__len__() > 0 + input_shape = data_train.__getitem__(0)[1].shape logger.info("Loading test data") data_test = alicorn_data_loader( @@ -252,6 +254,7 @@ def run_network( conv_0_power_softmax=conv_0_power_softmax, conv_0_meanmode_softmax=conv_0_meanmode_softmax, conv_0_no_input_mode_softmax=conv_0_no_input_mode_softmax, + input_shape=input_shape, ).to(device) logger.info(model)