pytorch-sbs/network/DatasetMix.py
2023-02-21 14:37:51 +01:00

90 lines
2.6 KiB
Python

import torch
from network.Dataset import DatasetMNIST, DatasetFashionMNIST, DatasetCIFAR
import math
class DatasetMNISTMix(DatasetMNIST):
def __init__(
self,
train: bool = False,
path_pattern: str = "./",
path_label: str = "./",
alpha: float = 1.0,
) -> None:
super().__init__(train, path_pattern, path_label)
self.alpha = alpha
def __getitem__(self, index: int) -> tuple[torch.Tensor, list[int]]: # type: ignore
assert self.alpha >= 0.0
assert self.alpha <= 1.0
image_a, target_a = super().__getitem__(index)
target_b: int = target_a
while target_b == target_a:
image_b, target_b = super().__getitem__(
int(math.floor(self.number_of_pattern * torch.rand((1)).item()))
)
image = self.alpha * image_a + (1.0 - self.alpha) * image_b
target = [target_a, target_b]
return image, target
class DatasetFashionMNISTMix(DatasetFashionMNIST):
def __init__(
self,
train: bool = False,
path_pattern: str = "./",
path_label: str = "./",
alpha: float = 1.0,
) -> None:
super().__init__(train, path_pattern, path_label)
self.alpha = alpha
def __getitem__(self, index: int) -> tuple[torch.Tensor, list[int]]: # type: ignore
assert self.alpha >= 0.0
assert self.alpha <= 1.0
image_a, target_a = super().__getitem__(index)
target_b: int = target_a
while target_b == target_a:
image_b, target_b = super().__getitem__(
int(math.floor(self.number_of_pattern * torch.rand((1)).item()))
)
image = self.alpha * image_a + (1.0 - self.alpha) * image_b
target = [target_a, target_b]
return image, target
class DatasetCIFARMix(DatasetCIFAR):
def __init__(
self,
train: bool = False,
path_pattern: str = "./",
path_label: str = "./",
alpha: float = 1.0,
) -> None:
super().__init__(train, path_pattern, path_label)
self.alpha = alpha
def __getitem__(self, index: int) -> tuple[torch.Tensor, list[int]]: # type: ignore
assert self.alpha >= 0.0
assert self.alpha <= 1.0
image_a, target_a = super().__getitem__(index)
target_b: int = target_a
while target_b == target_a:
image_b, target_b = super().__getitem__(
int(math.floor(self.number_of_pattern * torch.rand((1)).item()))
)
image = self.alpha * image_a + (1.0 - self.alpha) * image_b
target = [target_a, target_b]
return image, target