Add files via upload
This commit is contained in:
parent
fbc4516e58
commit
34a688d546
1 changed files with 7 additions and 5 deletions
|
@ -8,8 +8,8 @@ def alicorn_data_loader(
|
|||
num_pfinkel: list[int] | None,
|
||||
load_stimuli_per_pfinkel: int,
|
||||
condition: str,
|
||||
logger,
|
||||
data_path: str,
|
||||
logger=None,
|
||||
) -> torch.utils.data.TensorDataset:
|
||||
"""
|
||||
- num_pfinkel: list of the angles that should be loaded (ranging from
|
||||
|
@ -52,6 +52,7 @@ def alicorn_data_loader(
|
|||
num_batches = load_stimuli_per_pfinkel // (stimuli_per_pfinkel // len(batch))
|
||||
num_img_per_pfinkel = load_stimuli_per_pfinkel // num_batches
|
||||
|
||||
if logger is not None:
|
||||
logger.info(f"{num_batches} batches")
|
||||
logger.info(f"{num_img_per_pfinkel} stimuli per pfinkel.")
|
||||
|
||||
|
@ -64,6 +65,7 @@ def alicorn_data_loader(
|
|||
(num_stimuli), dtype=torch.int64, device=torch.device("cpu")
|
||||
)
|
||||
|
||||
if logger is not None:
|
||||
logger.info(f"data tensor shape: {data_tensor.shape}")
|
||||
logger.info(f"label tensor shape: {label_tensor.shape}")
|
||||
|
||||
|
|
Loading…
Reference in a new issue