Add files via upload

This commit is contained in:
David Rotermund 2023-07-23 01:37:54 +02:00 committed by GitHub
parent 5a799066f2
commit fc782baeb3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 5 deletions

View file

@ -23,7 +23,7 @@ class SoftmaxPower(torch.nn.Module):
output = output / output.sum() output = output / output.sum()
else: else:
output = output / output.sum(dim=self.dim, keepdim=True) output = output / output.sum(dim=self.dim, keepdim=True)
return output return input * output
def extra_repr(self) -> str: def extra_repr(self) -> str:
return f"dim={self.dim} ; power={self.power}" return f"dim={self.dim} ; power={self.power}"

View file

@ -60,10 +60,11 @@ def make_cnn(
assert setting_understood assert setting_understood
if conv_0_enable_softmax: if conv_0_enable_softmax:
if conv_0_power_softmax != 0.0: # if conv_0_power_softmax != 0.0:
cnn.append(SoftmaxPower(dim=1, power=conv_0_power_softmax)) # cnn.append(SoftmaxPower(dim=1, power=conv_0_power_softmax))
else: # else:
cnn.append(torch.nn.Softmax(dim=1)) # cnn.append(torch.nn.Softmax(dim=1))
cnn.append(SoftmaxPower(dim=1, power=conv_0_power_softmax))
# Changing structure # Changing structure
for i in range(1, len(conv_out_channels_list)): for i in range(1, len(conv_out_channels_list)):