Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add stable nrm2 Reducer #11573

Merged
merged 7 commits into from
Jul 11, 2018
Merged

Add stable nrm2 Reducer #11573

merged 7 commits into from
Jul 11, 2018

Conversation

leezu
Copy link
Contributor

@leezu leezu commented Jul 5, 2018

Description

The mxnet L2 norm implementation is currently prone to under/overflow when called on NDArrays with very small or large elements. This PR adds a stable L2 norm reducer and switches the dense norm operator to make use of it.

The implementation follows the Blas nrm2 reference implementation.

As a final rescaling step is necessary at the end of the reduction, this PR also introduces a new Finalize method for Reducers. This method is a noop for all previous reducers. A separate PR needs to add the Finalize method to the maximum and minimum reducers in mshadow.

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http:https://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Improve numerical stability for L2 norm in norm() operator
  • Improve numerical stability of sum in broadcast reduce

Comments

  • The test_norm() test was currently disabled (Hanging flaky test test_operator.test_norm @ Python 3: GPU Win #11509). It is enabled in this PR. Maybe using a different norm2 implementation will fix the flakiness of the test..
  • I tried avoiding the need to add the Finalize method to all existing reducers (i.e. maximum, minimum, sum, product, nansum, nanprod) using SFINAE but it didn't get it to work immediately (the Finalize() would correctly be skipped if it the respective reducer didn't implement it, but for unknown reasons was also skipped if the reducer implemented it). As there are not many reducers, I added the method for now. The SFINAE approach was as follows:
template <typename Reducer, typename DType>
auto reducer_finalize(DType val, DType residual, int) -> decltype(Reducer::Finalize, void()) {
   Reducer::Finalize(val, residual);
}

template <typename Reducer, typename DType>
auto reducer_finalize(DType val, DType residual, long) -> decltype(void()) {}

@leezu leezu requested a review from anirudh2290 as a code owner July 10, 2018 00:38
@leezu leezu force-pushed the stablenrm2 branch 3 times, most recently from 590ddef to 0c391f6 Compare July 10, 2018 05:11
@leezu
Copy link
Contributor Author

leezu commented Jul 10, 2018

@anirudh2290 @szha This is ready for review. Tests are passing, but for some reason the deploy stage failed due to a Jenkins error. Rerunning now.

check_numeric_gradient(norm_sym, [in_data], numeric_eps=epsilon, rtol=1e-2, atol=1e-3)

# Disable numeric gradient https://github.com/apache/incubator-mxnet/issues/11509
# # check gradient
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to be commented. Did this not help for #11509 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately it didn't help. I observed a similar error to #11509 (comment)

template<typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_ssq, volatile DType& dst_scale, volatile DType& src_ssq, volatile DType& src_scale) { // NOLINT(*)
if (dst_scale != 0 && dst_scale >= src_scale) {
dst_ssq = dst_ssq + src_ssq * (src_scale / dst_scale) * (src_scale / dst_scale);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please elaborate on how this expression was obtained

Copy link
Contributor Author

@leezu leezu Jul 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Remember that we use a scaled sum of squares to compute the L2 norm, to avoid numeric instability caused by the squaring and subsequently taking the square root of very small / large numbers.
For efficient reducing, on GPU multiple reducers compute a reduction of a part of a vector to be reduced. Their result is a scaled sum of squares. To combine the reducers, we must find a common scale for all of them. Following the implementation of Reduce, I choose the largest scale.

Above equation simply rescales the sum of squares of the reducer that currently uses a smaller scale value, such that in the end norm(x) = sqrt(ssq) * scale = dst_scale * sqrt(dst_ssq + src_ssq*src_scale/dst_scale*src_scale_dst_scale) = sqrt(src_scale*src_scale*src_ssq + dst_scale*dst_scale*dst_ssq) (where we wan't to avoid the right part due to numerical instability; here scale and ssq denote what is written to dst_ssq and dst_scale in above code).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the explanation!

@szha szha merged commit 43ad56c into apache:master Jul 11, 2018
@leezu leezu deleted the stablenrm2 branch July 11, 2018 16:55
XinYao1994 pushed a commit to XinYao1994/incubator-mxnet that referenced this pull request Aug 29, 2018
* Add stable nrm2 Reducer

* Prefer scipy.linalg.norm over np.linalg.norm as it is numerically stable

* Update mshadow

* Add stable reducer merge

* Use stable merging of reducers in broadcast_reduce-inl.cuh

* Update mshadow

* Update tests
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants