Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix NaN value comparisons in relu, max and min ops #14262

Merged
merged 2 commits into from
Mar 10, 2019

Conversation

anirudhacharya
Copy link
Member

@anirudhacharya anirudhacharya commented Feb 27, 2019

Description

Fix NaN comparisons in relu, max and min ops

Fixes #14157
Fixes #11115

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http:https://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • change the way max, min and relu operators function to make sure nan values are handled properly. And added tests.

@anirudh2290 @apeforest

Copy link
Contributor

@apeforest apeforest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this? Doesn't we already have a is_nan() for all operators?

@anirudhacharya
Copy link
Member Author

@apeforest we need it because operations like this mx.nd.relu(np.NaN*nd.ones(1), out) returns [0.] whereas activation functions are expected to propagate the NaN values.

And maximum and minimum operators have inconsistent behavior w.r.t NaN values.

>>> a = np.NaN*nd.ones(1)
>>> b = nd.zeros(1)

>>> nd.maximum(a,b)
[0.]
<NDArray 1 @cpu(0)>

>>> nd.maximum(b,a)
[nan]
<NDArray 1 @cpu(0)>

@szha
Copy link
Member

szha commented Feb 27, 2019

@anirudhacharya I'm not sure if it answers the question. Why do we need to start to support nan values, especially given the extra handling required for nan.

@anirudhacharya
Copy link
Member Author

anirudhacharya commented Feb 27, 2019

@szha I ideally do not want relu operator to clip NaN values to zero, especially when I am trying to debug a model.
And with regards to maximum and minimum it is not about 'starting to support' nan values but to fix inconsistent handling of nan values.

Pytorch's relu behavior -

>>> import torch
>>> import torch.nn as nn
>>> m = nn.ReLU()
>>> input = np.NaN * torch.ones(1)
>>> out = m(input)
>>> out
tensor([nan])

Also I found a related issue here - #14157

Edit - Another issue filed some time ago which had slipped from my memory - #11115

@anirudhacharya
Copy link
Member Author

@mxnet-label-bot add [pr-awaiting-review]

@marcoabreu marcoabreu added the pr-awaiting-review PR is waiting for code review label Feb 27, 2019
@szha
Copy link
Member

szha commented Feb 28, 2019

@anirudhacharya thanks for the explanation. should relu grad deal with nan in a special way?

@anirudhacharya
Copy link
Member Author

@szha yes I think the relu grad should also be handled in a special way, thanks for pointing it out.

Currently relu grad at nan returns a 0 by evaluating this expression

MXNET_UNARY_MATH_OP_NC(relu_grad, a > DType(0) ? DType(1) : DType(0));

But max(NaN, 0) evaluates to NaN and that should translate to a relu grad value of 1 and not 0. I will make the changes to fix it.

FYI - Here is an in depth conversation on NaN handling - JuliaLang/julia#7866

@adrianloy
Copy link

Nice PR! I also had a bug in my model, and because of relu activations removing NaNs it took me much longer to realize there is a bug. Behaviour should definitely be changed!

@szha
Copy link
Member

szha commented Mar 7, 2019

@anirudhacharya I'm not sure if relu grad should act like that. As a sanity check, consider if nan is larger than or smaller than 0.

szha
szha previously requested changes Mar 7, 2019
Copy link
Member

@szha szha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems that nan should be surfaced in relu grad instead of 1 when output is nan, because nan is not a number.

@wkcn
Copy link
Member

wkcn commented Mar 7, 2019

Could we add a new operator to check whether there are nan? It is used when debug.
When nan appears in the model, it always means that the model fails, and the output and gradient are not reliable.

@anirudhacharya
Copy link
Member Author

@anirudhacharya I'm not sure if relu grad should act like that. As a sanity check, consider if nan is larger than or smaller than 0.

nan compared to any number is always False. nan > 0 -> False and nan < 0 -> False

Ref - https://stackoverflow.com/questions/49011370/nan-propagation-and-ieee-754-standard/49040225

But there are languages and libraries which consider nan to be greater than any number even np.inf.

@szha
Copy link
Member

szha commented Mar 7, 2019

@anirudhacharya it's not about comparison. nan is not in the domain of the function.

@anirudhacharya
Copy link
Member Author

@szha What you say makes sense, nan is not a number and hence not in the realm of comparison. So any occurrence either forward or backward will have to be propagated.

Maybe pytorch is doing it wrong, but just for comparison's sake pytorch seems to treat the gradient of relu @ NaN as equal to 1 -

>>> import torch
>>> import numpy as np
>>> a = np.NaN * torch.ones(1)
>>> a.requires_grad_(True)
tensor([nan], requires_grad=True)
>>> m = torch.nn.ReLU()
>>> out = m(a)
>>> out.backward()
>>> a.grad
tensor([1.])

My main motivation when I first made changes to the relu forward behavior was that the operator silently clipping NaN values was very misleading while trying to build or debug models.

I am open to suggestions on how relu gradient should behave, it would seem there is no single consensus on this and each community/library decide things for themselves .

@anirudhacharya
Copy link
Member Author

Could we add a new operator to check whether there are nan? It is used when debug.
When nan appears in the model, it always means that the model fails, and the output and gradient are not reliable.

I think you are looking for this - http:https://mxnet.incubator.apache.org/api/python/ndarray/contrib.html?highlight=isnan#mxnet.ndarray.contrib.isnan

@anirudhacharya
Copy link
Member Author

anirudhacharya commented Mar 8, 2019

I modified relu grad to also propagate NaN values. As discussed above, since NaN does not exist in the domain of the function, it can also not be mapped to any element in the range of the function, hence the output is also NaN.

@szha szha dismissed their stale review March 8, 2019 03:59

concern addressed.

@szha
Copy link
Member

szha commented Mar 8, 2019

@anirudhacharya one last thing, could you measure the performance before and after this change? This change is nonetheless necessary, still it would better if we could anticipate any performance change from this. Thanks.

@anirudhacharya
Copy link
Member Author

Run Mode: Before --> After ( time in ms)

'Whole CPU run: ' - 0.843163 --> 0.864071
'Forward CPU run: ' - 0.016467 --> 0.043115
'Whole GPU run: ' - 0.460900 --> 0.480667
'Forward GPU run: ' - 0.058783 --> 0.059333

script used

import mxnet as mx                                                                                                                                                    
import numpy as np
from mxnet.test_utils import check_speed

ctx = mx.cpu()
#ctx = mx.gpu(0)

sample_data = mx.nd.ones((3, 500, 500), ctx=ctx)
sample_data[0] = -1. 
sample_data[1] = np.NaN
sample = mx.sym.Variable("sample")
relu_sym = mx.sym.relu(data=sample)

print("Whole CPU run: ", check_speed(relu_sym, location={"sample": sample_data}, ctx=ctx, N=int(1e5), typ="whole"))
print("Forward CPU run: ", check_speed(relu_sym, location={"sample": sample_data}, ctx=ctx, N=int(1e5), typ="forward"))
#print("Whole GPU run: ", check_speed(relu_sym, location={"sample": sample_data}, ctx=ctx, N=int(1e5), typ="whole"))
#print("Forward GPU run: ", check_speed(relu_sym, location={"sample": sample_data}, ctx=ctx, N=int(1e5), typ="forward"))

@szha szha merged commit c645591 into apache:master Mar 10, 2019
@anirudhacharya anirudhacharya deleted the relu branch March 10, 2019 08:58
vdantu pushed a commit to vdantu/incubator-mxnet that referenced this pull request Mar 31, 2019
nswamy pushed a commit that referenced this pull request Apr 5, 2019
haohuanw pushed a commit to haohuanw/incubator-mxnet that referenced this pull request Jun 23, 2019
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Inconsistent handling for nan ReLU Clips NaNs to Zero
7 participants