forked from dragen1860/MAML-Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
outer.py
175 lines (164 loc) · 8.44 KB
/
outer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import torch
from torch import nn
from torch.nn import functional as F
from torch.multiprocessing import Pool
import numpy as np
from inner import Inner
from copy import deepcopy
class Outer(nn.Module):
"""
Meta learner for the outer loop
"""
def __init__(self, args, config=None):
"""
:param config: network config file, type:list of (string, list)
"""
super(Outer, self).__init__()
self.task_num = args.task_num
self.inner_step = args.inner_step
self.tuning_step = args.tuning_step
self.inner_lr = args.inner_lr
self.class_num = args.class_num
self.train_sample_size_per_class = args.train_sample_size_per_class
self.test_sample_size_per_class = args.test_sample_size_per_class
if args.data_source == 'sinusoid':
self.classification = False
self.loss_func = F.mse_loss
if config is None:
config = [('linear', [40, 1]),
('relu', [True]),
('linear', [40, 40]),
('relu', [True]),
('linear', [1, 40])]
elif args.data_source == 'omniglot':
self.classification = True
self.loss_func = F.cross_entropy
if config is None:
config =[('conv2d', [64, 1, 3, 3, 2, 0]),
('relu', [True]),
('bn', [64]),
('conv2d', [64, 64, 3, 3, 2, 0]),
('relu', [True]),
('bn', [64]),
('conv2d', [64, 64, 3, 3, 2, 0]),
('relu', [True]),
('bn', [64]),
('conv2d', [64, 64, 2, 2, 1, 0]),
('relu', [True]),
('bn', [64]),
('flatten', []),
('linear', [self.class_num, 64])]
elif args.data_source == 'miniimagenet':
self.classification = True
self.loss_func = F.cross_entropy
if config is None:
config = [('conv2d', [32, 3, 3, 3, 1, 0]),
('relu', [True]),
('bn', [32]),
('max_pool2d', [2, 2, 0]),
('conv2d', [32, 32, 3, 3, 1, 0]),
('relu', [True]),
('bn', [32]),
('max_pool2d', [2, 2, 0]),
('conv2d', [32, 32, 3, 3, 1, 0]),
('relu', [True]),
('bn', [32]),
('max_pool2d', [2, 2, 0]),
('conv2d', [32, 32, 3, 3, 1, 0]),
('relu', [True]),
('bn', [32]),
('max_pool2d', [2, 1, 0]),
('flatten', []),
('linear', [self.class_num, 32 * 5 * 5])]
else:
raise NotImplementedError
self.model = Inner(config)
def forward(self, x_train, y_train, x_test):
"""
:param x_train: [task_num, class_num*train_sample_size_per_class, input_size]
:param y_train: [task_num, class_num*train_sample_size_per_class, output_size]
:param x_test: [task_num, class_num*test_sample_size_per_class, input_size]
:return:
"""
output_size = (self.task_num, self.class_num * self.test_sample_size_per_class)
if self.classification:
output_size += (self.class_num,)
output = y_train.new_empty(output_size, dtype=torch.float)
else:
output_size += y_train.shape[2:]
output = y_train.new_empty(output_size)
for i in range(self.task_num):
# the first step of the inner loop
output_inner = self.model(x_train[i], self.model.parameters(), bn_training=True) #logits
loss_inner = self.loss_func(output_inner, y_train[i]) #loss = F.cross_entropy(logits, y_spt[i])
grad = torch.autograd.grad(loss_inner, self.model.parameters())
fast_weights = list(map(lambda p: p[1] - self.inner_lr * p[0], zip(grad, self.model.parameters())))
# the rest of the inner loop
for _ in range(1, self.inner_step):
output_inner = self.model(x_train[i], fast_weights, bn_training=True)
loss_inner = self.loss_func(output_inner, y_train[i])
grad = torch.autograd.grad(loss_inner, fast_weights)
fast_weights = list(map(lambda p: p[1] - self.inner_lr * p[0], zip(grad, fast_weights)))
# calculate output of the outer loop
output[i] = self.model(x_test[i], fast_weights, bn_training=True)
return output
def fine_tuning(self, x_train, y_train, x_test, y_test):
"""
:param x_train: [task_num, class_num*train_sample_size_per_class, input_size]
:param y_train: [task_num, class_num*train_sample_size_per_class, output_size]
:param x_test: [task_num, class_num*test_sample_size_per_class, input_size]
:param y_test: [task_num, class_num*test_sample_size_per_class, output_size]
:return:
"""
task_num = x_train.shape[0]
# record the first training error and all-step test errors during fine-tuning
loss_summary = np.zeros(self.tuning_step + 2)
if self.classification:
accuracy_summary = np.zeros(self.tuning_step + 2)
model = deepcopy(self.model)
for i in range(task_num):
# the first step of the inner loop
output_inner = model(x_train[i], model.parameters(), bn_training=True) #logits
loss_inner = self.loss_func(output_inner, y_train[i]) #loss = F.cross_entropy(logits, y_spt[i])
## train error before update
loss_summary[0] += loss_inner.item()
grad = torch.autograd.grad(loss_inner, model.parameters())
fast_weights = list(map(lambda p: p[1] - self.inner_lr * p[0], zip(grad, model.parameters())))
if self.classification:
with torch.no_grad():
y_pred_inner = F.softmax(output_inner, dim=1).argmax(dim=1)
accuracy_summary[0] += torch.eq(y_pred_inner, y_train[i]).to(torch.float).mean().item()
## test error before update
with torch.no_grad():
output_test = model(x_test[i], model.parameters(), bn_training=True)
loss_test = self.loss_func(output_test, y_test[i])
loss_summary[1] += loss_test.item()
if self.classification:
y_pred_test = F.softmax(output_test, dim=1).argmax(dim=1)
accuracy_summary[1] += torch.eq(y_pred_test, y_test[i]).to(torch.float).mean().item()
## test error after the first update
with torch.no_grad():
output_test = model(x_test[i], fast_weights, bn_training=True)
loss_test = self.loss_func(output_test, y_test[i])
loss_summary[2] += loss_test.item()
if self.classification:
y_pred_test = F.softmax(output_test, dim=1).argmax(dim=1)
accuracy_summary[2] += torch.eq(y_pred_test, y_test[i]).to(torch.float).mean().item()
# the rest of the inner loop
for j in range(1, self.tuning_step):
output_inner = model(x_train[i], fast_weights, bn_training=True)
loss_inner = self.loss_func(output_inner, y_train[i])
grad = torch.autograd.grad(loss_inner, fast_weights)
fast_weights = list(map(lambda p: p[1] - self.inner_lr * p[0], zip(grad, fast_weights)))
with torch.no_grad():
output_test = model(x_test[i], fast_weights, bn_training=True)
loss_test = self.loss_func(output_test, y_test[i])
loss_summary[j + 2] += loss_test.item()
if self.classification:
y_pred_test = F.softmax(output_test, dim=1).argmax(dim=1)
accuracy_summary[j + 2] += torch.eq(y_pred_test, y_test[i]).to(torch.float).mean().item()
del model
if self.classification:
return loss_summary/task_num, accuracy_summary/task_num
else:
return loss_summary/task_num