diff --git a/functions/make_cnn.py b/functions/make_cnn.py index f752e31..49d9edf 100644 --- a/functions/make_cnn.py +++ b/functions/make_cnn.py @@ -17,6 +17,7 @@ def make_cnn( conv_0_enable_softmax: bool, conv_0_power_softmax: float, conv_0_meanmode_softmax: bool, + conv_0_no_input_mode_softmax: bool, l_relu_negative_slope: float, ) -> torch.nn.Sequential: assert len(conv_out_channels_list) >= 1 @@ -63,7 +64,10 @@ def make_cnn( if conv_0_enable_softmax: cnn.append( SoftmaxPower( - dim=1, power=conv_0_power_softmax, mean_mode=conv_0_meanmode_softmax + dim=1, + power=conv_0_power_softmax, + mean_mode=conv_0_meanmode_softmax, + no_input_mode=conv_0_no_input_mode_softmax, ) )