-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
'mean' reduction result in CrossEntropyLoss mismatches with manually computing mean #40560
Comments
This isn't a bug, but the documentation isn't clear -- you'd only find the explanation if you look at the documentation for NLLLoss. See #31295 for the case in NLLLoss. |
High priority for the docs fixes. We should do the following:
We should fix this up for NLLLoss and audit all the other places the reduction string is used to make sure these are accurate. If any of them aren't, we should fix them and also add the math description to 1) for them. |
Thank you for your quick response! |
It is the standard definition of the weighted mean. For integral weights, this gives you the same answer as the mean where each element a = torch.rand(10)
w = torch.ones(10)
assert (a * w).sum() / w.sum() == a.mean() |
Summary: Closes pytorch#40560 This adds the equation for the weighted mean to `CrossEntropyLoss`'s docs and the `reduction` argument for `CrossEntropyLoss` and `NLLLoss` no longer describes a non-weighted mean of the outputs. Pull Request resolved: pytorch#40991 Differential Revision: D22395805 Pulled By: ezyang fbshipit-source-id: a623b6dd2aab17220fe0bf706bd9b62d6ba531fd
🐛 Bug
Hi, during training, I noticed that when specifiying weights for CrossEntropyLoss, using the 'mean' reduction produces a different loss output, compared to using the 'none' reduction and computing the mean manually.
To Reproduce
Steps to reproduce the behavior:
Expected behavior
The loss computed by the two methods should be equal
torch.equal(loss_a, loss_b) = True
Environment
Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).
You can get the script and run it with:
conda
,pip
, source): pipAdditional context
It seems to be the problem with the weights as I tried out to assign the same weight to all the classes, however, theorectically, the weighted loss should be computed before applying the mean() operation.
cc @ezyang @gchanan @zou3519 @jlin27 @albanD @mruberry
The text was updated successfully, but these errors were encountered: