Skip to content

Commit

Permalink
added SVHN
Browse files Browse the repository at this point in the history
  • Loading branch information
atharva-tendle committed May 20, 2021
1 parent 2fac474 commit bea3562
Showing 1 changed file with 39 additions and 1 deletion.
40 changes: 39 additions & 1 deletion utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.datasets import CIFAR10, CIFAR100, MNIST, FashionMNIST
from torchvision.datasets import CIFAR10, CIFAR100, MNIST, FashionMNIST, SVHN


def load_dataset(configs):
Expand All @@ -12,6 +12,8 @@ def load_dataset(configs):
return load_mnist(configs)
elif dataset == "fashionmnist":
return load_fashionmnist(configs)
elif dataset == "svhn":
return load_svhn(configs)
elif dataset == "cifar-10":
return load_cifar10(configs)
elif dataset == "cifar-100":
Expand Down Expand Up @@ -89,6 +91,42 @@ def load_cifar100(configs):

return {'train': train_loader, 'test': test_loader}

def load_cifar100(configs):

This comment has been minimized.

Copy link
@ayushm-agrawal

ayushm-agrawal May 21, 2021

Collaborator

Change the method name to load_svhn(configs)

This comment has been minimized.

Copy link
@atharva-tendle

atharva-tendle May 21, 2021

Author Collaborator

Oops. I should've double-checked!

# transform for the training data
train_transforms = transforms.Compose([
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])

test_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])

# load datasets, downloading if needed
train_set = SVHN('./data/svhn', split="train", download=True,

This comment has been minimized.

Copy link
@ayushm-agrawal

ayushm-agrawal May 21, 2021

Collaborator

change this to ../data/svhn.

We don't want to add data to our repository

This comment has been minimized.

Copy link
@atharva-tendle

atharva-tendle May 21, 2021

Author Collaborator

Yep! good catch.

transform=train_transforms)
test_set = SVHN('./data/svhn', split="test", download=True,
transform=test_transforms)

train_loader = torch.utils.data.DataLoader(
train_set, batch_size=configs.batch_size, num_workers=0)
test_loader = torch.utils.data.DataLoader(
test_set, batch_size=configs.batch_size, num_workers=0)

print('Number of iterations required to get through training data of length {}: {}'.format(
len(train_set), len(train_loader)))

print(train_set.data.shape)
print(test_set.data.shape)

return {'train': train_loader, 'test': test_loader}


def load_fashionmnist(configs):
# transform for the training data
train_transforms = transforms.Compose([
Expand Down

1 comment on commit bea3562

@ayushm-agrawal
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updates Required

  • Change the method name to load_svhn(configs)
  • The data download path on 111 and 112 should be ../data/svhn

Looks good otherwise

Please sign in to comment.