diff --git a/functions/SoftmaxPower.py b/functions/SoftmaxPower.py new file mode 100644 index 0000000..370683d --- /dev/null +++ b/functions/SoftmaxPower.py @@ -0,0 +1,29 @@ +import torch + + +class SoftmaxPower(torch.nn.Module): + dim: int | None + power: float + + def __init__(self, power: float = 2.0, dim: int | None = None) -> None: + super().__init__() + self.dim = dim + self.power = power + + def __setstate__(self, state): + super().__setstate__(state) + if not hasattr(self, "dim"): + self.dim = None + if not hasattr(self, "power"): + self.power = 2.0 + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output: torch.Tensor = torch.abs(input).pow(self.power) + if self.dim is None: + output = output / output.sum() + else: + output = output / output.sum(dim=self.dim, keepdim=True) + return output + + def extra_repr(self) -> str: + return f"dim={self.dim} ; power={self.power}" diff --git a/functions/make_cnn.py b/functions/make_cnn.py index 866b02c..72040b0 100644 --- a/functions/make_cnn.py +++ b/functions/make_cnn.py @@ -1,5 +1,6 @@ import torch import numpy as np +from functions.SoftmaxPower import SoftmaxPower def make_cnn( @@ -14,6 +15,7 @@ def make_cnn( mp_1_stride: int, pooling_type: str, conv_0_enable_softmax: bool, + conv_0_power_softmax: float, l_relu_negative_slope: float, ) -> torch.nn.Sequential: assert len(conv_out_channels_list) >= 1 @@ -58,7 +60,10 @@ def make_cnn( assert setting_understood if conv_0_enable_softmax: - cnn.append(torch.nn.Softmax(dim=1)) + if conv_0_power_softmax != 0.0: + cnn.append(SoftmaxPower(dim=1, power=conv_0_power_softmax)) + else: + cnn.append(torch.nn.Softmax(dim=1)) # Changing structure for i in range(1, len(conv_out_channels_list)):