Skip to content

Commit

Permalink
Add class augmentations
Browse files Browse the repository at this point in the history
  • Loading branch information
tristandeleu committed Sep 22, 2019
1 parent a2583c2 commit 5ed4a5f
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torchmeta.utils.data import BatchMetaDataLoader
from torchmeta.datasets import Omniglot, MiniImagenet
from torchmeta.toy import Sinusoid
from torchmeta.transforms import ClassSplitter, Categorical
from torchmeta.transforms import ClassSplitter, Categorical, Rotation
from torchvision.transforms import ToTensor, Resize, Compose

from maml.model import ModelConvOmniglot, ModelConvMiniImagenet, ModelMLPSinusoid
Expand Down Expand Up @@ -41,6 +41,7 @@ def main(args):
dataset_transform = ClassSplitter(shuffle=True,
num_train_per_class=args.num_shots,
num_test_per_class=args.num_shots_test)
class_augmentations = [Rotation([90, 180, 270])]
if args.dataset == 'sinusoid':
transform = ToTensor()

Expand All @@ -60,10 +61,12 @@ def main(args):
meta_train_dataset = Omniglot(args.folder, transform=transform,
target_transform=Categorical(args.num_ways),
num_classes_per_task=args.num_ways, meta_train=True,
class_augmentations=class_augmentations,
dataset_transform=dataset_transform, download=True)
meta_val_dataset = Omniglot(args.folder, transform=transform,
target_transform=Categorical(args.num_ways),
num_classes_per_task=args.num_ways, meta_val=True,
class_augmentations=class_augmentations,
dataset_transform=dataset_transform)

model = ModelConvOmniglot(args.num_ways, hidden_size=args.hidden_size)
Expand All @@ -75,10 +78,12 @@ def main(args):
meta_train_dataset = MiniImagenet(args.folder, transform=transform,
target_transform=Categorical(args.num_ways),
num_classes_per_task=args.num_ways, meta_train=True,
class_augmentations=class_augmentations,
dataset_transform=dataset_transform, download=True)
meta_val_dataset = MiniImagenet(args.folder, transform=transform,
target_transform=Categorical(args.num_ways),
num_classes_per_task=args.num_ways, meta_val=True,
class_augmentations=class_augmentations,
dataset_transform=dataset_transform)

model = ModelConvMiniImagenet(args.num_ways, hidden_size=args.hidden_size)
Expand Down

0 comments on commit 5ed4a5f

Please sign in to comment.