Add files via upload
This commit is contained in:
parent
8505df62e3
commit
e18690b0b3
4 changed files with 248 additions and 248 deletions
|
@ -8,8 +8,8 @@ def alicorn_data_loader(
|
|||
num_pfinkel: list[int] | None,
|
||||
load_stimuli_per_pfinkel: int,
|
||||
condition: str,
|
||||
logger,
|
||||
data_path: str,
|
||||
logger=None,
|
||||
) -> torch.utils.data.TensorDataset:
|
||||
"""
|
||||
- num_pfinkel: list of the angles that should be loaded (ranging from
|
||||
|
|
|
@ -32,9 +32,6 @@ def make_cnn(
|
|||
)
|
||||
)
|
||||
|
||||
if conv_0_enable_softmax:
|
||||
cnn.append(torch.nn.Softmax(dim=1))
|
||||
|
||||
setting_understood: bool = False
|
||||
if conv_activation_function.upper() == str("relu").upper():
|
||||
cnn.append(torch.nn.ReLU())
|
||||
|
@ -60,6 +57,9 @@ def make_cnn(
|
|||
setting_understood = True
|
||||
assert setting_understood
|
||||
|
||||
if conv_0_enable_softmax:
|
||||
cnn.append(torch.nn.Softmax(dim=1))
|
||||
|
||||
# Changing structure
|
||||
for i in range(1, len(conv_out_channels_list)):
|
||||
if i == 1 and not train_conv_0:
|
||||
|
|
Loading…
Reference in a new issue