Add files via upload
This commit is contained in:
parent
fc782baeb3
commit
b8000f87a5
2 changed files with 18 additions and 7 deletions
|
@ -4,11 +4,15 @@ import torch
|
||||||
class SoftmaxPower(torch.nn.Module):
|
class SoftmaxPower(torch.nn.Module):
|
||||||
dim: int | None
|
dim: int | None
|
||||||
power: float
|
power: float
|
||||||
|
mean_mode: bool
|
||||||
|
|
||||||
def __init__(self, power: float = 2.0, dim: int | None = None) -> None:
|
def __init__(
|
||||||
|
self, power: float = 2.0, dim: int | None = None, mean_mode: bool = False
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.power = power
|
self.power = power
|
||||||
|
self.mean_mode = mean_mode
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
super().__setstate__(state)
|
super().__setstate__(state)
|
||||||
|
@ -16,6 +20,8 @@ class SoftmaxPower(torch.nn.Module):
|
||||||
self.dim = None
|
self.dim = None
|
||||||
if not hasattr(self, "power"):
|
if not hasattr(self, "power"):
|
||||||
self.power = 2.0
|
self.power = 2.0
|
||||||
|
if not hasattr(self, "mean_mode"):
|
||||||
|
self.mean_mode = False
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
output: torch.Tensor = torch.abs(input).pow(self.power)
|
output: torch.Tensor = torch.abs(input).pow(self.power)
|
||||||
|
@ -23,6 +29,10 @@ class SoftmaxPower(torch.nn.Module):
|
||||||
output = output / output.sum()
|
output = output / output.sum()
|
||||||
else:
|
else:
|
||||||
output = output / output.sum(dim=self.dim, keepdim=True)
|
output = output / output.sum(dim=self.dim, keepdim=True)
|
||||||
|
|
||||||
|
if self.mean_mode:
|
||||||
|
return torch.abs(input).mean(dim=1, keepdim=True) * output
|
||||||
|
else:
|
||||||
return input * output
|
return input * output
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
|
|
|
@ -16,6 +16,7 @@ def make_cnn(
|
||||||
pooling_type: str,
|
pooling_type: str,
|
||||||
conv_0_enable_softmax: bool,
|
conv_0_enable_softmax: bool,
|
||||||
conv_0_power_softmax: float,
|
conv_0_power_softmax: float,
|
||||||
|
conv_0_meanmode_softmax: bool,
|
||||||
l_relu_negative_slope: float,
|
l_relu_negative_slope: float,
|
||||||
) -> torch.nn.Sequential:
|
) -> torch.nn.Sequential:
|
||||||
assert len(conv_out_channels_list) >= 1
|
assert len(conv_out_channels_list) >= 1
|
||||||
|
@ -60,11 +61,11 @@ def make_cnn(
|
||||||
assert setting_understood
|
assert setting_understood
|
||||||
|
|
||||||
if conv_0_enable_softmax:
|
if conv_0_enable_softmax:
|
||||||
# if conv_0_power_softmax != 0.0:
|
cnn.append(
|
||||||
# cnn.append(SoftmaxPower(dim=1, power=conv_0_power_softmax))
|
SoftmaxPower(
|
||||||
# else:
|
dim=1, power=conv_0_power_softmax, mean_mode=conv_0_meanmode_softmax
|
||||||
# cnn.append(torch.nn.Softmax(dim=1))
|
)
|
||||||
cnn.append(SoftmaxPower(dim=1, power=conv_0_power_softmax))
|
)
|
||||||
|
|
||||||
# Changing structure
|
# Changing structure
|
||||||
for i in range(1, len(conv_out_channels_list)):
|
for i in range(1, len(conv_out_channels_list)):
|
||||||
|
|
Loading…
Reference in a new issue