-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-599] Partial shape infer for Slice #11406
Conversation
src/operator/tensor/matrix_op-inl.h
Outdated
@@ -674,13 +675,23 @@ inline void SetSliceOpOutputDimSize(const index_t i, const int b, | |||
const int e, const int s, | |||
TShape* oshape) { | |||
if (s > 0) { | |||
CHECK_LT(b, e) << "slicing with begin=[" << i << "]=" << b << ", end[" << i << "]=" | |||
CHECK_LE(b, e) << "slicing with begin=[" << i << "]=" << b << ", end[" << i << "]=" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the case b = e
required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, now when shape of a dim is not known, b==e==0
src/operator/tensor/matrix_op-inl.h
Outdated
// for partial shape infer | ||
(*oshape)[i] = 0; | ||
} else { | ||
(*oshape)[i] = (e - b - 1) / s + 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should only fill in output shape dim size when the corresponding input shape dim size is non-zero.
That is if input shape is (0, 20), begin=(0, 0), end=(5, 10), it should return (0, 10), instead of (5, 10) as inferred output shape.
@sandeep-krishnamurthy Could you confirm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that is correct. Even if one of the shape is unknown (0) we cannot/should not infer output shape.
@@ -690,9 +701,10 @@ inline bool SliceOpShape(const nnvm::NodeAttrs& attrs, | |||
CHECK_EQ(in_attrs->size(), 1U); | |||
CHECK_EQ(out_attrs->size(), 1U); | |||
const TShape& dshape = (*in_attrs)[0]; | |||
if (dshape.ndim() == 0 || dshape.Size() == 0) return false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we are removing one of the checks, please add
return !shape_is_none(dshape) && !shape_is_none(oshape)
at the end of this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
Thanks a lot @rahul003 - This is very useful fix for all Keras-MXNet users using slice operation. Overall looks good to me, except 1 change that Jun mentioned about unknown dimension should be unknown. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you. Verified with Keras-MXNet also.
@reminisce - Requesting you to kindly review the updates :-)
Updated the PR with suggested changes to set as 0 when axis is unknown. Summary of changes to help review: |
@@ -5740,6 +5740,30 @@ def test_slice_forward_backward(a, index): | |||
slice_sym = mx.sym.slice(data, begin=[0, None], end=[1, None], step=[2, -1]) | |||
check_numeric_gradient(slice_sym, [in_data]) | |||
|
|||
def test_slice_partial_infer(): | |||
var1 = mx.sym.var(name="data", shape=(0,20)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you modulize the test code? For example:
def check_slice_partial_shape_infer(data, begin, end, step, expected_out_shape):
out = mx.sym.slice(data, begin, end, step)
assert out.infer_shape_partial()[1][0] == expected_out_shape
check_slice_partial_shape_infer(var1, (None, None), (None, 10), None, (0, 10))
...
@reminisce Made the requested changes and added a couple more tests. Please check. Thanks! |
Thank you very much @rahul003 and @reminisce |
* slice change * add tests * add slice axis test * update check * Whitespace * bring back a check for ndim * trigger ci * update case when axis can't be inferred * test failure * refactor and add more tests * fix lint
Description
Fix for #11349
Allow partial shape inference of slice operator
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
@sandeep-krishnamurthy @reminisce Please review, thanks!