import torch from torch.utils.data import Dataset from torchvision.datasets import MNIST, FashionMNIST, CIFAR10 from torchvision import transforms class DiffSet(Dataset): def __init__(self, train, dataset="MNIST"): transform = transforms.Compose([transforms.ToTensor()]) datasets = { "MNIST": MNIST, "Fashion": FashionMNIST, "CIFAR": CIFAR10, } train_dataset = datasets[dataset]( "./data", download=True, train=train, transform=transform ) self.dataset_len = len(train_dataset.data) if dataset == "MNIST" or dataset == "Fashion": pad = transforms.Pad(2) data = pad(train_dataset.data) data = data.unsqueeze(3) self.depth = 1 self.size = 32 elif dataset == "CIFAR": data = torch.Tensor(train_dataset.data) self.depth = 3 self.size = 32 self.input_seq = ((data / 255.0) * 2.0) - 1.0 self.input_seq = self.input_seq.moveaxis(3, 1) def __len__(self): return self.dataset_len def __getitem__(self, item): return self.input_seq[item]