Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve handling of unbalanced confusion matrices #41

Merged
merged 1 commit into from
Aug 24, 2017

Conversation

ExcaliburZero
Copy link
Contributor

Here I have made a few changes that make it easier to plot confusion matrices where the true and predicted sets of labels are not the same. This is a case that can occur when doing something like applying "new" categories to a dataset with an older set of categories.

The changes included are the following:

Fix an issue with nan values showing up when unbalanced confusion matrices are normalized. Where rows with zero entries would sum to zero and then divide by zero when normalizing each cell.

Add options to limit the labels displayed on the true and predicted axes, as with unbalanced confusion matrices some of the labels can be only in the set of true labels or only in the set of predicted labels.

You can see the effect of the new options here:

import numpy as np
import matplotlib.pyplot as plt
import scikitplot as sciplt

y_true = np.array(["A", "A", "B", "B", "B", "C", "D"])
y_pred = np.array(["A", "A", "Ba", "Bb", "Ba", "C", "D"])

print(y_true.shape)
print(y_pred.shape)

true_labels = np.unique(y_true)
pred_labels = np.unique(y_pred)

labels = np.sort(np.unique(np.concatenate([true_labels, pred_labels])))

true_label_indexes = np.where(np.isin(labels, true_labels))
pred_label_indexes = np.where(np.isin(labels, pred_labels))

sciplt.plotters.plot_confusion_matrix(y_true, y_pred, hide_zeros=True, normalize=True, true_label_indexes=true_label_indexes, pred_label_indexes=pred_label_indexes, labels=labels)
plt.show()

figure_1

@reiinakano
Copy link
Owner

reiinakano commented Aug 4, 2017

Hi @ExcaliburZero thanks for taking the time to write this. Apologies for the late response, I've been quite busy.

First off, do you think you could open a separate PR for the NaN values bugfix so I can merge that immediately and keep this new feature suggestion separate? Thanks!

Anyway, I think allowing people to select which classes to show on the CM is a useful feature, however, I'm not a fan of the new arguments introduced here. The need to use indices would be confusing for anyone not familiar with the internals. Nobody knows what the index of a class is, unless they explicitly see that it's the index in the array you get from np.sort(np.unique(np.concatenate([true_labels, pred_labels]))). And nobody's gonna go and dig that info out from the source code.

I think a better way is to let them pass in a list of the actual classes they want included in the x-axis and y-axis respectively. But you have to be careful to handle edge cases: classes that are not in the data at all, duplicate classes, etc.

@ExcaliburZero
Copy link
Contributor Author

I have made a separate PR for the NaN values bugfix (#42).

I agree that the arguments are a bit confusing to use. Definitely they should instead take in the names of categories instead, though as you noted this is complicated by the possible values that can be passed in.

I'll make those change to the arguments and add in some good validation. (Though I'll be a bit busy so I will probably get around to working on this on Monday.)

@ExcaliburZero
Copy link
Contributor Author

I have changed the options to take in the names of the labels instead.

true_labels = ["A", "B", "C", "D"]
pred_labels = ["A", "Ba", "Bb", "C", "D"]

sciplt.plotters.plot_confusion_matrix(y_true, y_pred, true_labels=true_labels,
                                      pred_labels=pred_labels)

I have also added validation for the both test_labels and pred_labels to check that there are no duplicate labels and that there are no labels that are not in classes. In the case that there is either such issue then a ValueError is raised with a descriptive error message.

@ExcaliburZero
Copy link
Contributor Author

Here are some examples of the error messages.

Duplicate labels:

true_labels = ["A", "B", "C", "D", "D", "A"]
pred_labels = ["A", "Ba", "Bb", "C", "D", "F", "G"]

sciplt.plotters.plot_confusion_matrix(y_true, y_pred, true_labels=true_labels,
                                      pred_labels=pred_labels)
ValueError: The following duplicate labels were passed into true_labels: D, A

Missing labels:

true_labels = ["A", "B", "C", "D"]
pred_labels = ["A", "Ba", "Bb", "C", "D", "F", "G"]

sciplt.plotters.plot_confusion_matrix(y_true, y_pred, true_labels=true_labels,
                                      pred_labels=pred_labels)

ValueError: The following labels were passed into pred_labels, but were not found in labels: F, G

@reiinakano
Copy link
Owner

Sorry for the late review. I've been very busy the past few days. Will add my comments now.

@@ -87,6 +94,50 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, title=None, normalize=Fal
else:
classes = np.asarray(labels)

def validate_labels(known_classes, passed_labels, argument_name):
Copy link
Owner

@reiinakano reiinakano Aug 19, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please put this outside of the function, add a detailed docstring, and add a unit test.

duplicate_indexes = indexes[~np.isin(indexes, unique_indexes)]
duplicate_labels = passed_labels[duplicate_indexes]

msg = "The following duplicate labels were passed into %s: %s" % (argument_name, ", ".join(duplicate_labels))
Copy link
Owner

@reiinakano reiinakano Aug 19, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line length is too long. Also, since the rest of the codebase uses .format instead of % string formatting, I'd prefer it if you use that as well.

if np.any(passed_labels_absent):
absent_labels = passed_labels[passed_labels_absent]

msg = "The following labels were passed into %s, but were not found in labels: %s" % (argument_name, ", ".join(absent_labels))
Copy link
Owner

@reiinakano reiinakano Aug 19, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line length is too long. Also, since the rest of the codebase uses .format instead of % string formatting, I'd prefer it if you use that as well.


pred_classes = classes[pred_label_indexes]
cm = cm[:,pred_label_indexes][:,0,:]

if normalize:
Copy link
Owner

@reiinakano reiinakano Aug 19, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should calculate normalized values before slicing the array according to pred_classes and true_classes, otherwise the calculated normalized values might be wrong

pred_label_indexes = np.where(np.isin(classes, pred_labels))

pred_classes = classes[pred_label_indexes]
cm = cm[:,pred_label_indexes][:,0,:]
Copy link
Owner

@reiinakano reiinakano Aug 19, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stylefix: cm = cm[:, pred_label_indexes][:, 0, :]

@reiinakano
Copy link
Owner

reiinakano commented Aug 19, 2017

Aside from my comments, this looks pretty good.

Only thing missing are the appropriate unit tests to properly define the behavior of this new functionality. Thanks!

@ExcaliburZero
Copy link
Contributor Author

I have made the changes you mentioned.

Right now I am just having an issue where it looks like numpy.isin does not work in the Python 2.7 build.

@ExcaliburZero
Copy link
Contributor Author

Okay, the tests seem to all work correctly in the Travis CI builds.

Do you want me to also write some tests for the new arguments for plot_confusion_matrix?

@reiinakano
Copy link
Owner

Yes please. Just one more test running through the new arguments.

@ExcaliburZero
Copy link
Contributor Author

I have added a test for the new arguments, and also fixed and added a test for an issue it had with non-string labels.

else:
validate_labels(classes, true_labels, "true_labels")

true_label_indexes = np.where(np.in1d(classes, true_labels))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't say for sure, but I did some tests and wouldn't
true_label_indexes = np.in1d(classes, true_labels)
work just as well?

else:
validate_labels(classes, pred_labels, "pred_labels")

pred_label_indexes = np.where(np.in1d(classes, pred_labels))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here,

pred_label_indexes = np.where(np.in1d(classes, pred_labels))

Although for this one, you'll want to change L167 to
cm = cm[:, pred_label_indexes]

Add options to plot only certain specified labels in confusion
matrices to allow for cases where some "true" labels are not in the
"predicted" label set or vice versa.

This can be useful in cases where a classifier with certain labels is
applied to a dataset with a disjoint or partially disjoint set of
related labels.

Also add tests for some of the new functionality.
@ExcaliburZero
Copy link
Contributor Author

I have now made those changes.

@reiinakano
Copy link
Owner

LGTM!

Thanks a lot for this feature!

@reiinakano reiinakano merged commit d5402fe into reiinakano:master Aug 24, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants