-
Notifications
You must be signed in to change notification settings - Fork 1
/
valid_resnet2.py
109 lines (88 loc) · 3.75 KB
/
valid_resnet2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
import numpy as np
import scipy
import os
from scipy.spatial.distance import cdist
import scipy.io
import sys
try:
import cPickle
except:
import _pickle as cPickle
# Syspath for the folder with the utils files
# sys.path.insert(0, "/data/sylvestre")
import utils_data
import utils_cifar
######### Modifiable Settings ##########
batch_size = 128 # Batch size
nb_cl = 10 # Classes per group
nb_groups = 10 # Number of groups
top = 5 # Choose to evaluate the top X accuracy
itera = 9 # Choose the state of the network : 0 correspond to the first batch of classes
eval_groups = np.array(range(itera + 1)) # List indicating on which batches of classes to evaluate the classifier
gpu = '0' # Used GPU
########################################
######### Paths ##########
# Working station
train_path = '/ssd_disk/ILSVRC2012/train'
save_path = '/media/data/srebuffi/'
###########################
# Load ResNet settings
str_mixing = str(nb_cl) + 'mixing.pickle'
with open(str_mixing, 'rb') as fp:
mixing = cPickle.load(fp)
str_settings_resnet = str(nb_cl) + 'settings_resnet.pickle'
with open(str_settings_resnet, 'rb') as fp:
order = cPickle.load(fp)
files_valid = cPickle.load(fp)
files_train = cPickle.load(fp)
# Load class means
str_class_means = str(nb_cl) + 'class_means.pickle'
with open(str_class_means, 'rb') as fp:
class_means = cPickle.load(fp)
# Loading the labels
labels_dic, label_names, validation_ground_truth = utils_data.parse_devkit_meta(devkit_path)
# Initialization
acc_list = np.zeros((nb_groups, 3))
# Load the evaluation files
print("Processing network after {} increments\t".format(itera))
print("Evaluation on batches {} \t".format(eval_groups))
files_from_cl = []
for i in eval_groups:
files_from_cl.extend(files_valid[i])
inits, scores, label_batch, loss_class, file_string_batch, op_feature_map = utils_icarl.reading_data_and_preparing_network(
files_from_cl, gpu, itera, batch_size, train_path, labels_dic, mixing, nb_groups, nb_cl, save_path)
with tf.Session(config=config) as sess:
# Launch the prefetch system
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
sess.run(inits)
# Evaluation routine
stat_hb1 = []
stat_icarl = []
stat_ncm = []
for i in range(int(np.ceil(len(files_from_cl) / batch_size))):
sc, l, loss, files_tmp, feat_map_tmp = sess.run(
[scores, label_batch, loss_class, file_string_batch, op_feature_map])
mapped_prototypes = feat_map_tmp[:, 0, 0, :]
pred_inter = (mapped_prototypes.T) / np.linalg.norm(mapped_prototypes.T, axis=0)
sqd_icarl = -cdist(class_means[:, :, 0, itera].T, pred_inter.T, 'sqeuclidean').T
sqd_ncm = -cdist(class_means[:, :, 1, itera].T, pred_inter.T, 'sqeuclidean').T
#sc 随机分类分数
stat_hb1 += ([ll in best for ll, best in zip(l, np.argsort(sc, axis=1)[:, -top:])])
stat_icarl += ([ll in best for ll, best in zip(l, np.argsort(sqd_icarl, axis=1)[:, -top:])])#倒数5个数
stat_ncm += ([ll in best for ll, best in zip(l, np.argsort(sqd_ncm, axis=1)[:, -top:])])
coord.request_stop()
coord.join(threads)
print('Increment: %i' % itera)
print('Hybrid 1 top ' + str(top) + ' accuracy: %f' % np.average(stat_hb1))#[True or False]
print('iCaRL top ' + str(top) + ' accuracy: %f' % np.average(stat_icarl))
print('NCM top ' + str(top) + ' accuracy: %f' % np.average(stat_ncm))
acc_list[itera, 0] = np.average(stat_icarl)
acc_list[itera, 1] = np.average(stat_hb1)
acc_list[itera, 2] = np.average(stat_ncm)
# Reset the graph to compute the numbers ater the next increment
tf.reset_default_graph()
np.save('results_top' + str(top) + '_acc_cl' + str(nb_cl), acc_list)