Add files via upload

This commit is contained in:
David Rotermund 2023-07-22 16:51:57 +02:00 committed by GitHub
parent e18690b0b3
commit a516d05146
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 1 deletions

29
functions/SoftmaxPower.py Normal file
View file

@ -0,0 +1,29 @@
import torch
class SoftmaxPower(torch.nn.Module):
dim: int | None
power: float
def __init__(self, power: float = 2.0, dim: int | None = None) -> None:
super().__init__()
self.dim = dim
self.power = power
def __setstate__(self, state):
super().__setstate__(state)
if not hasattr(self, "dim"):
self.dim = None
if not hasattr(self, "power"):
self.power = 2.0
def forward(self, input: torch.Tensor) -> torch.Tensor:
output: torch.Tensor = torch.abs(input).pow(self.power)
if self.dim is None:
output = output / output.sum()
else:
output = output / output.sum(dim=self.dim, keepdim=True)
return output
def extra_repr(self) -> str:
return f"dim={self.dim} ; power={self.power}"

View file

@ -1,5 +1,6 @@
import torch import torch
import numpy as np import numpy as np
from functions.SoftmaxPower import SoftmaxPower
def make_cnn( def make_cnn(
@ -14,6 +15,7 @@ def make_cnn(
mp_1_stride: int, mp_1_stride: int,
pooling_type: str, pooling_type: str,
conv_0_enable_softmax: bool, conv_0_enable_softmax: bool,
conv_0_power_softmax: float,
l_relu_negative_slope: float, l_relu_negative_slope: float,
) -> torch.nn.Sequential: ) -> torch.nn.Sequential:
assert len(conv_out_channels_list) >= 1 assert len(conv_out_channels_list) >= 1
@ -58,6 +60,9 @@ def make_cnn(
assert setting_understood assert setting_understood
if conv_0_enable_softmax: if conv_0_enable_softmax:
if conv_0_power_softmax != 0.0:
cnn.append(SoftmaxPower(dim=1, power=conv_0_power_softmax))
else:
cnn.append(torch.nn.Softmax(dim=1)) cnn.append(torch.nn.Softmax(dim=1))
# Changing structure # Changing structure