import torch import torchvision # type: ignore from data_loader import data_loader import numpy as np def get_the_data( dataset: str, batch_size_train: int, batch_size_test: int, torch_device: torch.device, input_dim_x: int, input_dim_y: int, ) -> tuple[ torch.utils.data.dataloader.DataLoader, torch.utils.data.dataloader.DataLoader, torchvision.transforms.Compose, ]: if dataset == "MNIST": tv_dataset_train = torchvision.datasets.MNIST( root="data", train=True, download=True ) tv_dataset_test = torchvision.datasets.MNIST( root="data", train=False, download=True ) elif dataset == "FashionMNIST": tv_dataset_train = torchvision.datasets.FashionMNIST( root="data", train=True, download=True ) tv_dataset_test = torchvision.datasets.FashionMNIST( root="data", train=False, download=True ) elif dataset == "CIFAR10": tv_dataset_train = torchvision.datasets.CIFAR10( root="data", train=True, download=True ) tv_dataset_test = torchvision.datasets.CIFAR10( root="data", train=False, download=True ) else: raise NotImplementedError("This dataset is not implemented.") def seed_worker(worker_id): worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) torch.random.seed(worker_seed) g = torch.Generator() g.manual_seed(0) if dataset == "MNIST" or dataset == "FashionMNIST": train_dataloader = data_loader( torch_device=torch_device, batch_size=batch_size_train, pattern=tv_dataset_train.data, labels=tv_dataset_train.targets, shuffle=False, worker_init_fn=seed_worker, generator=g, ) test_dataloader = data_loader( torch_device=torch_device, batch_size=batch_size_test, pattern=tv_dataset_test.data, labels=tv_dataset_test.targets, shuffle=False, worker_init_fn=seed_worker, generator=g, ) # Data augmentation filter test_processing_chain = torchvision.transforms.Compose( transforms=[torchvision.transforms.CenterCrop((input_dim_x, input_dim_y))], ) else: train_dataloader = data_loader( torch_device=torch_device, batch_size=batch_size_train, pattern=torch.tensor(tv_dataset_train.data).movedim(-1, 1), labels=torch.tensor(tv_dataset_train.targets), shuffle=False, worker_init_fn=seed_worker, generator=g, ) test_dataloader = data_loader( torch_device=torch_device, batch_size=batch_size_test, pattern=torch.tensor(tv_dataset_test.data).movedim(-1, 1), labels=torch.tensor(tv_dataset_test.targets), shuffle=False, worker_init_fn=seed_worker, generator=g, ) # Data augmentation filter test_processing_chain = torchvision.transforms.Compose( transforms=[torchvision.transforms.CenterCrop((input_dim_x, input_dim_y))], ) return ( train_dataloader, test_dataloader, test_processing_chain, )