Add files via upload

This commit is contained in:
David Rotermund 2023-07-22 14:53:46 +02:00 committed by GitHub
parent 8505df62e3
commit e18690b0b3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 248 additions and 248 deletions

View file

@ -8,8 +8,8 @@ def alicorn_data_loader(
num_pfinkel: list[int] | None, num_pfinkel: list[int] | None,
load_stimuli_per_pfinkel: int, load_stimuli_per_pfinkel: int,
condition: str, condition: str,
logger,
data_path: str, data_path: str,
logger=None,
) -> torch.utils.data.TensorDataset: ) -> torch.utils.data.TensorDataset:
""" """
- num_pfinkel: list of the angles that should be loaded (ranging from - num_pfinkel: list of the angles that should be loaded (ranging from

View file

@ -32,9 +32,6 @@ def make_cnn(
) )
) )
if conv_0_enable_softmax:
cnn.append(torch.nn.Softmax(dim=1))
setting_understood: bool = False setting_understood: bool = False
if conv_activation_function.upper() == str("relu").upper(): if conv_activation_function.upper() == str("relu").upper():
cnn.append(torch.nn.ReLU()) cnn.append(torch.nn.ReLU())
@ -60,6 +57,9 @@ def make_cnn(
setting_understood = True setting_understood = True
assert setting_understood assert setting_understood
if conv_0_enable_softmax:
cnn.append(torch.nn.Softmax(dim=1))
# Changing structure # Changing structure
for i in range(1, len(conv_out_channels_list)): for i in range(1, len(conv_out_channels_list)):
if i == 1 and not train_conv_0: if i == 1 and not train_conv_0: