kk_contour_net_shallow/functions/SoftmaxPower.py
2023-07-22 16:51:57 +02:00

29 lines
838 B
Python

import torch
class SoftmaxPower(torch.nn.Module):
dim: int | None
power: float
def __init__(self, power: float = 2.0, dim: int | None = None) -> None:
super().__init__()
self.dim = dim
self.power = power
def __setstate__(self, state):
super().__setstate__(state)
if not hasattr(self, "dim"):
self.dim = None
if not hasattr(self, "power"):
self.power = 2.0
def forward(self, input: torch.Tensor) -> torch.Tensor:
output: torch.Tensor = torch.abs(input).pow(self.power)
if self.dim is None:
output = output / output.sum()
else:
output = output / output.sum(dim=self.dim, keepdim=True)
return output
def extra_repr(self) -> str:
return f"dim={self.dim} ; power={self.power}"