-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils_.py
112 lines (86 loc) · 2.85 KB
/
utils_.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import sys
import errno
import numpy as np
from sklearn.metrics import roc_curve, roc_auc_score, accuracy_score, average_precision_score
def compute_oscr(x1, x2, pred, labels):
"""
:param x1: open set score for each known class sample (B_k,)
:param x2: open set score for each unknown class sample (B_u,)
:param pred: predicted class for each known class sample (B_k,)
:param labels: correct class for each known class sample (B_k,)
:return: Open Set Classification Rate
"""
x1, x2 = -x1, -x2
# x1, x2 = np.max(pred_k, axis=1), np.max(pred_u, axis=1)
# pred = np.argmax(pred_k, axis=1)
correct = (pred == labels)
m_x1 = np.zeros(len(x1))
m_x1[pred == labels] = 1
k_target = np.concatenate((m_x1, np.zeros(len(x2))), axis=0)
u_target = np.concatenate((np.zeros(len(x1)), np.ones(len(x2))), axis=0)
predict = np.concatenate((x1, x2), axis=0)
n = len(predict)
# Cutoffs are of prediction values
CCR = [0 for x in range(n + 2)]
FPR = [0 for x in range(n + 2)]
idx = predict.argsort()
s_k_target = k_target[idx]
s_u_target = u_target[idx]
for k in range(n - 1):
CC = s_k_target[k + 1:].sum()
FP = s_u_target[k:].sum()
# True Positive Rate
CCR[k] = float(CC) / float(len(x1))
# False Positive Rate
FPR[k] = float(FP) / float(len(x2))
CCR[n] = 0.0
FPR[n] = 0.0
CCR[n + 1] = 1.0
FPR[n + 1] = 1.0
# Positions of ROC curve (FPR, TPR)
ROC = sorted(zip(FPR, CCR), reverse=True)
OSCR = 0
# Compute AUROC Using Trapezoidal Rule
for j in range(n + 1):
h = ROC[j][0] - ROC[j + 1][0]
w = (ROC[j][1] + ROC[j + 1][1]) / 2.0
OSCR = OSCR + h * w
return OSCR
def mkdir_if_missing(directory):
if not os.path.exists(directory):
try:
os.makedirs(directory)
except OSError as e:
if e.errno != errno.EEXIST:
raise
class Logger(object):
"""
Write console output to external text file.
Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py.
"""
def __init__(self, fpath=None):
self.console = sys.stdout
self.file = None
if fpath is not None:
mkdir_if_missing(os.path.dirname(fpath))
self.file = open(fpath, 'w')
def __del__(self):
self.close()
def __enter__(self):
pass
def __exit__(self, *args):
self.close()
def write(self, msg):
self.console.write(msg)
if self.file is not None:
self.file.write(msg)
def flush(self):
self.console.flush()
if self.file is not None:
self.file.flush()
os.fsync(self.file.fileno())
def close(self):
self.console.close()
if self.file is not None:
self.file.close()