diff --git a/functions/SoftmaxPower.py b/functions/SoftmaxPower.py index 22b97cf..8a39ddc 100644 --- a/functions/SoftmaxPower.py +++ b/functions/SoftmaxPower.py @@ -5,14 +5,20 @@ class SoftmaxPower(torch.nn.Module): dim: int | None power: float mean_mode: bool + no_input_mode: bool def __init__( - self, power: float = 2.0, dim: int | None = None, mean_mode: bool = False + self, + power: float = 2.0, + dim: int | None = None, + mean_mode: bool = False, + no_input_mode: bool = False, ) -> None: super().__init__() self.dim = dim self.power = power self.mean_mode = mean_mode + self.no_input_mode = no_input_mode def __setstate__(self, state): super().__setstate__(state) @@ -22,6 +28,8 @@ class SoftmaxPower(torch.nn.Module): self.power = 2.0 if not hasattr(self, "mean_mode"): self.mean_mode = False + if not hasattr(self, "no_input_mode"): + self.no_input_mode = False def forward(self, input: torch.Tensor) -> torch.Tensor: output: torch.Tensor = torch.abs(input).pow(self.power) @@ -30,10 +38,25 @@ class SoftmaxPower(torch.nn.Module): else: output = output / output.sum(dim=self.dim, keepdim=True) - if self.mean_mode: + if self.no_input_mode: + return output + elif self.mean_mode: return torch.abs(input).mean(dim=1, keepdim=True) * output else: return input * output def extra_repr(self) -> str: - return f"dim={self.dim} ; power={self.power}" + if self.power != 0.0: + return ( + f"dim={self.dim}; " + f"power={self.power}; " + f"mean_mode={self.mean_mode}; " + f"no_input_mode={self.no_input_mode}" + ) + else: + return ( + f"dim={self.dim}; " + "exp-mode; " + f"mean_mode={self.mean_mode}; " + f"no_input_mode={self.no_input_mode}" + )