-
Notifications
You must be signed in to change notification settings - Fork 42
/
datasets.py
108 lines (94 loc) · 5.34 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch.nn.functional as F
from collections import namedtuple
from torchmeta.datasets import Omniglot, MiniImagenet
from torchmeta.toy import Sinusoid
from torchmeta.transforms import ClassSplitter, Categorical, Rotation
from torchvision.transforms import ToTensor, Resize, Compose
from maml.model import ModelConvOmniglot, ModelConvMiniImagenet, ModelMLPSinusoid
from maml.utils import ToTensor1D
Benchmark = namedtuple('Benchmark', 'meta_train_dataset meta_val_dataset '
'meta_test_dataset model loss_function')
def get_benchmark_by_name(name,
folder,
num_ways,
num_shots,
num_shots_test,
hidden_size=None):
dataset_transform = ClassSplitter(shuffle=True,
num_train_per_class=num_shots,
num_test_per_class=num_shots_test)
if name == 'sinusoid':
transform = ToTensor1D()
meta_train_dataset = Sinusoid(num_shots + num_shots_test,
num_tasks=1000000,
transform=transform,
target_transform=transform,
dataset_transform=dataset_transform)
meta_val_dataset = Sinusoid(num_shots + num_shots_test,
num_tasks=1000000,
transform=transform,
target_transform=transform,
dataset_transform=dataset_transform)
meta_test_dataset = Sinusoid(num_shots + num_shots_test,
num_tasks=1000000,
transform=transform,
target_transform=transform,
dataset_transform=dataset_transform)
model = ModelMLPSinusoid(hidden_sizes=[40, 40])
loss_function = F.mse_loss
elif name == 'omniglot':
class_augmentations = [Rotation([90, 180, 270])]
transform = Compose([Resize(28), ToTensor()])
meta_train_dataset = Omniglot(folder,
transform=transform,
target_transform=Categorical(num_ways),
num_classes_per_task=num_ways,
meta_train=True,
class_augmentations=class_augmentations,
dataset_transform=dataset_transform,
download=True)
meta_val_dataset = Omniglot(folder,
transform=transform,
target_transform=Categorical(num_ways),
num_classes_per_task=num_ways,
meta_val=True,
class_augmentations=class_augmentations,
dataset_transform=dataset_transform)
meta_test_dataset = Omniglot(folder,
transform=transform,
target_transform=Categorical(num_ways),
num_classes_per_task=num_ways,
meta_test=True,
dataset_transform=dataset_transform)
model = ModelConvOmniglot(num_ways, hidden_size=hidden_size)
loss_function = F.cross_entropy
elif name == 'miniimagenet':
transform = Compose([Resize(84), ToTensor()])
meta_train_dataset = MiniImagenet(folder,
transform=transform,
target_transform=Categorical(num_ways),
num_classes_per_task=num_ways,
meta_train=True,
dataset_transform=dataset_transform,
download=True)
meta_val_dataset = MiniImagenet(folder,
transform=transform,
target_transform=Categorical(num_ways),
num_classes_per_task=num_ways,
meta_val=True,
dataset_transform=dataset_transform)
meta_test_dataset = MiniImagenet(folder,
transform=transform,
target_transform=Categorical(num_ways),
num_classes_per_task=num_ways,
meta_test=True,
dataset_transform=dataset_transform)
model = ModelConvMiniImagenet(num_ways, hidden_size=hidden_size)
loss_function = F.cross_entropy
else:
raise NotImplementedError('Unknown dataset `{0}`.'.format(name))
return Benchmark(meta_train_dataset=meta_train_dataset,
meta_val_dataset=meta_val_dataset,
meta_test_dataset=meta_test_dataset,
model=model,
loss_function=loss_function)