Add files via upload
This commit is contained in:
parent
ab0f4084c3
commit
2bcb14def3
1 changed files with 26 additions and 3 deletions
|
@ -5,14 +5,20 @@ class SoftmaxPower(torch.nn.Module):
|
||||||
dim: int | None
|
dim: int | None
|
||||||
power: float
|
power: float
|
||||||
mean_mode: bool
|
mean_mode: bool
|
||||||
|
no_input_mode: bool
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, power: float = 2.0, dim: int | None = None, mean_mode: bool = False
|
self,
|
||||||
|
power: float = 2.0,
|
||||||
|
dim: int | None = None,
|
||||||
|
mean_mode: bool = False,
|
||||||
|
no_input_mode: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.power = power
|
self.power = power
|
||||||
self.mean_mode = mean_mode
|
self.mean_mode = mean_mode
|
||||||
|
self.no_input_mode = no_input_mode
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
super().__setstate__(state)
|
super().__setstate__(state)
|
||||||
|
@ -22,6 +28,8 @@ class SoftmaxPower(torch.nn.Module):
|
||||||
self.power = 2.0
|
self.power = 2.0
|
||||||
if not hasattr(self, "mean_mode"):
|
if not hasattr(self, "mean_mode"):
|
||||||
self.mean_mode = False
|
self.mean_mode = False
|
||||||
|
if not hasattr(self, "no_input_mode"):
|
||||||
|
self.no_input_mode = False
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
output: torch.Tensor = torch.abs(input).pow(self.power)
|
output: torch.Tensor = torch.abs(input).pow(self.power)
|
||||||
|
@ -30,10 +38,25 @@ class SoftmaxPower(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
output = output / output.sum(dim=self.dim, keepdim=True)
|
output = output / output.sum(dim=self.dim, keepdim=True)
|
||||||
|
|
||||||
if self.mean_mode:
|
if self.no_input_mode:
|
||||||
|
return output
|
||||||
|
elif self.mean_mode:
|
||||||
return torch.abs(input).mean(dim=1, keepdim=True) * output
|
return torch.abs(input).mean(dim=1, keepdim=True) * output
|
||||||
else:
|
else:
|
||||||
return input * output
|
return input * output
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
return f"dim={self.dim} ; power={self.power}"
|
if self.power != 0.0:
|
||||||
|
return (
|
||||||
|
f"dim={self.dim}; "
|
||||||
|
f"power={self.power}; "
|
||||||
|
f"mean_mode={self.mean_mode}; "
|
||||||
|
f"no_input_mode={self.no_input_mode}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
f"dim={self.dim}; "
|
||||||
|
"exp-mode; "
|
||||||
|
f"mean_mode={self.mean_mode}; "
|
||||||
|
f"no_input_mode={self.no_input_mode}"
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in a new issue