Add files via upload
This commit is contained in:
parent
e18690b0b3
commit
a516d05146
2 changed files with 35 additions and 1 deletions
29
functions/SoftmaxPower.py
Normal file
29
functions/SoftmaxPower.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
import torch
|
||||
|
||||
|
||||
class SoftmaxPower(torch.nn.Module):
|
||||
dim: int | None
|
||||
power: float
|
||||
|
||||
def __init__(self, power: float = 2.0, dim: int | None = None) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.power = power
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
if not hasattr(self, "dim"):
|
||||
self.dim = None
|
||||
if not hasattr(self, "power"):
|
||||
self.power = 2.0
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
output: torch.Tensor = torch.abs(input).pow(self.power)
|
||||
if self.dim is None:
|
||||
output = output / output.sum()
|
||||
else:
|
||||
output = output / output.sum(dim=self.dim, keepdim=True)
|
||||
return output
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"dim={self.dim} ; power={self.power}"
|
|
@ -1,5 +1,6 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from functions.SoftmaxPower import SoftmaxPower
|
||||
|
||||
|
||||
def make_cnn(
|
||||
|
@ -14,6 +15,7 @@ def make_cnn(
|
|||
mp_1_stride: int,
|
||||
pooling_type: str,
|
||||
conv_0_enable_softmax: bool,
|
||||
conv_0_power_softmax: float,
|
||||
l_relu_negative_slope: float,
|
||||
) -> torch.nn.Sequential:
|
||||
assert len(conv_out_channels_list) >= 1
|
||||
|
@ -58,6 +60,9 @@ def make_cnn(
|
|||
assert setting_understood
|
||||
|
||||
if conv_0_enable_softmax:
|
||||
if conv_0_power_softmax != 0.0:
|
||||
cnn.append(SoftmaxPower(dim=1, power=conv_0_power_softmax))
|
||||
else:
|
||||
cnn.append(torch.nn.Softmax(dim=1))
|
||||
|
||||
# Changing structure
|
||||
|
|
Loading…
Reference in a new issue