Add files via upload

This commit is contained in:
David Rotermund 2023-07-21 11:08:11 +02:00 committed by GitHub
parent fbc4516e58
commit 34a688d546
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -8,8 +8,8 @@ def alicorn_data_loader(
num_pfinkel: list[int] | None, num_pfinkel: list[int] | None,
load_stimuli_per_pfinkel: int, load_stimuli_per_pfinkel: int,
condition: str, condition: str,
logger,
data_path: str, data_path: str,
logger=None,
) -> torch.utils.data.TensorDataset: ) -> torch.utils.data.TensorDataset:
""" """
- num_pfinkel: list of the angles that should be loaded (ranging from - num_pfinkel: list of the angles that should be loaded (ranging from
@ -52,6 +52,7 @@ def alicorn_data_loader(
num_batches = load_stimuli_per_pfinkel // (stimuli_per_pfinkel // len(batch)) num_batches = load_stimuli_per_pfinkel // (stimuli_per_pfinkel // len(batch))
num_img_per_pfinkel = load_stimuli_per_pfinkel // num_batches num_img_per_pfinkel = load_stimuli_per_pfinkel // num_batches
if logger is not None:
logger.info(f"{num_batches} batches") logger.info(f"{num_batches} batches")
logger.info(f"{num_img_per_pfinkel} stimuli per pfinkel.") logger.info(f"{num_img_per_pfinkel} stimuli per pfinkel.")
@ -64,6 +65,7 @@ def alicorn_data_loader(
(num_stimuli), dtype=torch.int64, device=torch.device("cpu") (num_stimuli), dtype=torch.int64, device=torch.device("cpu")
) )
if logger is not None:
logger.info(f"data tensor shape: {data_tensor.shape}") logger.info(f"data tensor shape: {data_tensor.shape}")
logger.info(f"label tensor shape: {label_tensor.shape}") logger.info(f"label tensor shape: {label_tensor.shape}")