-
Notifications
You must be signed in to change notification settings - Fork 180
/
make_antdirec_plots.py
60 lines (43 loc) · 1.9 KB
/
make_antdirec_plots.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import pickle
import glob
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
#names = ['maml','maml0','random','oracle']
prefix = 'icml_antdirec_results_'
oracle_pkl = prefix+'oracle.pkl'
maml_pkl = prefix+'maml.pkl'
pretrain_pkl = prefix+'pretrain.pkl'
random_pkl = prefix+'random.pkl'
key = 'task_avg_returns'
n_itr = 4
with open(oracle_pkl, 'rb') as f:
oracle_data = np.array(pickle.load(f)[key])[0]
oracle_data = np.reshape(oracle_data, [-1, 1])
oracle_data = np.tile(oracle_data[:,0:1], [1,n_itr])
fig = plt.figure()
plt.clf()
with open(maml_pkl, 'rb') as maml_f:
maml_data = np.array(pickle.load(maml_f)[key]).T[:,:n_itr]
with open(pretrain_pkl, 'rb') as f:
pretrain_data = np.array(pickle.load(f)[key]).T[:,:n_itr]
with open(random_pkl, 'rb') as f:
random_data = np.array(pickle.load(f)[key]).T[:,:n_itr]
sns.tsplot(time=range(n_itr), data=maml_data[:,:n_itr], color='g', linestyle='-', marker='o', condition='MAML (ours)', legend=False)
sns.tsplot(time=range(n_itr), data=pretrain_data[:,:n_itr], color='b', linestyle='--', marker='s', condition='pretrained', legend=False)
sns.tsplot(time=range(n_itr), data=random_data[:,:n_itr], color='k', linestyle=':', marker='^', condition='random', legend=False)
sns.tsplot(time=range(n_itr), data=oracle_data[:,:n_itr], color='r', linestyle='-.', marker='v', condition='oracle', legend=False)
ax = fig.gca()
#ax.set(yscale='symlog')
#plt.ylim([-100,-2.0])
plt.xlabel('number of gradient steps', fontsize=26)
plt.ylabel('average return', fontsize=26)
#lgd=plt.legend(['MAML (ours)', 'pretrained', 'random', 'oracle'], loc=0, bbox_to_anchor=(1, 0.5), fontsize=20)
plt.title('ant, forward/backward', fontsize=26)
#plt.ylim([-0.04, 3.5])
plt.tight_layout()
ax = plt.gca()
plt.setp(ax.get_xticklabels(), fontsize=18)
plt.setp(ax.get_yticklabels(), fontsize=18)
plt.xticks(np.arange(0,4,1.0))
plt.savefig('antdirec_results.png', bbox_inches='tight')