2023-07-22 16:51:57 +02:00
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
class SoftmaxPower(torch.nn.Module):
|
|
|
|
dim: int | None
|
|
|
|
power: float
|
2023-07-23 01:45:45 +02:00
|
|
|
mean_mode: bool
|
2023-07-22 16:51:57 +02:00
|
|
|
|
2023-07-23 01:45:45 +02:00
|
|
|
def __init__(
|
|
|
|
self, power: float = 2.0, dim: int | None = None, mean_mode: bool = False
|
|
|
|
) -> None:
|
2023-07-22 16:51:57 +02:00
|
|
|
super().__init__()
|
|
|
|
self.dim = dim
|
|
|
|
self.power = power
|
2023-07-23 01:45:45 +02:00
|
|
|
self.mean_mode = mean_mode
|
2023-07-22 16:51:57 +02:00
|
|
|
|
|
|
|
def __setstate__(self, state):
|
|
|
|
super().__setstate__(state)
|
|
|
|
if not hasattr(self, "dim"):
|
|
|
|
self.dim = None
|
|
|
|
if not hasattr(self, "power"):
|
|
|
|
self.power = 2.0
|
2023-07-23 01:45:45 +02:00
|
|
|
if not hasattr(self, "mean_mode"):
|
|
|
|
self.mean_mode = False
|
2023-07-22 16:51:57 +02:00
|
|
|
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
|
|
output: torch.Tensor = torch.abs(input).pow(self.power)
|
|
|
|
if self.dim is None:
|
|
|
|
output = output / output.sum()
|
|
|
|
else:
|
|
|
|
output = output / output.sum(dim=self.dim, keepdim=True)
|
2023-07-23 01:45:45 +02:00
|
|
|
|
|
|
|
if self.mean_mode:
|
|
|
|
return torch.abs(input).mean(dim=1, keepdim=True) * output
|
|
|
|
else:
|
|
|
|
return input * output
|
2023-07-22 16:51:57 +02:00
|
|
|
|
|
|
|
def extra_repr(self) -> str:
|
|
|
|
return f"dim={self.dim} ; power={self.power}"
|