kk_contour_net_shallow/functions/SoftmaxPower.py

67 lines
1.9 KiB
Python
Raw Normal View History

2023-07-22 16:51:57 +02:00
import torch
class SoftmaxPower(torch.nn.Module):
dim: int | None
power: float
2023-07-23 01:45:45 +02:00
mean_mode: bool
2023-07-23 01:54:23 +02:00
no_input_mode: bool
2023-07-22 16:51:57 +02:00
2023-07-23 01:45:45 +02:00
def __init__(
2023-07-23 01:54:23 +02:00
self,
power: float = 2.0,
dim: int | None = None,
mean_mode: bool = False,
no_input_mode: bool = False,
2023-07-23 01:45:45 +02:00
) -> None:
2023-07-22 16:51:57 +02:00
super().__init__()
self.dim = dim
self.power = power
2023-07-23 01:45:45 +02:00
self.mean_mode = mean_mode
2023-07-23 01:54:23 +02:00
self.no_input_mode = no_input_mode
2023-07-22 16:51:57 +02:00
def __setstate__(self, state):
super().__setstate__(state)
if not hasattr(self, "dim"):
self.dim = None
if not hasattr(self, "power"):
self.power = 2.0
2023-07-23 01:45:45 +02:00
if not hasattr(self, "mean_mode"):
self.mean_mode = False
2023-07-23 01:54:23 +02:00
if not hasattr(self, "no_input_mode"):
self.no_input_mode = False
2023-07-22 16:51:57 +02:00
def forward(self, input: torch.Tensor) -> torch.Tensor:
2023-07-25 00:57:03 +02:00
if self.power != 0.0:
output: torch.Tensor = torch.abs(input).pow(self.power)
else:
output: torch.Tensor = torch.exp(input)
2023-07-22 16:51:57 +02:00
if self.dim is None:
output = output / output.sum()
else:
output = output / output.sum(dim=self.dim, keepdim=True)
2023-07-23 01:45:45 +02:00
2023-07-23 01:54:23 +02:00
if self.no_input_mode:
return output
elif self.mean_mode:
2023-07-23 01:45:45 +02:00
return torch.abs(input).mean(dim=1, keepdim=True) * output
else:
return input * output
2023-07-22 16:51:57 +02:00
def extra_repr(self) -> str:
2023-07-23 01:54:23 +02:00
if self.power != 0.0:
return (
f"dim={self.dim}; "
f"power={self.power}; "
f"mean_mode={self.mean_mode}; "
f"no_input_mode={self.no_input_mode}"
)
else:
return (
f"dim={self.dim}; "
"exp-mode; "
f"mean_mode={self.mean_mode}; "
f"no_input_mode={self.no_input_mode}"
)