28 lines
764 B
Python
28 lines
764 B
Python
|
import torch
|
||
|
|
||
|
|
||
|
def data_loader(
|
||
|
pattern: torch.Tensor,
|
||
|
labels: torch.Tensor,
|
||
|
batch_size: int = 128,
|
||
|
shuffle: bool = True,
|
||
|
torch_device: torch.device = torch.device("cpu"),
|
||
|
) -> torch.utils.data.dataloader.DataLoader:
|
||
|
|
||
|
assert pattern.ndim >= 3
|
||
|
|
||
|
pattern_storage: torch.Tensor = pattern.to(torch_device).type(torch.float32)
|
||
|
if pattern_storage.ndim == 3:
|
||
|
pattern_storage = pattern_storage.unsqueeze(1)
|
||
|
pattern_storage /= pattern_storage.max()
|
||
|
|
||
|
label_storage: torch.Tensor = labels.to(torch_device).type(torch.int64)
|
||
|
|
||
|
dataloader = torch.utils.data.DataLoader(
|
||
|
torch.utils.data.TensorDataset(pattern_storage, label_storage),
|
||
|
batch_size=batch_size,
|
||
|
shuffle=shuffle,
|
||
|
)
|
||
|
|
||
|
return dataloader
|