-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-1233] Enable dynamic shape in CachedOp #13419
Conversation
@mxnet-label-bot add [pr-work-in-progress] |
@mxnet-label-bot add [pr-awaiting-review] |
@mxnet-label-bot remove [pr-work-in-progress] |
src/imperative/cached_op.cc
Outdated
@@ -262,6 +262,29 @@ std::vector<nnvm::NodeEntry> CachedOp::Gradient( | |||
return ret; | |||
} | |||
|
|||
bool CachedOp::CheckDynamicShapeExists(const Context& default_ctx, | |||
const std::vector<NDArray*>& inputs) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i wonder if it's better to check operators with dynamic shape directly. right now, it assumes that if a computation graph can't infer shape, it contains dynamic-shape operators. it's better to write one that works for both CachedOp and symbol executor. It's a property of a computation graph whether a graph contains dynamic shape. We can easily check it by traversing all operators in a graph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per our discussion last time, I think our solution should be naive implementation first, and then do graph partitioning to speed stuff up.
arrays.reserve(num_entries); | ||
for (auto& item : runtime.buff) { | ||
arrays.push_back(&item); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i wonder if we should buffer arrays from the previous run?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why buffer stuff from previous run? To save memory alloc overhead?
Context ctx = GetContext(node.source->attrs, ndinputs, ndoutputs, default_ctx); | ||
auto invoke = [&](const OpStatePtr &state) { | ||
DispatchMode dispatch_mode = DispatchMode::kUndefined; | ||
SetShapeType(ctx, node.source->attrs, ndinputs, ndoutputs, &dispatch_mode); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we still infer shape here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. This function leverages the existing infer_shape to find out whether there are dynamic shape stuff inside the graph.
auto fwd_node_id = idx.node_id(fwd_node); | ||
cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req, ndoutputs); | ||
} else if (createop.count(node.source->op())) { | ||
// case 2: node is in createop |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this is to handle stateful operators
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, you are right. Should I change the comments?
@@ -1002,6 +1009,18 @@ void RunGraph(const bool retain_graph, | |||
const DispatchModeVector &dispatch_modes, | |||
bool recording); | |||
|
|||
|
|||
void NaiveRunGraph(const bool retain_graph, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these new functions deserve documentations for inputs/outputs and what it does to keep the readability
@eric-haibin-lin another round of review? |
@zheng-da - Can you please take a look at this PR again? |
I am going to close this PR, split it into small pieces, and PR again. |
Description
This PR enables dynamic shape in CachedOp, a.k.a. the backend supporting
HybridBlock
in Gluon.In the forward pass, we have to invoke operators one-by-one because dynamic shape disallows us to allocate memory ahead of time.
The backward pass is actually unaffected, because after forward, shape of everything becomes known.
CC: @zheng-da @szha @yidawang
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
NaiveRunGraph
in imperative mode, in which operators are executed in a synchronized manner.NaiveForward
mode in CachedOp, which callsNaiveRunGraph
.CheckDynamicShapeExists
in CachedOp, which tells whether the graph contains an operator returning dynamic shape.Comments