Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'mean' reduction result in CrossEntropyLoss mismatches with manually computing mean #40560

Closed
vincentwen1995 opened this issue Jun 25, 2020 · 4 comments
Assignees
Labels
high priority module: docs Related to our documentation, both in docs/ and docblocks module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@vincentwen1995
Copy link

vincentwen1995 commented Jun 25, 2020

🐛 Bug

Hi, during training, I noticed that when specifiying weights for CrossEntropyLoss, using the 'mean' reduction produces a different loss output, compared to using the 'none' reduction and computing the mean manually.

To Reproduce

Steps to reproduce the behavior:

import torch
import torch.nn as nn

logits = torch.randn((16, 5))
targets = torch.empty(16, dtype=torch.long).random_(5)

weights = [1, 2, 3, 4, 5]
cross_ent_mean = nn.CrossEntropyLoss(weight=torch.FloatTensor(weights), reduction='mean')
loss_a = cross_ent_mean(logits, targets)
print(loss_a)

cross_ent = nn.CrossEntropyLoss(weight=torch.FloatTensor(weights), reduction='none')
loss_b = cross_ent(logits, targets).mean()
print(loss_b)

assert torch.equal(loss_a, loss_b)

Expected behavior

The loss computed by the two methods should be equal torch.equal(loss_a, loss_b) = True

Environment

Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).

You can get the script and run it with:

wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py
  • PyTorch Version (e.g., 1.0): 1.6.0.dev20200505+cu101
  • OS (e.g., Linux): Windows 10
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.7.6
  • CUDA/cuDNN version: CUDA 10.1 (irrelevant)
  • GPU models and configuration: GTX970M (irrelevant)
  • Any other relevant information:

Additional context

It seems to be the problem with the weights as I tried out to assign the same weight to all the classes, however, theorectically, the weighted loss should be computed before applying the mean() operation.

cc @ezyang @gchanan @zou3519 @jlin27 @albanD @mruberry

@gchanan
Copy link
Contributor

gchanan commented Jun 25, 2020

This isn't a bug, but the documentation isn't clear -- you'd only find the explanation if you look at the documentation for NLLLoss.

See #31295 for the case in NLLLoss.

@gchanan gchanan added module: docs Related to our documentation, both in docs/ and docblocks module: nn Related to torch.nn triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Jun 25, 2020
@gchanan
Copy link
Contributor

gchanan commented Jun 25, 2020

High priority for the docs fixes.

We should do the following:

  1. Add the weight description CrossEntropyLoss, as in nll_loss with weights: reduction 'mean' gives wrong result  #31295.
  2. The weight explanation has been added to NLLLoss, but the parameter documentation still retains the generic writeup:

reduction (string, optional) – Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the sum of the output will be divided by the number of elements in the output, 'sum': the output will be summed. Note: size_average and reduce are in the process of being deprecated, and in the meantime, specifying either of those two args will override reduction. Default: 'mean'

We should fix this up for NLLLoss and audit all the other places the reduction string is used to make sure these are accurate. If any of them aren't, we should fix them and also add the math description to 1) for them.

@vincentwen1995
Copy link
Author

This isn't a bug, but the documentation isn't clear -- you'd only find the explanation if you look at the documentation for NLLLoss.

See #31295 for the case in NLLLoss.

Thank you for your quick response!
The explanation seems clear in the documentation for NLLoss, although I am wondering what would be the argumentation for this computation. As I suppose, in a certain batch, the losses still have the same weight for averaging, the only difference is that if the batch contains more samples with the major label (with smaller weight for smaller impact), ths coefficient grows larger and this mini-batch training will have a larger coefficient during back propogation, which makes this certain batch more impactful for the weight updates? Please correct me if I am wrong.

@peterbell10 peterbell10 self-assigned this Jul 5, 2020
@peterbell10
Copy link
Collaborator

I am wondering what would be the argumentation for this computation.

It is the standard definition of the weighted mean. For integral weights, this gives you the same answer as the mean where each element i is duplicated weight[i] times. And if you set the weights to all equal, you will get a normal unweighted mean.

a = torch.rand(10)
w = torch.ones(10)
assert (a * w).sum() / w.sum() == a.mean()

csarofeen pushed a commit to csarofeen/pytorch that referenced this issue Jul 7, 2020
Summary:
Closes pytorch#40560

This adds the equation for the weighted mean to `CrossEntropyLoss`'s docs and the `reduction` argument for `CrossEntropyLoss` and `NLLLoss` no longer describes a non-weighted mean of the outputs.

Pull Request resolved: pytorch#40991

Differential Revision: D22395805

Pulled By: ezyang

fbshipit-source-id: a623b6dd2aab17220fe0bf706bd9b62d6ba531fd
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: docs Related to our documentation, both in docs/ and docblocks module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants