Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[fix] missing input log higher order. #15331

Merged
merged 13 commits into from
Nov 19, 2019

Conversation

kshitij12345
Copy link
Contributor

@larroy Thank You very much for catching this.

Sorry for the silly mistake.
Can we have a way to test this?

@apeforest @larroy

@Roshrini Roshrini added the pr-awaiting-review PR is waiting for code review label Jun 23, 2019
@roywei
Copy link
Member

roywei commented Jul 8, 2019

@mxnet-label-bot add [Operator, Backend]

@marcoabreu marcoabreu added Backend Issues related to the backend of MXNet Operator labels Jul 8, 2019
@kshitij12345 kshitij12345 changed the title [fix] missing input log higher order [fix] missing input log higher order. Jul 8, 2019
@larroy
Copy link
Contributor

larroy commented Jul 9, 2019

I'm still not sure what's the meaning of the backward output for the head gradient input as we discussed before. This week we are at a conference so we might be slow to respond.

I'm not sure how to test this, I think I would need to dump the graph and think about it, as I'm not sure now in which python variable is the gradient of the head gradient stored.

I think the PR fixes the issue though. Would the operator had failed on a division without argument? looks like the tests don't execute the Op or?

Would it be better to set those outputs to zero since we don't know how to use them? I'm fine with the fix proposed in this PR though.

@apeforest
Copy link
Contributor

@larroy Those outputs are needed for 3rd order and above gradients.

@apeforest
Copy link
Contributor

@kshitij12345 https://github.com/apache/incubator-mxnet/pull/15331/files#diff-0dad60704ce39e602a1907aec6835375R1121 comment should actually be dL/dygrad. Could you please update it as well?

@apeforest
Copy link
Contributor

apeforest commented Jul 10, 2019

@sxjscience Do we have a use case where the gradient on the gradient of output y is needed?
i.e. ygrad = dL/dy. How can we test the value of dG/dygrad given G is a function G(x, y, xgrad, ygrad) from R^n -> R.

@apeforest
Copy link
Contributor

Please update the comment, otherwise LGTM

@kshitij12345
Copy link
Contributor Author

https://github.com/apache/incubator-mxnet/blob/5171e1d92cfc5eefa2c20dfe8ac3fac5351ad19a/src/operator/tensor/elemwise_unary_op_basic.cc#L1120
dL/dygrad for this one right?

@larroy @apeforest , I was also wondering if we can check the number of inputs passed at compile time? I have observed the MakeNode gets the Op from dynamic registry based on the name. However we actually have information about the number of inputs and outputs for a given Op at compile time. I tried but couldn't actually figure out. What are your thoughts? How easy or hard would it be to check for valid number of inputs in MakeNode? This would help catch these sort of errors at compile time itself.

@larroy
Copy link
Contributor

larroy commented Jul 16, 2019

I think ograd[0] is dL/dx_grad

About the number of inputs, you are right that we could check. If it's more than one or two function calls I think is too much overhead and it's going to get caught with the python tests, Also if you don't return enough gradients, there's a check after calls to fgradient, so I think is not a big deal. Up to you if you can come up with something concise.

@@ -1117,15 +1117,15 @@ MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log10,
unary_bwd<mshadow_op::log10_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// ograds[0]: dL/dxgrad
// ograds[0]: dL/dygrad
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is dL/dx_grad. The head gradient is the gradient with respect to the previous output right? the previous output is x_grad or dL/dx so this thing is dL/(dL/dx) or dL/dx_grad in lack of a better notation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it should be, dL/dy_grad as we are computing/returning dL/dx_grad,
Eg.

y = f(dx_grad)
L = g(y) # dx_grad formed part of the network and affected loss

During backprop by chain rule,
dL/dx_grad = dL/dy * dy/dx_grad

In comments, we have called dL/dy (mentioned in the above example) as dL/dy_grad

That is why we have,
https://github.com/apache/incubator-mxnet/blob/5b95fb3ee3581ba20fe1def336621d68a811e17f/src/operator/tensor/elemwise_unary_op_basic.cc#L1111-L1112

These multiplications performing,

dL/dx_grad = dL/dy * dy/dx_grad

Copy link
Contributor

@larroy larroy Jul 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the notation is complicating us in excess as it gets pretty hairy. It's the head gradient of the previous (output) node, which has shape of x, and x_grad. So it has to be related to x, not y.

I think in Lagrange notation it would be $$F_{L_x}$$ (derivative of some head function with respect to the derivative of the first loss wrt to x. (x_grad).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh. I get it now. If I understand it correctly then, crudely ograds[0] is how much does the x_grad affect the L and then we compute how does x_grad change with x. Makes sense now.

Thank you very much. Will reflect it in this and other PRs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kshitij12345 I think what you write makes sense. I'm also unsure about notations, maybe you can come with a better one. If not maybe we leave the comment out, so we can merge the PR, as the code seems to do what's needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Thanks Again.

// inputs[0]: dL/dy
// inputs[1]: x
// inputs[1]: x (ElemewiseGradUseIn)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice comment, helps.

Copy link
Contributor

@apeforest apeforest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@karan6181
Copy link
Contributor

@kshitij12345 Could you please resolve the merge conflict and ping the reviewers again to merge it? Thanks!

@kshitij12345
Copy link
Contributor Author

kshitij12345 commented Aug 30, 2019

Sure. Thanks

@kshitij12345
Copy link
Contributor Author

@larroy @apeforest Gentle ping for review.

@kshitij12345
Copy link
Contributor Author

@apeforest @larroy Gentle Ping.

@sxjscience
Copy link
Member

I guess we need to add a test case.

@kshitij12345
Copy link
Contributor Author

@sxjscience I am not sure about how to test this. I was expecting that this missing input will cause problem with computing of higher order. Tried for 3rd order, which was computed successfully. For fourth order,
Operator _backward_mul is non-differentiable because it didn't register FGradient attribute (different problem). So I am not very sure as to how to write a test case.

@kshitij12345
Copy link
Contributor Author

@sxjscience @apeforest @larroy Gentle Ping.

Copy link
Contributor

@larroy larroy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry we are all quite busy. I think it's fine to merge this. We can do any additional refinements later.

LGTM

Copy link
Contributor

@apeforest apeforest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Sorry for the delayed response. We have been extremely busy in the past month.

@apeforest apeforest merged commit 60f53ed into apache:master Nov 19, 2019
@kshitij12345 kshitij12345 deleted the fix/missing-input branch November 22, 2019 07:21
@kshitij12345
Copy link
Contributor Author

@apeforest Sure no worries. Thanks.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Backend Issues related to the backend of MXNet Operator pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants