diff --git a/functions/SoftmaxPower.py b/functions/SoftmaxPower.py index a5e0f39..22b97cf 100644 --- a/functions/SoftmaxPower.py +++ b/functions/SoftmaxPower.py @@ -4,11 +4,15 @@ import torch class SoftmaxPower(torch.nn.Module): dim: int | None power: float + mean_mode: bool - def __init__(self, power: float = 2.0, dim: int | None = None) -> None: + def __init__( + self, power: float = 2.0, dim: int | None = None, mean_mode: bool = False + ) -> None: super().__init__() self.dim = dim self.power = power + self.mean_mode = mean_mode def __setstate__(self, state): super().__setstate__(state) @@ -16,6 +20,8 @@ class SoftmaxPower(torch.nn.Module): self.dim = None if not hasattr(self, "power"): self.power = 2.0 + if not hasattr(self, "mean_mode"): + self.mean_mode = False def forward(self, input: torch.Tensor) -> torch.Tensor: output: torch.Tensor = torch.abs(input).pow(self.power) @@ -23,7 +29,11 @@ class SoftmaxPower(torch.nn.Module): output = output / output.sum() else: output = output / output.sum(dim=self.dim, keepdim=True) - return input * output + + if self.mean_mode: + return torch.abs(input).mean(dim=1, keepdim=True) * output + else: + return input * output def extra_repr(self) -> str: return f"dim={self.dim} ; power={self.power}" diff --git a/functions/make_cnn.py b/functions/make_cnn.py index 625ee79..f752e31 100644 --- a/functions/make_cnn.py +++ b/functions/make_cnn.py @@ -16,6 +16,7 @@ def make_cnn( pooling_type: str, conv_0_enable_softmax: bool, conv_0_power_softmax: float, + conv_0_meanmode_softmax: bool, l_relu_negative_slope: float, ) -> torch.nn.Sequential: assert len(conv_out_channels_list) >= 1 @@ -60,11 +61,11 @@ def make_cnn( assert setting_understood if conv_0_enable_softmax: -# if conv_0_power_softmax != 0.0: -# cnn.append(SoftmaxPower(dim=1, power=conv_0_power_softmax)) -# else: -# cnn.append(torch.nn.Softmax(dim=1)) - cnn.append(SoftmaxPower(dim=1, power=conv_0_power_softmax)) + cnn.append( + SoftmaxPower( + dim=1, power=conv_0_power_softmax, mean_mode=conv_0_meanmode_softmax + ) + ) # Changing structure for i in range(1, len(conv_out_channels_list)):