diff --git a/functions/SoftmaxPower.py b/functions/SoftmaxPower.py index 370683d..a5e0f39 100644 --- a/functions/SoftmaxPower.py +++ b/functions/SoftmaxPower.py @@ -23,7 +23,7 @@ class SoftmaxPower(torch.nn.Module): output = output / output.sum() else: output = output / output.sum(dim=self.dim, keepdim=True) - return output + return input * output def extra_repr(self) -> str: return f"dim={self.dim} ; power={self.power}" diff --git a/functions/make_cnn.py b/functions/make_cnn.py index 72040b0..625ee79 100644 --- a/functions/make_cnn.py +++ b/functions/make_cnn.py @@ -60,10 +60,11 @@ def make_cnn( assert setting_understood if conv_0_enable_softmax: - if conv_0_power_softmax != 0.0: - cnn.append(SoftmaxPower(dim=1, power=conv_0_power_softmax)) - else: - cnn.append(torch.nn.Softmax(dim=1)) +# if conv_0_power_softmax != 0.0: +# cnn.append(SoftmaxPower(dim=1, power=conv_0_power_softmax)) +# else: +# cnn.append(torch.nn.Softmax(dim=1)) + cnn.append(SoftmaxPower(dim=1, power=conv_0_power_softmax)) # Changing structure for i in range(1, len(conv_out_channels_list)):