Skip to content

Commit

Permalink
[Bugfix] Fix bug in cross entropy loss (#3457)
Browse files Browse the repository at this point in the history
Thanks for your contribution and we appreciate it a lot. The following
instructions would make your pull request more healthy and more easily
get feedback. If you do not understand some items, don't worry, just
make the pull request and seek help from maintainers.

## Motivation

Fixes #3412

## Modification

We just need to replace tensor creation using torch.stack() instead of
torch.tensor().

## BC-breaking (Optional)

Does the modification introduce changes that break the
backward-compatibility of the downstream repos?
If so, please describe how it breaks the compatibility and how the
downstream projects should modify their code to keep compatibility with
this PR.

## Use cases (Optional)

If this PR introduces a new feature, it is better to list some use cases
here, and update the documentation.

## Checklist

1. Pre-commit or other linting tools are used to fix the potential lint
issues.
2. The modification is covered by complete unit tests. If not, please
add more unit test to ensure the correctness.
3. If the modification has potential influence on downstream projects,
this PR should be tested with downstream projects, like MMDet or
MMDet3D.
4. The documentation has been modified accordingly, like docstring or
example tutorials.
  • Loading branch information
mmeendez8 committed Dec 4, 2023
1 parent cbf9af1 commit e51f511
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
5 changes: 3 additions & 2 deletions mmseg/models/losses/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ def cross_entropy(pred,

else:
# the average factor should take the class weights into account
label_weights = torch.tensor([class_weight[cls] for cls in label],
device=class_weight.device)
label_weights = torch.stack([class_weight[cls] for cls in label
]).to(device=class_weight.device)

if avg_non_ignore:
label_weights[label == ignore_index] = 0
avg_factor = label_weights.sum()
Expand Down
28 changes: 28 additions & 0 deletions tests/test_models/test_losses/test_cross_entropy_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F

from mmseg.models.losses import CrossEntropyLoss, weight_reduce_loss


def test_cross_entropy_loss_class_weights():
loss_class = CrossEntropyLoss
pred = torch.rand((1, 10, 4, 4))
target = torch.randint(0, 10, (1, 4, 4))
class_weight = torch.ones(10)
avg_factor = target.numel()

cross_entropy_loss = F.cross_entropy(
pred, target, weight=class_weight, reduction='none', ignore_index=-100)

expected_loss = weight_reduce_loss(
cross_entropy_loss,
weight=None,
reduction='mean',
avg_factor=avg_factor)

# Test loss forward
loss = loss_class(class_weight=class_weight.tolist())(pred, target)

assert isinstance(loss, torch.Tensor)
assert expected_loss == loss

0 comments on commit e51f511

Please sign in to comment.