Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
atharva-tendle committed May 21, 2021
1 parent bea3562 commit 49b969e
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def load_cifar10(configs):
])

# load datasets, downloading if needed
train_set = CIFAR10('./data/cifar10', train=True, download=True,
train_set = CIFAR10('../data/cifar10', train=True, download=True,
transform=train_transforms)
test_set = CIFAR10('./data/cifar10', train=False, download=True,
test_set = CIFAR10('../data/cifar10', train=False, download=True,
transform=test_transforms)

train_loader = torch.utils.data.DataLoader(
Expand Down Expand Up @@ -73,9 +73,9 @@ def load_cifar100(configs):
])

# load datasets, downloading if needed
train_set = CIFAR100('./data/cifar100', train=True, download=True,
train_set = CIFAR100('../data/cifar100', train=True, download=True,
transform=train_transforms)
test_set = CIFAR100('./data/cifar100', train=False, download=True,
test_set = CIFAR100('../data/cifar100', train=False, download=True,
transform=test_transforms)

train_loader = torch.utils.data.DataLoader(
Expand All @@ -91,7 +91,7 @@ def load_cifar100(configs):

return {'train': train_loader, 'test': test_loader}

def load_cifar100(configs):
def load_svhn(configs):
# transform for the training data
train_transforms = transforms.Compose([
transforms.RandomCrop(32),
Expand All @@ -108,9 +108,9 @@ def load_cifar100(configs):
])

# load datasets, downloading if needed
train_set = SVHN('./data/svhn', split="train", download=True,
train_set = SVHN('../data/svhn', split="train", download=True,
transform=train_transforms)
test_set = SVHN('./data/svhn', split="test", download=True,
test_set = SVHN('../data/svhn', split="test", download=True,
transform=test_transforms)

train_loader = torch.utils.data.DataLoader(
Expand Down Expand Up @@ -141,9 +141,9 @@ def load_fashionmnist(configs):
])

# load datasets, downloading if needed
train_set = FashionMNIST('./data/fashionmnist', train=True, download=True,
train_set = FashionMNIST('../data/fashionmnist', train=True, download=True,
transform=train_transforms)
test_set = FashionMNIST('./data/fashionmnist', train=False, download=True,
test_set = FashionMNIST('../data/fashionmnist', train=False, download=True,
transform=test_transforms)

train_loader = torch.utils.data.DataLoader(
Expand Down Expand Up @@ -173,9 +173,9 @@ def load_mnist(configs):
])

# load datasets, downloading if needed
train_set = MNIST('./data/mnist', train=True, download=True,
train_set = MNIST('../data/mnist', train=True, download=True,
transform=train_transforms)
test_set = MNIST('./data/mnist', train=False, download=True,
test_set = MNIST('../data/mnist', train=False, download=True,
transform=test_transforms)

train_loader = torch.utils.data.DataLoader(
Expand Down

0 comments on commit 49b969e

Please sign in to comment.