-
-
Notifications
You must be signed in to change notification settings - Fork 611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding the MultiLabelConfusionMatrix for computing multi label confusion matrices. #1613
Conversation
@touqir14 thanks a lot for a quick and nice PR! Overall code looks good, later I can try to check the implementation vs tfa. Maybe, we could add a comment with credits to the original implementation in the code if it is just ported tfa implementation. I also wonder for the tests if we could compare the results with something like sklearn ? |
This was not ported from tfa but since you mentioned it, a simple transpose operation will ensure that its output matches that of tfa's implementation. I think it is reasonable to test against scikit learn's implementation. |
Btw, should we aim to keep the output consistent with tfa or just keep this as it is, @vfdev-5 ? |
Let's keep it consistent to sklearn and mention about the difference vs tfa and what to do to make it consistent with tfa... |
Added some tests, let me know what you think. |
conf_mtrx = mtr.compute() | ||
correct_conf_mtrx = torch.tensor([[[1, 1], [0, 2]], [[1, 2], [0, 1]], [[1, 0], [1, 2]]]) | ||
|
||
TestCase.assertTrue( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in general we use pytest and do simply assert something, err_message
, so let's do it here in the same way.
About the tests, I'm thinking if it won't be interesting to put them into a separate file : test_multilabel_confusion_matrix.py
. Maybe, same for the code of MultiLabelConfusionMatrix
?
In the tests, you can simply generate inputs, compute the result and compare it vs sklearn result as it is done for ConfusionMatrix.
Do you think that it would make sense to ensure that MultiLabelConfusionMatrix works and then test it on the input of shape like (batch_size, classes, dim1, ...)
for predictions/targets ?
Another point to address a bit later is to ensure that it works with DDP. In this case, we have to replicate what is done here :
ignite/tests/ignite/metrics/test_confusion_matrix.py
Lines 783 to 790 in f1cc9fb
@pytest.mark.distributed | |
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") | |
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") | |
def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): | |
device = torch.device(f"cuda:{local_rank}") | |
_test_distrib_multiclass_images(device) | |
_test_distrib_accumulator_device(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding the input format of (batch_size, classes, dim1, ...)
, is there any special use case? The tfa implementation and sklearn's function take 2d matrices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can imagine, for example, image segmentation with overlapping classes...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. So we can think of this as each class having num_batches x dim_1 x ... x dim_n
configurations in the general case where in the 2d case, each class had num_batches
number of configurations to compare between the predictions and the ground truth.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it can be seen like that. In ConfusionMatrix implementation we however do in some sense this reshape internally and can accept inputs like (B, C, H, W)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So in the test cases we can simply compare the output of ignite's implementation with sklearn's implementation having feeding the latter with predictions and ground truth arrays after being reshaped as [num_batches x dim_1 x ... x dim_n , num_classes]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, exactly ! Transposed and reshaped
@vfdev-5 , I have been a little busy, so this commit got a bit delayed. Let me know what you think. |
@touqir14 thanks for the updates and no worries about the delay ! I'll take a look later the code and comment out. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update @touqir14 !
I haven't yet seen the tests in details. There are some comments to address for the PR and seems like that code-formatting is not passing neither...
@vfdev-5 , feel free to review the tests. I can push updates addressing your current suggestions and any future suggestions together. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@touqir14 thanks again for the PR !
I have few comments on the implementation and tests. Let me know what do you think.
if ( | ||
not isinstance(output, Sequence) | ||
or len(output) < 2 | ||
or not isinstance(output[0], torch.Tensor) | ||
or not isinstance(output[1], torch.Tensor) | ||
): | ||
raise ValueError( | ||
(r"Argument must consist of a Python Sequence of two tensors such that the first is the predicted" | ||
r" tensor and the second is the ground-truth tensor") | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a correct check for output
but we do not perform such test nowhere in other metrics as it is sort of documented convention. Maybe, we can remove that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or maybe, we could include this check in other metrics as we see fit? Either one is fine by me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove it here and maybe we could add this check in Metric class in another PR
if y.dtype not in valid_types: | ||
raise ValueError(f"y must be of any type: {valid_types}") | ||
|
||
if y_pred.numel() != ((y_pred == 0).sum() + (y_pred == 1).sum()).item(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously, we were checking for binary input as torch.equal(x, x**2)
. I tried to compare times of these two implementations:
import torch
y_pred = torch.randint(0, 2, size=(32, 10))
%%timeit
y_pred.numel() == ((y_pred == 0).sum() + (y_pred == 1).sum()).item()
> 50.3 µs ± 96.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%%timeit
torch.equal(y_pred, y_pred ** 2)
> 9.74 µs ± 23.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
Probably, we can keep torch.equal(y_pred, y_pred ** 2)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thats a clever optimization! I will replace that.
if y_pred.numel() != ((y_pred == 0).sum() + (y_pred == 1).sum()).item(): | ||
raise ValueError("y_pred must be a binary tensor") | ||
|
||
if y.numel() != ((y == 0).sum() + (y == 1).sum()).item(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
ignite_CM = mlcm.compute().numpy() | ||
assert np.all(sklearn_CM.astype(np.int64) == ignite_CM.astype(np.int64)) | ||
|
||
return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return |
num_classes = 3 | ||
cm = MultiLabelConfusionMatrix(num_classes=num_classes, device=metric_device) | ||
|
||
y_true, y_pred = get_y_true_y_pred() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have to generate data a bit differently depending on distributed rank.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have any suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see, our tests for confusion matrix have the same remark. I think we should rework them. Let me propose a better tests for our confusion matrix which could be adapted for this PR as well. We can inspire from here :
ignite/tests/ignite/metrics/test_recall.py
Lines 784 to 785 in e17acc7
y_true = torch.randint(0, 2, size=(offset * idist.get_world_size(), n_classes, 6, 8)).to(device) | |
y_preds = torch.randint(0, 2, size=(offset * idist.get_world_size(), n_classes, 6, 8)).to(device) |
There seems to be a test failure in the |
|
||
def test_simple_ND_input(): | ||
|
||
num_iters = 100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's change it to 5
num_iters = 100 | |
num_iters = 5 |
|
||
def test_simple_batched(): | ||
|
||
num_iters = 100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
num_iters = 100 | |
num_iters = 5 |
- The classes present in M are indexed as 0, ..., num_classes-1 as can be inferred from above. | ||
|
||
Args: | ||
num_classes (int): Number of classes, should be > 1. See notes for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_classes (int): Number of classes, should be > 1. See notes for more details. | |
num_classes (int): Number of classes, should be > 1. |
num_classes = 3 | ||
cm = MultiLabelConfusionMatrix(num_classes=num_classes, device=metric_device) | ||
|
||
y_true, y_pred = get_y_true_y_pred() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see, our tests for confusion matrix have the same remark. I think we should rework them. Let me propose a better tests for our confusion matrix which could be adapted for this PR as well. We can inspire from here :
ignite/tests/ignite/metrics/test_recall.py
Lines 784 to 785 in e17acc7
y_true = torch.randint(0, 2, size=(offset * idist.get_world_size(), n_classes, 6, 8)).to(device) | |
y_preds = torch.randint(0, 2, size=(offset * idist.get_world_size(), n_classes, 6, 8)).to(device) |
@touqir14 let's do like that for this PR: let's comment out distributed tests and fix all remaining nits asked below. And we can merge it. After that in a follow-up PR we can rework both distrib tests for CM and MLCM. What do you think ? |
That sounds good. I will push a commit next addressing the remaining issues. |
Just pushed a commit. I think all the issues have been addressed. Let me know if I missed something |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good @touqir14 ! Thanks !
Just fix code formatting issue and it is good to go
Another point from the checklist is not yet done :
We have to add an entry here as well : https://github.com/pytorch/ignite/blob/master/docs/source/metrics.rst Could you please also merge current master into your branch as "This branch is out-of-date with the base branch". Thanks |
For the docs part here : https://github.com/pytorch/ignite/blob/master/docs/source/metrics.rst , I should just add |
Yes, just You can see docs preview here as well : https://deploy-preview-1613--pytorch-ignite-preview.netlify.app/ |
Reformatted the docstring, let me know how it looks. |
Looks better :) https://deploy-preview-1613--pytorch-ignite-preview.netlify.app/metrics.html#ignite.metrics.MultiLabelConfusionMatrix Description of M[i, 0, 0] values is inlined but it can go... |
Its better to put the
|
Yes, I agree. I'm not sure if rst/sphinx would accept it simply like that, anyway It's better to try that locally :) |
Looks good to me. |
Fixes 1609
Description:
This is an in-progress PR that adds functionalities for computing confusion matrices for multi-label multi-class classification problems. I will add tests and doc strings once I get an initial feedback to accomodate any necessary changes (if any). The test_confusion_matrix.py file tests the ConfusionMatrix class quite thoroughly. I was wondering how thorough the tests for MultiLabelConfusionMatrix needs to be.
The current version only works with prediction and ground truth tensors of shape
[batch_size, num_classes]
. The tensor values need to be binary. The example below illustrates its use.ccing @vfdev-5
Check list: