Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Counter intuitive behavior of nn.CrossEntropy/nn.NLLLoss with weights and issue with gradient accumulation #72047

Closed
idc9 opened this issue Jan 30, 2022 · 5 comments
Labels
module: loss Problem is related to loss function module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@idc9
Copy link

idc9 commented Jan 30, 2022

🚀 The feature, motivation and pitch

The behavior of mean reduction in nn.CrossEntropy and nn.NLLLoss is counter intuitive when there are class weights as discussed in #9882. The current behavior performs a weighted average instead of an unweighted average, which is probably what people expect.

This counter intuitive behavior also causes an issue when doing gradient accumulation. In particular, when you adjust the loss function to account for gradient accumulation (i.e. to make the divisor batch_size x n_grad_accum_steps instead of just batch_size) you no longer have the exact gradients (i.e. the gradients you would have had if your batch size was batch_size x n_grad_accum_steps).

# Example gradient accumulation code
# loader, model, optimizer, weights set above 

# unweighted case -- no issue for gradient accumulation
loss_func = nn.CrossEntropyLoss(reduction='mean')

# weighted case with an issue
# loss_func = nn.CrossEntropyLoss(weights=weights, reduction='mean')

# this loop assumes everything is nicely divisible -- it can be modified to handle when
# num batches is not divisible by grad_accum and num_samples is not divisible by batch_size
for batch_idx, (x, y_true) in enumerate(loader):
    y_pred = model(x)
    loss = loss_func(y_true, y_pred) 
    
    # adjust for gradient accumulation
    # this gives you exact gradients in the unweighted case, but not in the weighted case!
    loss = loss / n_grad_accum_batches
    
    loss.backward()
    if (batch_idx + 1) % n_grad_accum_batches == 0:
        optimizer.step()
        optimizer.zero_grad()

You can of course address this issue if you use reduction=sum and manually averaging the loss, but this is clunky and probably frequently overlooked.

Possible solution

A straightforward solution and the most intuitive -- at least to me -- would be

  1. make reduction='mean' perform an unweighted average
  2. introduce reduction='weighted_mean' for weighted averages (current behavior of mean)
  3. default to to the unweighted mean case, which is probably what users expect to happen

Alternatives

No response

Additional context

No response

cc @albanD @mruberry @jbschlosser @walterddr @kshitij12345

@albanD
Copy link
Collaborator

albanD commented Feb 1, 2022

Hi,

Thanks for the suggestion.
Note that for BC reasons, we won't be able to change the default behavior.
Also the current behavior is very clearly documented so I don't think there is any confusion here right?

Note that in the case that you share, even when weights are not involved, if you get a partial batch from your dataloader, then the final loss will also be wrong. So when doing multiple batch accumulation for a single backward, I would recommend you use the sum reduction and do the division yourself at the end.

@albanD albanD added module: loss Problem is related to loss function module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Feb 1, 2022
@idc9
Copy link
Author

idc9 commented Feb 1, 2022

Can you add an unweighted_average option? (Perhaps there's a better name).

Yes the weighted average is is documented clearly. It's still a counter intuitive choice I suspect many people overlook, but I take your point about BC/defaulting to weighted average.

You are right about partial batches -- and perhaps worse, number batches not divisible by the number of grad accumulation steps -- the code above was just for exposition purposes. The user would have to write some additional code to manually compute the divisor in either sum/mean cases. I point out the gradient accumulation issue because I suspect many people using gradient accumulation are using the default (weighted) mean reduction, which would be very clunky to address even with additional code.

@jbschlosser
Copy link
Contributor

Is this a duplicate of #61309?

@idc9
Copy link
Author

idc9 commented Feb 1, 2022

Ah thanks for pointing that out.

One of the reasons I posted this issue is that the weighted mean makes implemented gradient accumulation properly very difficult for weighted losses (its a bit annoying but very doable to implement gradient accumulation with unweighted mean reduction).

@jbschlosser
Copy link
Contributor

Closing this issue for now as a duplicate but the additional context is useful! Let's continue discussion on #61309 to get this resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: loss Problem is related to loss function module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants