Skip to content

Commit

Permalink
bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ayushm-agrawal committed Apr 17, 2021
1 parent b9b5571 commit 8127b28
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def load_mnist(configs):
def load_imagenet(configs):
# transform for the training data
train_transforms = transforms.Compose([
transforms.RandomCrop(224, padding=4),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
Expand All @@ -171,8 +171,8 @@ def load_imagenet(configs):

train_data_path = configs.data_path + "/train/"
val_data_path = configs.data_path + "/val/"
train_set = datasets.ImageFolder(train_data_path, transforms=train_transforms)
val_set = datasets.ImageFolder(val_data_path, transforms=val_transforms)
train_set = datasets.ImageFolder(train_data_path, transform=train_transforms)
val_set = datasets.ImageFolder(val_data_path, transform=val_transforms)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=configs.batch_size,
num_workers=configs.num_workers, shuffle=True)
Expand All @@ -182,7 +182,5 @@ def load_imagenet(configs):
print('Number of iterations required to get through training data of length {}: {}'.format(
len(train_set), len(train_loader)))

print(train_set.data.shape)
print(val_set.data.shape)

return {'train': train_loader, 'test': val_loader}

0 comments on commit 8127b28

Please sign in to comment.