Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-978] Higher Order Gradient Support arcsin, arccos. #15515

Merged

Conversation

kshitij12345
Copy link
Contributor

Description

PR intends to add support for higher order gradient for arcsin, arccos.

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA-978 issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http:https://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • higher order gradient for a arcsin, arccos.
  • unit test for the same.

@kshitij12345 kshitij12345 force-pushed the develop/add-higher-order/arcsin-arccos branch from 4c45f2f to 7daaf76 Compare July 11, 2019 14:37
Copy link
Contributor

@ChaiBapchya ChaiBapchya left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Barring that hacky way of testing (which makes sense but I'll wait for committers to approve it), LGTM! Thanks for your contribution!

@kshitij12345
Copy link
Contributor Author

kshitij12345 commented Jul 14, 2019

Thank You.
Do note that, the other alternative is 2.*(array - np.min(array))/np.ptp(array)-1, but we will need to handle the case with array of one element array.

@karan6181
Copy link
Contributor

@mxnet-label-bot add [Operator, pr-awaiting-review]

@marcoabreu marcoabreu added Operator pr-awaiting-review PR is waiting for code review labels Jul 16, 2019
@kshitij12345
Copy link
Contributor Author

@apeforest @larroy @sxjscience Gentle Ping for review.:)

Copy link
Contributor

@larroy larroy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi kshitij12345, thanks a lot for your contribution. In general looks good. One question regarding the first output of the second gradient.

auto grad_grad_x = op.mul(dydx_mul_grad_x, grad_x_square_mul_x);

std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(op.mul(ograds[0], grad_x));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the first input is y_grad or dL/dy, this gradient should be dL/(dy*dx) ?
didn't we have the convention of x_grad instead of grad_x?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late reply.

if the first input is y_grad or dL/dy, this gradient should be dL/(dy*dx) ?

I am not sure of dL part because we don't really use it in computing the loss function.

didn't we have the convention of x_grad instead of grad_x?

Oops. Thanks. Was a old PR. Will update the names.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, l guess calling it x_grad_y_grad is fine. Sorry CI is flaky now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x_grad_y_grad we are not naming that particular variable anywhere. Or am I confusing something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@larroy Can you elaborate on what you meant with x_grad_y_grad. I am slightly confused. Thanks.

@kshitij12345
Copy link
Contributor Author

@apeforest @larroy @sxjscience
Gentle Ping for review.

Copy link
Contributor

@apeforest apeforest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Contributor

@larroy larroy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

If you have interest, one idea that circulated around to enhace higher order gradients is to add a "FGradientSymbolic" function that gets triggered when the higher order gradient is not avaiable and changes the graph to have the forward pass expressed in terms of differentiable primitives. We can talk more if you are interested.

@sxjscience sxjscience merged commit ed09547 into apache:master Dec 18, 2019
@kshitij12345
Copy link
Contributor Author

@larroy , I am quite interested in the idea. What would be a good place to talk? Slack?

@sxjscience
Copy link
Member

@larroy @kshitij12345 @apeforest I think we can use Slack. What do you think?

@kshitij12345
Copy link
Contributor Author

Slack sounds good to me.

@apeforest
Copy link
Contributor

@kshitij12345 I have sent a slack invite to [email protected]. Please accept. Thanks!

@larroy
Copy link
Contributor

larroy commented Jan 4, 2020

Slack or mailing list are fine for me.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Operator pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants