Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1426] Fix the wrong result of sum, mean, argmin, argmax when inputs contain inf or nan #16234

Merged
merged 21 commits into from
Nov 12, 2019

Conversation

wkcn
Copy link
Member

@wkcn wkcn commented Sep 22, 2019

Description

Hi, there.
I fix the wrong result of sum(inf, inf) and mean(inf, inf).

Test Case:

import mxnet as mx
import numpy as np

def test(x):
    print('data', x.asnumpy())
    print('mean/sum', mx.nd.mean(x).asnumpy(), mx.nd.sum(x).asnumpy())
    print('argmin/argmax', mx.nd.argmin(x).asnumpy(), mx.nd.argmax(x).asnumpy())
    print('min/max', mx.nd.min(x).asnumpy(), mx.nd.max(x).asnumpy())
    print('-----')

x = mx.nd.array([np.inf, np.inf, 1])
test(x)

x = mx.nd.array([-np.inf, -np.inf, 1])
test(x)

x = mx.nd.array([np.inf, -np.inf, 1])
test(x)

x = mx.nd.array([np.nan, -np.inf, 1])
test(x)

x = mx.nd.array([np.nan, np.nan])
test(x)

x = mx.nd.array([np.nan, 1])
test(x)

The wrong result in mxnet_mkl-1.6.0b20191015-py2.py3-none-manylinux1_x86_64

data [inf inf  1.]
mean/sum [nan] [nan]
argmin/argmax [2.] [0.]
min/max [1.] [inf]
-----
data [-inf -inf   1.]
mean/sum [nan] [nan]
argmin/argmax [0.] [2.]
min/max [-inf] [1.]
-----
data [ inf -inf   1.]
mean/sum [nan] [nan]
argmin/argmax [1.] [0.]
min/max [-inf] [inf]
-----
data [ nan -inf   1.]
mean/sum [nan] [nan]
argmin/argmax [1.] [2.]
min/max [-inf] [1.]
-----
data [nan nan]
mean/sum [nan] [nan]
argmin/argmax [0.] [0.]
min/max [inf] [-inf]
-----
data [nan  1.]
mean/sum [nan] [nan]
argmin/argmax [1.] [1.]
min/max [1.] [1.]

The correct result in this PR:

data [inf inf  1.]
mean/sum [inf] [inf]
argmin/argmax [2.] [0.]
min/max [1.] [inf]
-----
data [-inf -inf   1.]
mean/sum [-inf] [-inf]
argmin/argmax [0.] [2.]
min/max [-inf] [1.]
-----
data [ inf -inf   1.]
mean/sum [nan] [nan]
argmin/argmax [1.] [0.]
min/max [-inf] [inf]
-----
data [ nan -inf   1.]
mean/sum [nan] [nan]
argmin/argmax [0.] [0.]
min/max [nan] [nan]
-----
data [nan nan]
mean/sum [nan] [nan]
argmin/argmax [0.] [0.]
min/max [nan] [nan]
-----
data [nan  1.]
mean/sum [nan] [nan]
argmin/argmax [0.] [0.]
min/max [nan] [nan]

If we modify the test function,

def test(x):
    x = x.asnumpy()
    print('data', x)
    print('mean/sum', np.mean(x), np.sum(x))
    print('argmin/argmax', np.argmin(x), np.argmax(x))
    print('min/max', np.min(x), np.max(x))
    print('-----')

Here is the result of NumPy.

data [inf inf  1.]
mean/sum inf inf
argmin/argmax 2 0
min/max 1.0 inf
-----
data [-inf -inf   1.]
mean/sum -inf -inf
argmin/argmax 0 2
min/max -inf 1.0
-----
data [ inf -inf   1.]
mean/sum nan nan
argmin/argmax 1 0
min/max -inf inf
-----
data [ nan -inf   1.]
mean/sum nan nan
argmin/argmax 0 0
min/max nan nan
-----
data [nan nan]
mean/sum nan nan
argmin/argmax 0 0
min/max nan nan
-----
data [nan  1.]
mean/sum nan nan
argmin/argmax 0 0
min/max nan nan
-----

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http:https://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Add isnan_typed and isinf_typed in mshadow
  • Remove isnan_typed in src/operator/mshadow_op.h
  • Update mshadow/extension/reduce_with_axis.h to support NaN
  • Add relative python testcases
  • Update Julia testcase, which is consistent with Julia built-in functions argmin and argmax.

3rdparty/mshadow/mshadow/base.h Outdated Show resolved Hide resolved
@wkcn wkcn changed the title Fix the wrong result of sum(inf, inf) and mean(inf, inf) [MXNET-1426] Fix the wrong result of sum(inf, inf) and mean(inf, inf) Sep 22, 2019
@wkcn wkcn requested a review from iblislin as a code owner September 22, 2019 09:10
@wkcn wkcn changed the title [MXNET-1426] Fix the wrong result of sum(inf, inf) and mean(inf, inf) [MXNET-1426] Fix the wrong result of sum, mean, argmin, argmax when inputs contain inf or nan Sep 22, 2019
julia/test/unittest/ndarray.jl Outdated Show resolved Hide resolved
Copy link
Member

@iblislin iblislin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Julia part looks fine for me.

@wkcn
Copy link
Member Author

wkcn commented Sep 23, 2019

Hi @marcoabreu @access2rohit , could you please help take a review?
The PR fixes the wrong result of sum, mean, argmin and argmax when inputs contain inf or NaN.
Thank you!

@wkcn wkcn added the pr-awaiting-review PR is waiting for code review label Sep 23, 2019
@wkcn wkcn requested a review from szha as a code owner September 24, 2019 03:10
@wkcn
Copy link
Member Author

wkcn commented Oct 9, 2019

Hi @reminisce and @haojin2 , could you please help take a review?

This PR makes the following functions consistent with NumPy.

sum, mean, argmin, argmax

Thank you!

@wkcn
Copy link
Member Author

wkcn commented Oct 16, 2019

Hi @eric-haibin-lin , could you please help take a review?
It is a bug, which outputs a wrong result or an inconsistent result with that of NumPy.

Thank you so much!

@eric-haibin-lin
Copy link
Member

Would you mind also add what is the result before this fix?

@wkcn
Copy link
Member Author

wkcn commented Oct 17, 2019

Hi @eric-haibin-lin , I have updated the test result : )

@szha
Copy link
Member

szha commented Oct 21, 2019

cc @reminisce

@wkcn
Copy link
Member Author

wkcn commented Nov 10, 2019

Ping : )

}
template<>
MSHADOW_XINLINE bool IsNan(volatile mshadow::half::half_t val) {
return (val.half_ & 0x7fff) > 0x7c00;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you turn these magic values into constants with documentation? While I get 0x7ffff, 0x7c00 for example, looks quite arbitrary.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @marcoabreu , I add two constants MSHADOW_HALF_SIGN_BIT and MSHADOW_HALF_EXPONENT_BITS in 3rdparty/mshadow/mshadow/half.h, and replace these two magic values.

if type(ndarray_ret) is mx.ndarray.NDArray:
ndarray_ret = ndarray_ret.asnumpy()
assert (ndarray_ret.shape == numpy_ret.shape) or \
(ndarray_ret.shape == (1,) and numpy_ret.shape == ()), "nd:%s, numpy:%s" \
%(ndarray_ret.shape, numpy_ret.shape)
err = np.square(ndarray_ret - numpy_ret).mean()
assert err < 1E-4
if check_dtype:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate why you're introducing so much branching into a test? If the results are inconsistent, we should rather improve the test instead of skipping the checks. I'd love to have more detail

Copy link
Member Author

@wkcn wkcn Nov 12, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @marcoabreu , here is the explanation.

  1. So much branching
    We need to test all reduce operators, like min, max, argmin, argmax, sum, mean when the inputs contain -inf, +inf, nan.

  2. Skipping the checks
    I replace the old check with a new one. : )

@marcoabreu
Copy link
Contributor

marcoabreu commented Nov 10, 2019

I'll merge after the feedback has been addressed :) Sorry for the delay

@wkcn wkcn added pr-awaiting-merge Review and CI is complete. Ready to Merge and removed pr-awaiting-review PR is waiting for code review labels Nov 12, 2019
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Bug Operator pr-awaiting-merge Review and CI is complete. Ready to Merge
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants