diff --git a/functions/alicorn_data_loader.py b/functions/alicorn_data_loader.py index 3e8f1cd..71fb6db 100644 --- a/functions/alicorn_data_loader.py +++ b/functions/alicorn_data_loader.py @@ -8,8 +8,8 @@ def alicorn_data_loader( num_pfinkel: list[int] | None, load_stimuli_per_pfinkel: int, condition: str, - logger, data_path: str, + logger=None, ) -> torch.utils.data.TensorDataset: """ - num_pfinkel: list of the angles that should be loaded (ranging from @@ -52,8 +52,9 @@ def alicorn_data_loader( num_batches = load_stimuli_per_pfinkel // (stimuli_per_pfinkel // len(batch)) num_img_per_pfinkel = load_stimuli_per_pfinkel // num_batches - logger.info(f"{num_batches} batches") - logger.info(f"{num_img_per_pfinkel} stimuli per pfinkel.") + if logger is not None: + logger.info(f"{num_batches} batches") + logger.info(f"{num_img_per_pfinkel} stimuli per pfinkel.") # initialize data and label tensors: num_stimuli: int = len(angle) * num_batches * num_img_per_pfinkel * 2 @@ -64,8 +65,9 @@ def alicorn_data_loader( (num_stimuli), dtype=torch.int64, device=torch.device("cpu") ) - logger.info(f"data tensor shape: {data_tensor.shape}") - logger.info(f"label tensor shape: {label_tensor.shape}") + if logger is not None: + logger.info(f"data tensor shape: {data_tensor.shape}") + logger.info(f"label tensor shape: {label_tensor.shape}") # append data idx: int = 0