Add files via upload

This commit is contained in:
David Rotermund 2023-07-21 11:08:11 +02:00 committed by GitHub
parent fbc4516e58
commit 34a688d546
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -8,8 +8,8 @@ def alicorn_data_loader(
num_pfinkel: list[int] | None, num_pfinkel: list[int] | None,
load_stimuli_per_pfinkel: int, load_stimuli_per_pfinkel: int,
condition: str, condition: str,
logger,
data_path: str, data_path: str,
logger=None,
) -> torch.utils.data.TensorDataset: ) -> torch.utils.data.TensorDataset:
""" """
- num_pfinkel: list of the angles that should be loaded (ranging from - num_pfinkel: list of the angles that should be loaded (ranging from
@ -52,8 +52,9 @@ def alicorn_data_loader(
num_batches = load_stimuli_per_pfinkel // (stimuli_per_pfinkel // len(batch)) num_batches = load_stimuli_per_pfinkel // (stimuli_per_pfinkel // len(batch))
num_img_per_pfinkel = load_stimuli_per_pfinkel // num_batches num_img_per_pfinkel = load_stimuli_per_pfinkel // num_batches
logger.info(f"{num_batches} batches") if logger is not None:
logger.info(f"{num_img_per_pfinkel} stimuli per pfinkel.") logger.info(f"{num_batches} batches")
logger.info(f"{num_img_per_pfinkel} stimuli per pfinkel.")
# initialize data and label tensors: # initialize data and label tensors:
num_stimuli: int = len(angle) * num_batches * num_img_per_pfinkel * 2 num_stimuli: int = len(angle) * num_batches * num_img_per_pfinkel * 2
@ -64,8 +65,9 @@ def alicorn_data_loader(
(num_stimuli), dtype=torch.int64, device=torch.device("cpu") (num_stimuli), dtype=torch.int64, device=torch.device("cpu")
) )
logger.info(f"data tensor shape: {data_tensor.shape}") if logger is not None:
logger.info(f"label tensor shape: {label_tensor.shape}") logger.info(f"data tensor shape: {data_tensor.shape}")
logger.info(f"label tensor shape: {label_tensor.shape}")
# append data # append data
idx: int = 0 idx: int = 0