Add files via upload

This commit is contained in:
David Rotermund 2023-07-23 01:45:45 +02:00 committed by GitHub
parent fc782baeb3
commit b8000f87a5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 7 deletions

View file

@ -4,11 +4,15 @@ import torch
class SoftmaxPower(torch.nn.Module):
dim: int | None
power: float
mean_mode: bool
def __init__(self, power: float = 2.0, dim: int | None = None) -> None:
def __init__(
self, power: float = 2.0, dim: int | None = None, mean_mode: bool = False
) -> None:
super().__init__()
self.dim = dim
self.power = power
self.mean_mode = mean_mode
def __setstate__(self, state):
super().__setstate__(state)
@ -16,6 +20,8 @@ class SoftmaxPower(torch.nn.Module):
self.dim = None
if not hasattr(self, "power"):
self.power = 2.0
if not hasattr(self, "mean_mode"):
self.mean_mode = False
def forward(self, input: torch.Tensor) -> torch.Tensor:
output: torch.Tensor = torch.abs(input).pow(self.power)
@ -23,7 +29,11 @@ class SoftmaxPower(torch.nn.Module):
output = output / output.sum()
else:
output = output / output.sum(dim=self.dim, keepdim=True)
return input * output
if self.mean_mode:
return torch.abs(input).mean(dim=1, keepdim=True) * output
else:
return input * output
def extra_repr(self) -> str:
return f"dim={self.dim} ; power={self.power}"

View file

@ -16,6 +16,7 @@ def make_cnn(
pooling_type: str,
conv_0_enable_softmax: bool,
conv_0_power_softmax: float,
conv_0_meanmode_softmax: bool,
l_relu_negative_slope: float,
) -> torch.nn.Sequential:
assert len(conv_out_channels_list) >= 1
@ -60,11 +61,11 @@ def make_cnn(
assert setting_understood
if conv_0_enable_softmax:
# if conv_0_power_softmax != 0.0:
# cnn.append(SoftmaxPower(dim=1, power=conv_0_power_softmax))
# else:
# cnn.append(torch.nn.Softmax(dim=1))
cnn.append(SoftmaxPower(dim=1, power=conv_0_power_softmax))
cnn.append(
SoftmaxPower(
dim=1, power=conv_0_power_softmax, mean_mode=conv_0_meanmode_softmax
)
)
# Changing structure
for i in range(1, len(conv_out_channels_list)):