Bernstein_Poster_2024/basis_mlp_x16/get_the_data_picture.py

111 lines
3.3 KiB
Python
Raw Permalink Normal View History

2024-11-05 18:20:02 +01:00
import torch
import torchvision # type: ignore
from data_loader import data_loader
import numpy as np
def get_the_data(
dataset: str,
batch_size_train: int,
batch_size_test: int,
torch_device: torch.device,
input_dim_x: int,
input_dim_y: int,
) -> tuple[
torch.utils.data.dataloader.DataLoader,
torch.utils.data.dataloader.DataLoader,
torchvision.transforms.Compose,
]:
if dataset == "MNIST":
tv_dataset_train = torchvision.datasets.MNIST(
root="data", train=True, download=True
)
tv_dataset_test = torchvision.datasets.MNIST(
root="data", train=False, download=True
)
elif dataset == "FashionMNIST":
tv_dataset_train = torchvision.datasets.FashionMNIST(
root="data", train=True, download=True
)
tv_dataset_test = torchvision.datasets.FashionMNIST(
root="data", train=False, download=True
)
elif dataset == "CIFAR10":
tv_dataset_train = torchvision.datasets.CIFAR10(
root="data", train=True, download=True
)
tv_dataset_test = torchvision.datasets.CIFAR10(
root="data", train=False, download=True
)
else:
raise NotImplementedError("This dataset is not implemented.")
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
torch.random.seed(worker_seed)
g = torch.Generator()
g.manual_seed(0)
if dataset == "MNIST" or dataset == "FashionMNIST":
train_dataloader = data_loader(
torch_device=torch_device,
batch_size=batch_size_train,
pattern=tv_dataset_train.data,
labels=tv_dataset_train.targets,
shuffle=False,
worker_init_fn=seed_worker,
generator=g,
)
test_dataloader = data_loader(
torch_device=torch_device,
batch_size=batch_size_test,
pattern=tv_dataset_test.data,
labels=tv_dataset_test.targets,
shuffle=False,
worker_init_fn=seed_worker,
generator=g,
)
# Data augmentation filter
test_processing_chain = torchvision.transforms.Compose(
transforms=[torchvision.transforms.CenterCrop((input_dim_x, input_dim_y))],
)
else:
train_dataloader = data_loader(
torch_device=torch_device,
batch_size=batch_size_train,
pattern=torch.tensor(tv_dataset_train.data).movedim(-1, 1),
labels=torch.tensor(tv_dataset_train.targets),
shuffle=False,
worker_init_fn=seed_worker,
generator=g,
)
test_dataloader = data_loader(
torch_device=torch_device,
batch_size=batch_size_test,
pattern=torch.tensor(tv_dataset_test.data).movedim(-1, 1),
labels=torch.tensor(tv_dataset_test.targets),
shuffle=False,
worker_init_fn=seed_worker,
generator=g,
)
# Data augmentation filter
test_processing_chain = torchvision.transforms.Compose(
transforms=[torchvision.transforms.CenterCrop((input_dim_x, input_dim_y))],
)
return (
train_dataloader,
test_dataloader,
test_processing_chain,
)