Add files via upload
This commit is contained in:
parent
8505df62e3
commit
e18690b0b3
4 changed files with 248 additions and 248 deletions
|
@ -8,8 +8,8 @@ def alicorn_data_loader(
|
||||||
num_pfinkel: list[int] | None,
|
num_pfinkel: list[int] | None,
|
||||||
load_stimuli_per_pfinkel: int,
|
load_stimuli_per_pfinkel: int,
|
||||||
condition: str,
|
condition: str,
|
||||||
|
logger,
|
||||||
data_path: str,
|
data_path: str,
|
||||||
logger=None,
|
|
||||||
) -> torch.utils.data.TensorDataset:
|
) -> torch.utils.data.TensorDataset:
|
||||||
"""
|
"""
|
||||||
- num_pfinkel: list of the angles that should be loaded (ranging from
|
- num_pfinkel: list of the angles that should be loaded (ranging from
|
||||||
|
|
|
@ -32,9 +32,6 @@ def make_cnn(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if conv_0_enable_softmax:
|
|
||||||
cnn.append(torch.nn.Softmax(dim=1))
|
|
||||||
|
|
||||||
setting_understood: bool = False
|
setting_understood: bool = False
|
||||||
if conv_activation_function.upper() == str("relu").upper():
|
if conv_activation_function.upper() == str("relu").upper():
|
||||||
cnn.append(torch.nn.ReLU())
|
cnn.append(torch.nn.ReLU())
|
||||||
|
@ -60,6 +57,9 @@ def make_cnn(
|
||||||
setting_understood = True
|
setting_understood = True
|
||||||
assert setting_understood
|
assert setting_understood
|
||||||
|
|
||||||
|
if conv_0_enable_softmax:
|
||||||
|
cnn.append(torch.nn.Softmax(dim=1))
|
||||||
|
|
||||||
# Changing structure
|
# Changing structure
|
||||||
for i in range(1, len(conv_out_channels_list)):
|
for i in range(1, len(conv_out_channels_list)):
|
||||||
if i == 1 and not train_conv_0:
|
if i == 1 and not train_conv_0:
|
||||||
|
|
Loading…
Reference in a new issue