78 lines
2.2 KiB
Python
78 lines
2.2 KiB
Python
import torch
|
|
import torchvision # type: ignore
|
|
from data_loader import data_loader
|
|
|
|
import numpy as np
|
|
|
|
|
|
def get_the_data(
|
|
dataset: str,
|
|
batch_size_test: int,
|
|
torch_device: torch.device,
|
|
input_dim_x: int,
|
|
input_dim_y: int,
|
|
) -> tuple[
|
|
torch.utils.data.dataloader.DataLoader,
|
|
torchvision.transforms.Compose,
|
|
]:
|
|
if dataset == "MNIST":
|
|
tv_dataset_test = torchvision.datasets.MNIST(
|
|
root="data", train=False, download=True
|
|
)
|
|
elif dataset == "FashionMNIST":
|
|
tv_dataset_test = torchvision.datasets.FashionMNIST(
|
|
root="data", train=False, download=True
|
|
)
|
|
elif dataset == "CIFAR10":
|
|
tv_dataset_test = torchvision.datasets.CIFAR10(
|
|
root="data", train=False, download=True
|
|
)
|
|
else:
|
|
raise NotImplementedError("This dataset is not implemented.")
|
|
|
|
def seed_worker(worker_id):
|
|
worker_seed = torch.initial_seed() % 2**32
|
|
np.random.seed(worker_seed)
|
|
torch.random.seed(worker_seed)
|
|
|
|
g = torch.Generator()
|
|
g.manual_seed(0)
|
|
|
|
if dataset == "MNIST" or dataset == "FashionMNIST":
|
|
|
|
test_dataloader = data_loader(
|
|
torch_device=torch_device,
|
|
batch_size=batch_size_test,
|
|
pattern=tv_dataset_test.data,
|
|
labels=tv_dataset_test.targets,
|
|
shuffle=False,
|
|
worker_init_fn=seed_worker,
|
|
generator=g,
|
|
)
|
|
|
|
# Data augmentation filter
|
|
test_processing_chain = torchvision.transforms.Compose(
|
|
transforms=[torchvision.transforms.CenterCrop((input_dim_x, input_dim_y))],
|
|
)
|
|
|
|
else:
|
|
|
|
test_dataloader = data_loader(
|
|
torch_device=torch_device,
|
|
batch_size=batch_size_test,
|
|
pattern=torch.tensor(tv_dataset_test.data).movedim(-1, 1),
|
|
labels=torch.tensor(tv_dataset_test.targets),
|
|
shuffle=False,
|
|
worker_init_fn=seed_worker,
|
|
generator=g,
|
|
)
|
|
|
|
# Data augmentation filter
|
|
test_processing_chain = torchvision.transforms.Compose(
|
|
transforms=[torchvision.transforms.CenterCrop((input_dim_x, input_dim_y))],
|
|
)
|
|
|
|
return (
|
|
test_dataloader,
|
|
test_processing_chain,
|
|
)
|