69 lines
2.2 KiB
Python
69 lines
2.2 KiB
Python
|
# %%
|
||
|
import torch
|
||
|
from network.Dataset import (
|
||
|
DatasetMaster,
|
||
|
DatasetCIFAR,
|
||
|
DatasetMNIST,
|
||
|
DatasetFashionMNIST,
|
||
|
)
|
||
|
from network.Parameter import Config
|
||
|
|
||
|
|
||
|
def build_datasets(
|
||
|
cfg: Config,
|
||
|
) -> tuple[
|
||
|
DatasetMaster,
|
||
|
DatasetMaster,
|
||
|
torch.utils.data.DataLoader,
|
||
|
torch.utils.data.DataLoader,
|
||
|
]:
|
||
|
|
||
|
# Load the input data
|
||
|
the_dataset_train: DatasetMaster
|
||
|
the_dataset_test: DatasetMaster
|
||
|
if cfg.data_mode == "CIFAR10":
|
||
|
the_dataset_train = DatasetCIFAR(
|
||
|
train=True, path_pattern=cfg.data_path, path_label=cfg.data_path
|
||
|
)
|
||
|
the_dataset_test = DatasetCIFAR(
|
||
|
train=False, path_pattern=cfg.data_path, path_label=cfg.data_path
|
||
|
)
|
||
|
elif cfg.data_mode == "MNIST":
|
||
|
the_dataset_train = DatasetMNIST(
|
||
|
train=True, path_pattern=cfg.data_path, path_label=cfg.data_path
|
||
|
)
|
||
|
the_dataset_test = DatasetMNIST(
|
||
|
train=False, path_pattern=cfg.data_path, path_label=cfg.data_path
|
||
|
)
|
||
|
elif cfg.data_mode == "MNIST_FASHION":
|
||
|
the_dataset_train = DatasetFashionMNIST(
|
||
|
train=True, path_pattern=cfg.data_path, path_label=cfg.data_path
|
||
|
)
|
||
|
the_dataset_test = DatasetFashionMNIST(
|
||
|
train=False, path_pattern=cfg.data_path, path_label=cfg.data_path
|
||
|
)
|
||
|
else:
|
||
|
raise Exception("data_mode unknown")
|
||
|
|
||
|
if len(cfg.image_statistics.mean) == 0:
|
||
|
cfg.image_statistics.mean = the_dataset_train.mean
|
||
|
|
||
|
# The basic size
|
||
|
cfg.image_statistics.the_size = [
|
||
|
the_dataset_train.pattern_storage.shape[2],
|
||
|
the_dataset_train.pattern_storage.shape[3],
|
||
|
]
|
||
|
|
||
|
# Minus the stuff we cut away in the pattern filter
|
||
|
cfg.image_statistics.the_size[0] -= 2 * cfg.augmentation.crop_width_in_pixel
|
||
|
cfg.image_statistics.the_size[1] -= 2 * cfg.augmentation.crop_width_in_pixel
|
||
|
|
||
|
my_loader_test: torch.utils.data.DataLoader = torch.utils.data.DataLoader(
|
||
|
the_dataset_test, batch_size=cfg.batch_size, shuffle=False
|
||
|
)
|
||
|
my_loader_train: torch.utils.data.DataLoader = torch.utils.data.DataLoader(
|
||
|
the_dataset_train, batch_size=cfg.batch_size, shuffle=True
|
||
|
)
|
||
|
|
||
|
return the_dataset_train, the_dataset_test, my_loader_test, my_loader_train
|