Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the MultiLabelConfusionMatrix for computing multi label confusion matrices. #1613

Merged
merged 13 commits into from
Feb 20, 2021

Conversation

touqir14
Copy link
Contributor

@touqir14 touqir14 commented Feb 4, 2021

Fixes 1609

Description:
This is an in-progress PR that adds functionalities for computing confusion matrices for multi-label multi-class classification problems. I will add tests and doc strings once I get an initial feedback to accomodate any necessary changes (if any). The test_confusion_matrix.py file tests the ConfusionMatrix class quite thoroughly. I was wondering how thorough the tests for MultiLabelConfusionMatrix needs to be.

The current version only works with prediction and ground truth tensors of shape [batch_size, num_classes]. The tensor values need to be binary. The example below illustrates its use.

import torch
from ignite.metrics import MultiLabelConfusionMatrix

'''
Both inputs below must be a tensor of shape N x K, where N is the number of samples 
and K is the number of classes. The values of the tensors must be binary where a 
1 in the i'th row and j'th column marks the j'th label for the i'th sample.
'''

predicted = torch.Tensor([[1, 1, 0], [0, 1, 0], [1, 0, 1], [1, 1, 1]])
ground_truth = torch.Tensor([[0, 1, 0], [0, 0, 1], [1, 0, 1], [1, 0, 1]])
mtr_ignite = MultiLabelConfusionMatrix(3, normalized=False)
mtr_ignite.update([predicted, ground_truth])
conf_matrix = mtr_ignite.compute()


'''
conf_matrix[i, 0, 0] corresponds to count/rate of true negatives of class i,
conf_matrix[i, 0, 1] corresponds to count/rate of false positives of class i,
conf_matrix[i, 1, 0] corresponds to count/rate of false negatives of class i,
conf_matrix[i, 1, 1] corresponds to count/rate of true positives of class i.

With normalization : meter.MultiLabelConfusionMeter(k=3, normalized=True), for all i:
conf_matrix[i, 0, 0] + conf_matrix[i, 0, 1] + conf_matrix[i, 1, 0] + conf_matrix[i, 1, 1] = 1
'''

ccing @vfdev-5

Check list:

  • New tests are added (if a new feature is added)
  • New doc strings: description and/or example code are in RST format
  • Documentation is updated (if required)

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 5, 2021

@touqir14 thanks a lot for a quick and nice PR! Overall code looks good, later I can try to check the implementation vs tfa. Maybe, we could add a comment with credits to the original implementation in the code if it is just ported tfa implementation. I also wonder for the tests if we could compare the results with something like sklearn ?

@touqir14
Copy link
Contributor Author

touqir14 commented Feb 5, 2021

This was not ported from tfa but since you mentioned it, a simple transpose operation will ensure that its output matches that of tfa's implementation. I think it is reasonable to test against scikit learn's implementation.

@touqir14
Copy link
Contributor Author

touqir14 commented Feb 5, 2021

Btw, should we aim to keep the output consistent with tfa or just keep this as it is, @vfdev-5 ?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 5, 2021

Let's keep it consistent to sklearn and mention about the difference vs tfa and what to do to make it consistent with tfa...

@touqir14
Copy link
Contributor Author

touqir14 commented Feb 5, 2021

Added some tests, let me know what you think.

conf_mtrx = mtr.compute()
correct_conf_mtrx = torch.tensor([[[1, 1], [0, 2]], [[1, 2], [0, 1]], [[1, 0], [1, 2]]])

TestCase.assertTrue(
Copy link
Collaborator

@vfdev-5 vfdev-5 Feb 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in general we use pytest and do simply assert something, err_message, so let's do it here in the same way.

About the tests, I'm thinking if it won't be interesting to put them into a separate file : test_multilabel_confusion_matrix.py. Maybe, same for the code of MultiLabelConfusionMatrix ?

In the tests, you can simply generate inputs, compute the result and compare it vs sklearn result as it is done for ConfusionMatrix.
Do you think that it would make sense to ensure that MultiLabelConfusionMatrix works and then test it on the input of shape like (batch_size, classes, dim1, ...) for predictions/targets ?

Another point to address a bit later is to ensure that it works with DDP. In this case, we have to replicate what is done here :

@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test_distrib_gpu(local_rank, distributed_context_single_node_nccl):
device = torch.device(f"cuda:{local_rank}")
_test_distrib_multiclass_images(device)
_test_distrib_accumulator_device(device)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the input format of (batch_size, classes, dim1, ...) , is there any special use case? The tfa implementation and sklearn's function take 2d matrices.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can imagine, for example, image segmentation with overlapping classes...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. So we can think of this as each class having num_batches x dim_1 x ... x dim_n configurations in the general case where in the 2d case, each class had num_batches number of configurations to compare between the predictions and the ground truth.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it can be seen like that. In ConfusionMatrix implementation we however do in some sense this reshape internally and can accept inputs like (B, C, H, W).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So in the test cases we can simply compare the output of ignite's implementation with sklearn's implementation having feeding the latter with predictions and ground truth arrays after being reshaped as [num_batches x dim_1 x ... x dim_n , num_classes]

Copy link
Collaborator

@vfdev-5 vfdev-5 Feb 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly ! Transposed and reshaped

@touqir14
Copy link
Contributor Author

touqir14 commented Feb 12, 2021

@vfdev-5 , I have been a little busy, so this commit got a bit delayed. Let me know what you think.
As of now, the MultiLabelConfusionMeter class and the tests are in separate files. I have also added the DDP tests.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 12, 2021

@touqir14 thanks for the updates and no worries about the delay ! I'll take a look later the code and comment out.

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update @touqir14 !
I haven't yet seen the tests in details. There are some comments to address for the PR and seems like that code-formatting is not passing neither...

ignite/metrics/multilabel_confusion_matrix.py Show resolved Hide resolved
ignite/metrics/multilabel_confusion_matrix.py Outdated Show resolved Hide resolved
ignite/metrics/multilabel_confusion_matrix.py Outdated Show resolved Hide resolved
tests/ignite/metrics/test_confusion_matrix.py Outdated Show resolved Hide resolved
tests/ignite/metrics/test_multilabel_confusion_matrix.py Outdated Show resolved Hide resolved
@touqir14
Copy link
Contributor Author

@vfdev-5 , feel free to review the tests. I can push updates addressing your current suggestions and any future suggestions together.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 19, 2021

@vfdev-5 , feel free to review the tests. I can push updates addressing your current suggestions and any future suggestions together.

@touqir14 let me do that in the coming days. Thanks for pinging

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@touqir14 thanks again for the PR !
I have few comments on the implementation and tests. Let me know what do you think.

ignite/metrics/confusion_matrix.py Outdated Show resolved Hide resolved
ignite/metrics/multilabel_confusion_matrix.py Outdated Show resolved Hide resolved
Comment on lines 66 to 75
if (
not isinstance(output, Sequence)
or len(output) < 2
or not isinstance(output[0], torch.Tensor)
or not isinstance(output[1], torch.Tensor)
):
raise ValueError(
(r"Argument must consist of a Python Sequence of two tensors such that the first is the predicted"
r" tensor and the second is the ground-truth tensor")
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a correct check for output but we do not perform such test nowhere in other metrics as it is sort of documented convention. Maybe, we can remove that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe, we could include this check in other metrics as we see fit? Either one is fine by me.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove it here and maybe we could add this check in Metric class in another PR

if y.dtype not in valid_types:
raise ValueError(f"y must be of any type: {valid_types}")

if y_pred.numel() != ((y_pred == 0).sum() + (y_pred == 1).sum()).item():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, we were checking for binary input as torch.equal(x, x**2). I tried to compare times of these two implementations:

import torch
y_pred = torch.randint(0, 2, size=(32, 10))

%%timeit
y_pred.numel() == ((y_pred == 0).sum() + (y_pred == 1).sum()).item()
> 50.3 µs ± 96.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%%timeit
torch.equal(y_pred, y_pred ** 2)
> 9.74 µs ± 23.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

Probably, we can keep torch.equal(y_pred, y_pred ** 2).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats a clever optimization! I will replace that.

if y_pred.numel() != ((y_pred == 0).sum() + (y_pred == 1).sum()).item():
raise ValueError("y_pred must be a binary tensor")

if y.numel() != ((y == 0).sum() + (y == 1).sum()).item():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

ignite_CM = mlcm.compute().numpy()
assert np.all(sklearn_CM.astype(np.int64) == ignite_CM.astype(np.int64))

return
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return

num_classes = 3
cm = MultiLabelConfusionMatrix(num_classes=num_classes, device=metric_device)

y_true, y_pred = get_y_true_y_pred()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to generate data a bit differently depending on distributed rank.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any suggestions?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see, our tests for confusion matrix have the same remark. I think we should rework them. Let me propose a better tests for our confusion matrix which could be adapted for this PR as well. We can inspire from here :

y_true = torch.randint(0, 2, size=(offset * idist.get_world_size(), n_classes, 6, 8)).to(device)
y_preds = torch.randint(0, 2, size=(offset * idist.get_world_size(), n_classes, 6, 8)).to(device)

tests/ignite/metrics/test_multilabel_confusion_matrix.py Outdated Show resolved Hide resolved
tests/ignite/metrics/test_multilabel_confusion_matrix.py Outdated Show resolved Hide resolved
@touqir14
Copy link
Contributor Author

There seems to be a test failure in the test_distrib_cpu function. I am not sure what is causing that. In the non distributed setting the tests ran fine.

ignite/metrics/multilabel_confusion_matrix.py Show resolved Hide resolved

def test_simple_ND_input():

num_iters = 100
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's change it to 5

Suggested change
num_iters = 100
num_iters = 5


def test_simple_batched():

num_iters = 100
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Suggested change
num_iters = 100
num_iters = 5

- The classes present in M are indexed as 0, ..., num_classes-1 as can be inferred from above.

Args:
num_classes (int): Number of classes, should be > 1. See notes for more details.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
num_classes (int): Number of classes, should be > 1. See notes for more details.
num_classes (int): Number of classes, should be > 1.

num_classes = 3
cm = MultiLabelConfusionMatrix(num_classes=num_classes, device=metric_device)

y_true, y_pred = get_y_true_y_pred()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see, our tests for confusion matrix have the same remark. I think we should rework them. Let me propose a better tests for our confusion matrix which could be adapted for this PR as well. We can inspire from here :

y_true = torch.randint(0, 2, size=(offset * idist.get_world_size(), n_classes, 6, 8)).to(device)
y_preds = torch.randint(0, 2, size=(offset * idist.get_world_size(), n_classes, 6, 8)).to(device)

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 20, 2021

@touqir14 let's do like that for this PR: let's comment out distributed tests and fix all remaining nits asked below. And we can merge it. After that in a follow-up PR we can rework both distrib tests for CM and MLCM. What do you think ?

@touqir14
Copy link
Contributor Author

That sounds good. I will push a commit next addressing the remaining issues.

@touqir14
Copy link
Contributor Author

Just pushed a commit. I think all the issues have been addressed. Let me know if I missed something

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good @touqir14 ! Thanks !
Just fix code formatting issue and it is good to go

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 20, 2021

Another point from the checklist is not yet done :

  • Documentation is updated (if required)

We have to add an entry here as well : https://github.com/pytorch/ignite/blob/master/docs/source/metrics.rst

Could you please also merge current master into your branch as "This branch is out-of-date with the base branch". Thanks

@touqir14
Copy link
Contributor Author

For the docs part here : https://github.com/pytorch/ignite/blob/master/docs/source/metrics.rst , I should just add .. autoclass:: MultiLabelConfusionMatrix in the complete list of metrics right?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 20, 2021

Yes, just .. autoclass:: MultiLabelConfusionMatrix in alphabetical order and should be good.

You can see docs preview here as well : https://deploy-preview-1613--pytorch-ignite-preview.netlify.app/

@touqir14
Copy link
Contributor Author

Reformatted the docstring, let me know how it looks.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 20, 2021

Reformatted the docstring, let me know how it looks.

Looks better :) https://deploy-preview-1613--pytorch-ignite-preview.netlify.app/metrics.html#ignite.metrics.MultiLabelConfusionMatrix

Description of M[i, 0, 0] values is inlined but it can go...

@touqir14
Copy link
Contributor Author

Its better to put the M[] parts in a bulleted format I think. Will adding an extra "-" before each line put them in their own bulleted list? Like:

- The confusion matrix 'M' is of dimension (num_classes, 2, 2).
      - M[i, 0, 0] corresponds to count/rate of true negatives of class i,
      - M[i, 0, 1] corresponds to count/rate of false positives of class i,
      - M[i, 1, 0] corresponds to count/rate of false negatives of class i,
      - M[i, 1, 1] corresponds to count/rate of true positives of class i.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 20, 2021

Yes, I agree. I'm not sure if rst/sphinx would accept it simply like that, anyway It's better to try that locally :)
Here is info on how to generate docs locally: https://github.com/pytorch/ignite/blob/master/CONTRIBUTING.md#local-documentation-building-and-deploying

@touqir14
Copy link
Contributor Author

Looks good to me.

@vfdev-5 vfdev-5 merged commit abe1ddd into pytorch:master Feb 20, 2021
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 20, 2021

@touqir14 thanks again for the PR! If you are interested in finalizing the things with distributed we can discuss about that in the issue : #1657

@touqir14
Copy link
Contributor Author

Thanks @vfdev-5 for your valuable feedbacks! I will soon have a look at #1657

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support for MultiLabel MultiClass classification confusion matrix computation.
2 participants