Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1359] Adds a multiclass-MCC metric derived from Pearson #14461

Merged
merged 2 commits into from
Apr 10, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Adds a multiclass-MCC metric derived from Pearson
  • Loading branch information
tlby committed Mar 23, 2019
commit 7fb48d27c7aebe710b4dcfbabfdc0e56cd52f00d
132 changes: 131 additions & 1 deletion python/mxnet/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ class MCC(EvalMetric):

.. note::

This version of MCC only supports binary classification.
This version of MCC only supports binary classification. See PCC.

Parameters
----------
Expand Down Expand Up @@ -1476,6 +1476,136 @@ def update(self, labels, preds):
self.global_num_inst += 1


@register
class PCC(EvalMetric):
"""PCC is a multiclass equivalent for the Matthews correlation coefficient derived
from a discrete solution to the Pearson correlation coefficient.

.. math::
\\text{PCC} = \\frac {\\sum _{k}\\sum _{l}\\sum _{m}C_{kk}C_{lm}-C_{kl}C_{mk}}
{{\\sqrt {\\sum _{k}(\\sum _{l}C_{kl})(\\sum _{k'|k'\\neq k}\\sum _{l'}C_{k'l'})}}
{\\sqrt {\\sum _{k}(\\sum _{l}C_{lk})(\\sum _{k'|k'\\neq k}\\sum _{l'}C_{l'k'})}}}

defined in terms of a K x K confusion matrix C.

When there are more than two labels the PCC will no longer range between -1 and +1.
Instead the minimum value will be between -1 and 0 depending on the true distribution.
The maximum value is always +1.

Parameters
----------
name : str
Name of this metric instance for display.
output_names : list of str, or None
Name of predictions that should be used when updating with update_dict.
By default include all predictions.
label_names : list of str, or None
Name of labels that should be used when updating with update_dict.
By default include all labels.

Examples
--------
>>> # In this example the network almost always predicts positive
>>> false_positives = 1000
>>> false_negatives = 1
>>> true_positives = 10000
>>> true_negatives = 1
>>> predicts = [mx.nd.array(
[[.3, .7]]*false_positives +
[[.7, .3]]*true_negatives +
[[.7, .3]]*false_negatives +
[[.3, .7]]*true_positives
)]
>>> labels = [mx.nd.array(
[0]*(false_positives + true_negatives) +
[1]*(false_negatives + true_positives)
)]
>>> f1 = mx.metric.F1()
>>> f1.update(preds = predicts, labels = labels)
>>> pcc = mx.metric.PCC()
>>> pcc.update(preds = predicts, labels = labels)
>>> print f1.get()
('f1', 0.95233560306652054)
>>> print pcc.get()
('pcc', 0.01917751877733392)
"""
def __init__(self, name='pcc',
output_names=None, label_names=None,
has_global_stats=True):
self.k = 2
super(PCC, self).__init__(
name=name, output_names=output_names, label_names=label_names,
has_global_stats=has_global_stats)

def _grow(self, inc):
self.lcm = numpy.pad(
self.lcm, ((0, inc), (0, inc)), 'constant', constant_values=(0))
self.gcm = numpy.pad(
self.gcm, ((0, inc), (0, inc)), 'constant', constant_values=(0))
self.k += inc

def _calc_mcc(self, cmat):
n = cmat.sum()
x = cmat.sum(axis=1)
y = cmat.sum(axis=0)
cov_xx = numpy.sum(x * (n - x))
cov_yy = numpy.sum(y * (n - y))
if cov_xx == 0 or cov_yy == 0:
return float('nan')
i = cmat.diagonal()
cov_xy = numpy.sum(i * n - x * y)
return cov_xy / (cov_xx * cov_yy) ** 0.5

def update(self, labels, preds):
"""Updates the internal evaluation result.

Parameters
----------
labels : list of `NDArray`
The labels of the data.

preds : list of `NDArray`
Predicted values.
"""
labels, preds = check_label_shapes(labels, preds, True)

# update the confusion matrix
for label, pred in zip(labels, preds):
label = label.astype('int32', copy=False).asnumpy()
pred = pred.asnumpy().argmax(axis=1)
n = max(pred.max(), label.max())
if n >= self.k:
self._grow(n + 1 - self.k)
bcm = numpy.zeros((self.k, self.k))
for i, j in zip(pred, label):
bcm[i, j] += 1
self.lcm += bcm
self.gcm += bcm

self.num_inst += 1
self.global_num_inst += 1

@property
def sum_metric(self):
return self._calc_mcc(self.lcm) * self.num_inst

@property
def global_sum_metric(self):
return self._calc_mcc(self.gcm) * self.global_num_inst

def reset(self):
"""Resets the internal evaluation result to initial state."""
self.global_num_inst = 0.
self.gcm = numpy.zeros((self.k, self.k))
self.reset_local()

def reset_local(self):
"""Resets the local portion of the internal evaluation results
to initial state."""
self.num_inst = 0.
self.lcm = numpy.zeros((self.k, self.k))


@register
class Loss(EvalMetric):
"""Dummy metric for directly printing loss.
Expand Down
82 changes: 82 additions & 0 deletions tests/python/unittest/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def test_metrics():
check_metric('mcc')
check_metric('perplexity', -1)
check_metric('pearsonr')
check_metric('pcc')
check_metric('nll_loss')
check_metric('loss')
composite = mx.metric.create(['acc', 'f1'])
Expand Down Expand Up @@ -89,6 +90,7 @@ def test_global_metric():
_check_global_metric('mcc', shape=(10,2), average='micro')
_check_global_metric('perplexity', -1)
_check_global_metric('pearsonr', use_same_shape=True)
_check_global_metric('pcc', shape=(10,2))
_check_global_metric('nll_loss')
_check_global_metric('loss')
_check_global_metric('ce')
Expand Down Expand Up @@ -253,6 +255,86 @@ def test_pearsonr():
_, pearsonr = metric.get()
assert pearsonr == pearsonr_expected

def cm_batch(cm):
# generate a batch yielding a given confusion matrix
n = len(cm)
ident = np.identity(n)
labels = []
preds = []
for i in range(n):
for j in range(n):
labels += [ i ] * cm[i][j]
preds += [ ident[j] ] * cm[i][j]
return ([ mx.nd.array(labels, dtype='int32') ], [ mx.nd.array(preds) ])

def test_pcc():
labels, preds = cm_batch([
[ 7, 3 ],
[ 2, 5 ],
])
met_pcc = mx.metric.create('pcc')
met_pcc.update(labels, preds)
_, pcc = met_pcc.get()

# pcc should agree with mcc for binary classification
met_mcc = mx.metric.create('mcc')
met_mcc.update(labels, preds)
_, mcc = met_mcc.get()
np.testing.assert_almost_equal(pcc, mcc)

# pcc should agree with Pearson for binary classification
met_pear = mx.metric.create('pearsonr')
met_pear.update(labels, [p.argmax(axis=1) for p in preds])
_, pear = met_pear.get()
np.testing.assert_almost_equal(pcc, pear)

# check multiclass case against reference implementation
CM = [
[ 23, 13, 3 ],
[ 7, 19, 11 ],
[ 2, 5, 17 ],
]
K = 3
ref = sum(
CM[k][k] * CM[l][m] - CM[k][l] * CM[m][k]
for k in range(K)
for l in range(K)
for m in range(K)
) / (sum(
sum(CM[k][l] for l in range(K)) * sum(
sum(CM[f][g] for g in range(K))
for f in range(K)
if f != k
)
for k in range(K)
) * sum(
sum(CM[l][k] for l in range(K)) * sum(
sum(CM[f][g] for f in range(K))
for g in range(K)
if g != k
)
for k in range(K)
)) ** 0.5
labels, preds = cm_batch(CM)
met_pcc.reset()
met_pcc.update(labels, preds)
_, pcc = met_pcc.get()
np.testing.assert_almost_equal(pcc, ref)

# things that should not change metric score:
# * order
# * batch size
# * update frequency
labels = [ [ i ] for i in labels[0] ]
labels.reverse()
preds = [ [ i.reshape((1, -1)) ] for i in preds[0] ]
preds.reverse()

met_pcc.reset()
for l, p in zip(labels, preds):
met_pcc.update(l, p)
assert pcc == met_pcc.get()[1]

def test_single_array_input():
pred = mx.nd.array([[1,2,3,4]])
label = pred + 0.1
Expand Down