-
Notifications
You must be signed in to change notification settings - Fork 282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve handling of unbalanced confusion matrices #41
Conversation
Hi @ExcaliburZero thanks for taking the time to write this. Apologies for the late response, I've been quite busy. First off, do you think you could open a separate PR for the NaN values bugfix so I can merge that immediately and keep this new feature suggestion separate? Thanks! Anyway, I think allowing people to select which classes to show on the CM is a useful feature, however, I'm not a fan of the new arguments introduced here. The need to use indices would be confusing for anyone not familiar with the internals. Nobody knows what the index of a class is, unless they explicitly see that it's the index in the array you get from I think a better way is to let them pass in a list of the actual classes they want included in the x-axis and y-axis respectively. But you have to be careful to handle edge cases: classes that are not in the data at all, duplicate classes, etc. |
I have made a separate PR for the NaN values bugfix (#42). I agree that the arguments are a bit confusing to use. Definitely they should instead take in the names of categories instead, though as you noted this is complicated by the possible values that can be passed in. I'll make those change to the arguments and add in some good validation. (Though I'll be a bit busy so I will probably get around to working on this on Monday.) |
7e70d2d
to
3699401
Compare
I have changed the options to take in the names of the labels instead.
I have also added validation for the both |
Here are some examples of the error messages. Duplicate labels:
Missing labels:
|
Sorry for the late review. I've been very busy the past few days. Will add my comments now. |
scikitplot/plotters.py
Outdated
@@ -87,6 +94,50 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, title=None, normalize=Fal | |||
else: | |||
classes = np.asarray(labels) | |||
|
|||
def validate_labels(known_classes, passed_labels, argument_name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please put this outside of the function, add a detailed docstring, and add a unit test.
scikitplot/plotters.py
Outdated
duplicate_indexes = indexes[~np.isin(indexes, unique_indexes)] | ||
duplicate_labels = passed_labels[duplicate_indexes] | ||
|
||
msg = "The following duplicate labels were passed into %s: %s" % (argument_name, ", ".join(duplicate_labels)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line length is too long. Also, since the rest of the codebase uses .format
instead of % string formatting, I'd prefer it if you use that as well.
scikitplot/plotters.py
Outdated
if np.any(passed_labels_absent): | ||
absent_labels = passed_labels[passed_labels_absent] | ||
|
||
msg = "The following labels were passed into %s, but were not found in labels: %s" % (argument_name, ", ".join(absent_labels)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line length is too long. Also, since the rest of the codebase uses .format
instead of % string formatting, I'd prefer it if you use that as well.
scikitplot/plotters.py
Outdated
|
||
pred_classes = classes[pred_label_indexes] | ||
cm = cm[:,pred_label_indexes][:,0,:] | ||
|
||
if normalize: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you should calculate normalized values before slicing the array according to pred_classes
and true_classes
, otherwise the calculated normalized values might be wrong
scikitplot/plotters.py
Outdated
pred_label_indexes = np.where(np.isin(classes, pred_labels)) | ||
|
||
pred_classes = classes[pred_label_indexes] | ||
cm = cm[:,pred_label_indexes][:,0,:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stylefix: cm = cm[:, pred_label_indexes][:, 0, :]
Aside from my comments, this looks pretty good. Only thing missing are the appropriate unit tests to properly define the behavior of this new functionality. Thanks! |
08977e5
to
4b59b8c
Compare
I have made the changes you mentioned. Right now I am just having an issue where it looks like |
Okay, the tests seem to all work correctly in the Travis CI builds. Do you want me to also write some tests for the new arguments for |
Yes please. Just one more test running through the new arguments. |
I have added a test for the new arguments, and also fixed and added a test for an issue it had with non-string labels. |
scikitplot/plotters.py
Outdated
else: | ||
validate_labels(classes, true_labels, "true_labels") | ||
|
||
true_label_indexes = np.where(np.in1d(classes, true_labels)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't say for sure, but I did some tests and wouldn't
true_label_indexes = np.in1d(classes, true_labels)
work just as well?
scikitplot/plotters.py
Outdated
else: | ||
validate_labels(classes, pred_labels, "pred_labels") | ||
|
||
pred_label_indexes = np.where(np.in1d(classes, pred_labels)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here,
pred_label_indexes = np.where(np.in1d(classes, pred_labels))
Although for this one, you'll want to change L167 to
cm = cm[:, pred_label_indexes]
Add options to plot only certain specified labels in confusion matrices to allow for cases where some "true" labels are not in the "predicted" label set or vice versa. This can be useful in cases where a classifier with certain labels is applied to a dataset with a disjoint or partially disjoint set of related labels. Also add tests for some of the new functionality.
I have now made those changes. |
LGTM! Thanks a lot for this feature! |
Here I have made a few changes that make it easier to plot confusion matrices where the true and predicted sets of labels are not the same. This is a case that can occur when doing something like applying "new" categories to a dataset with an older set of categories.
The changes included are the following:
Fix an issue with nan values showing up when unbalanced confusion matrices are normalized. Where rows with zero entries would sum to zero and then divide by zero when normalizing each cell.
Add options to limit the labels displayed on the true and predicted axes, as with unbalanced confusion matrices some of the labels can be only in the set of true labels or only in the set of predicted labels.
You can see the effect of the new options here: