-
Notifications
You must be signed in to change notification settings - Fork 1
/
valid_resnet.py
113 lines (92 loc) · 3.95 KB
/
valid_resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
import numpy as np
import scipy
import os
from scipy.spatial.distance import cdist
import scipy.io
import sys
try:
import cPickle
except:
import _pickle as cPickle
# Syspath for the folder with the utils files
# sys.path.insert(0, "/data/sylvestre")
import utils_data
import utils_cifar
######### Modifiable Settings ##########
batch_size = 128 # Batch size
nb_cl = 100 # Classes per group
nb_groups = 10 # Number of groups
top = 5 # Choose to evaluate the top X accuracy
is_cumul = 'cumul' # Evaluate on the cumul of classes if 'cumul', otherwise on the first classes
gpu = '0' # Used GPU
########################################
######### Paths ##########
# Working station
train_path = '/data/datasets/imagenets72'
save_path = '/data/srebuffi/backup/'
###########################
# Load ResNet settings
str_mixing = str(nb_cl) + 'mixing.pickle'
with open(str_mixing, 'rb') as fp:
mixing = cPickle.load(fp)
str_settings_resnet = str(nb_cl) + 'settings_resnet.pickle'
with open(str_settings_resnet, 'rb') as fp:
order = cPickle.load(fp)
files_valid = cPickle.load(fp)
files_train = cPickle.load(fp)
# Load class means
str_class_means = str(nb_cl) + 'class_means.pickle'
with open(str_class_means, 'rb') as fp:
class_means = cPickle.load(fp)
# Loading the labels
labels_dic, label_names, validation_ground_truth = utils_data.parse_devkit_meta(devkit_path)
# Initialization
acc_list = np.zeros((nb_groups, 3))
for itera in range(nb_groups):
print("Processing network after {} increments\t".format(itera))
# Evaluation on cumul of classes or original classes
if is_cumul == 'cumul':
eval_groups = np.array(range(itera + 1))
else:
eval_groups = [0]
print("Evaluation on batches {} \t".format(eval_groups))
# Load the evaluation files
files_from_cl = []
for i in eval_groups:
files_from_cl.extend(files_valid[i])
inits, scores, label_batch, loss_class, file_string_batch, op_feature_map = utils_icarl.reading_data_and_preparing_network(
files_from_cl, gpu, itera, batch_size, train_path, labels_dic, mixing, nb_groups, nb_cl, save_path)
with tf.Session(config=config) as sess:
# Launch the prefetch system
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
sess.run(inits)
# Evaluation routine
stat_hb1 = []
stat_icarl = []
stat_ncm = []
for i in range(int(np.ceil(len(files_from_cl) / batch_size))):
sc, l, loss, files_tmp, feat_map_tmp = sess.run(
[scores, label_batch, loss_class, file_string_batch, op_feature_map])
mapped_prototypes = feat_map_tmp[:, 0, 0, :]
pred_inter = (mapped_prototypes.T) / np.linalg.norm(mapped_prototypes.T, axis=0)
sqd_icarl = -cdist(class_means[:, :, 0, itera].T, pred_inter.T, 'sqeuclidean').T
sqd_ncm = -cdist(class_means[:, :, 1, itera].T, pred_inter.T, 'sqeuclidean').T
stat_hb1 += ([ll in best for ll, best in zip(l, np.argsort(sc, axis=1)[:, -top:])])
stat_icarl += ([ll in best for ll, best in zip(l, np.argsort(sqd_icarl, axis=1)[:, -top:])])
stat_ncm += ([ll in best for ll, best in zip(l, np.argsort(sqd_ncm, axis=1)[:, -top:])])
coord.request_stop()
coord.join(threads)
print('Increment: %i' % itera)
print('Hybrid 1 top ' + str(top) + ' accuracy: %f' % np.average(stat_hb1))
print('iCaRL top ' + str(top) + ' accuracy: %f' % np.average(stat_icarl))
print('NCM top ' + str(top) + ' accuracy: %f' % np.average(stat_ncm))
acc_list[itera, 0] = np.average(stat_icarl)
acc_list[itera, 1] = np.average(stat_hb1)
acc_list[itera, 2] = np.average(stat_ncm)
# Reset the graph to compute the numbers ater the next increment
tf.reset_default_graph()
np.save('results_top' + str(top) + '_acc_' + is_cumul + '_cl' + str(nb_cl), acc_list)