diff --git a/functions/SoftmaxPower.py b/functions/SoftmaxPower.py index 8a39ddc..b0fd0ce 100644 --- a/functions/SoftmaxPower.py +++ b/functions/SoftmaxPower.py @@ -32,7 +32,11 @@ class SoftmaxPower(torch.nn.Module): self.no_input_mode = False def forward(self, input: torch.Tensor) -> torch.Tensor: - output: torch.Tensor = torch.abs(input).pow(self.power) + if self.power != 0.0: + output: torch.Tensor = torch.abs(input).pow(self.power) + else: + output: torch.Tensor = torch.exp(input) + if self.dim is None: output = output / output.sum() else: