Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add matrix inversion operator in linalg #14963

Merged
merged 21 commits into from
May 20, 2019
Merged

Conversation

arcadiaphy
Copy link
Member

As title.

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http:https://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Feature1, tests, (and when applicable, API doc)
  • Feature2, tests, (and when applicable, API doc)

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Interesting edge cases to note here

struct set_matrix : public mxnet::op::mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType **p, DType *m, int step) {
p[i] = m + i * step;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indent?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I'll fix it.

Copy link
Contributor

@larroy larroy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice stuff!

LINALG_CPU_GETRF(dgetrf, double)

#ifdef __CUDACC__

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment on what and why fill up the matrix this way?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK.

DType **A_ptr = static_cast<DType **>(A_ptr_buf.dptr); \
const Tensor<gpu, 3, DType> temp(work.dptr_, A.shape_, s); \
int *pivot = reinterpret_cast<int *>(temp.dptr_ + temp.shape_.Size()); \
int *info = pivot + A.size(0) * A.size(1); \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens on int32 overflow?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pivot's range is in [0, matrix_dim), I think creating a square matrix with int32 overflow dimension is not possible.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you know much more about this, but from what I understand A is the input matrix which gets overwritten by LU no? My question was if the product of A.size(0) and A.size(1) overflows, this can happen if both are bigger than 2^16 unless I'm mistaken. I have seen this bug in other places before we call Blas, it was nasty.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tried to write a overflow case, it always fails on size checks in Tensor or Blob. I think it's the right way, for ndarray with overflow size, it should fail in advance and not reach the code above.

DType **B_ptr = static_cast<DType **>(B_ptr_buf.dptr); \
Tensor<gpu, 3, DType> temp(work.dptr_, A.shape_, s); \
int *pivot = reinterpret_cast<int *>(temp.dptr_ + temp.shape_.Size()); \
int *info = pivot + A.size(0) * A.size(1); \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

.add_argument("A", "NDArray-or-Symbol", "Tensor of square matrix");

NNVM_REGISTER_OP(_backward_linalg_inverse)
.set_num_inputs(3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we have 3 inputs?

Copy link
Member Author

@arcadiaphy arcadiaphy May 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I use ElemwiseGradUseInOut, so the 3 inputs are out_grad, input, output. Actually, input is not used in computing in_grad, I'll change it to ElemwiseGradUseOut.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, It looked strange to me.

@arcadiaphy
Copy link
Member Author

I'll merge this PR now and create another PR on matrix determinant which is depended upon this one.

@arcadiaphy arcadiaphy merged commit 3cbfe48 into apache:master May 20, 2019
@arcadiaphy arcadiaphy deleted the pr_linalg branch May 20, 2019 13:40
haohuanw pushed a commit to haohuanw/incubator-mxnet that referenced this pull request Jun 23, 2019
* add inverse cpu

* add comment

* add inverse backward cpu

* add inverse gpu

* able to compile

* fix

* fix

* guard for lower version cuda

* update docs

* update docs

* fix misaligned memory

* add test

* fix lint

* fix android

* fix indent

* change transfer gradient

* fix

* refactor test

* delete unnecessary copy

* trigger CI

* fix test
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants