Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add MXEnginePushAsync and MXEnginePushSync C APIs #14615

Merged
merged 7 commits into from
Apr 12, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add PushAsyncPtr and PushSyncPtr APIs in engine
  • Loading branch information
yuxihu committed Apr 9, 2019
commit cdaf7446c720910dec62075f1b284763ae17fe80
41 changes: 41 additions & 0 deletions include/mxnet/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ class MXNET_API Engine {
typedef std::function<void(RunContext)> SyncFn;
/*! \brief Asynchronous operation to pass to engine. */
typedef std::function<void(RunContext, CallbackOnComplete)> AsyncFn;
/*! \brief Synchronous operation (function pointer) to pass to engine. */
typedef void (*SyncFnPtr)(RunContext, const std::shared_ptr<void>&);
yuxihu marked this conversation as resolved.
Show resolved Hide resolved
/*! \brief Asynchronous operation (function pointer) to pass to engine. */
typedef void (*AsyncFnPtr)(RunContext, CallbackOnComplete, const std::shared_ptr<void>&);
/*! \brief Variable pointer */
typedef engine::VarHandle VarHandle;
/*! \brief Operator pointer */
Expand Down Expand Up @@ -206,6 +210,43 @@ class MXNET_API Engine {
int priority = 0,
const char* opr_name = nullptr,
bool wait = false) = 0;
/*!
* \brief Push an asynchronous operation to the engine.
* \param exec_fn_ptr Execution function, this function takes a parameter
* on_complete that must be called when the execution
* completes.
* \param param the parameter set on calling exec_fn_ptr.
* \param exec_ctx Execution context.
* \param const_vars The variables that current operation will use but not
* mutate.
* \param mutable_vars The variables that current operation will mutate.
* \param prop Property of the function.
* \param priority Priority of the action, as hint to the engine.
* \param opr_name The operator name.
* \param wait Whether this is a WaitForVar operation.
*/
void PushAsyncPtr(AsyncFnPtr exec_fn_ptr, const std::shared_ptr<void>& param,
yuxihu marked this conversation as resolved.
Show resolved Hide resolved
yuxihu marked this conversation as resolved.
Show resolved Hide resolved
Context exec_ctx, std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal, int priority = 0,
const char* opr_name = nullptr, bool wait = false);
/*!
* \brief Push an synchronous operation to the engine.
* \param exec_fn_ptr Execution function that executes the operation.
* \param param the parameter set on calling exec_fn_ptr.
* \param exec_ctx Execution context.
* \param const_vars The variables that current operation will use but not
* mutate.
* \param mutable_vars The variables that current operation will mutate.
* \param prop Property of the function.
* \param priority Priority of the action, as hint to the engine.
* \param opr_name The operator name.
*/
void PushSyncPtr(SyncFnPtr exec_fn_ptr, const std::shared_ptr<void>& param,
Context exec_ctx, std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
yuxihu marked this conversation as resolved.
Show resolved Hide resolved
FnProperty prop = FnProperty::kNormal, int priority = 0,
const char* opr_name = nullptr);
/*!
* \brief Schedule the deletion of a variable.
*
Expand Down
2 changes: 2 additions & 0 deletions make/config/libmxnet.sym
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ _nn*
Java_org_apache_mxnet*
*NDArray*
*Engine*Get*
*Engine*PushAsyncPtr*
*Engine*PushSyncPtr*
*Storage*Get*
*on_enter_api*
*on_exit_api*
Expand Down
2 changes: 2 additions & 0 deletions make/config/libmxnet.ver
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
Java_org_apache_mxnet*;
*NDArray*;
*Engine*Get*;
*Engine*PushAsyncPtr*;
*Engine*PushSyncPtr*;
*Storage*Get*;
*on_enter_api*;
*on_exit_api*;
Expand Down
22 changes: 22 additions & 0 deletions src/engine/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,26 @@ Engine* Engine::Get() {
static Engine *inst = _GetSharedRef().get();
return inst;
}

void Engine::PushAsyncPtr(AsyncFnPtr exec_fn_ptr, const std::shared_ptr<void>& param,
Context exec_ctx, std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop, int priority,
const char* opr_name, bool wait) {
auto exec_fn = [exec_fn_ptr, param](RunContext rctx,
yuxihu marked this conversation as resolved.
Show resolved Hide resolved
CallbackOnComplete on_complete) {
exec_fn_ptr(rctx, on_complete, param);
};
PushAsync(exec_fn, exec_ctx, const_vars, mutable_vars, prop, priority, opr_name, wait);
}

void Engine::PushSyncPtr(SyncFnPtr exec_fn_ptr, const std::shared_ptr<void>& param,
Context exec_ctx, std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop, int priority, const char* opr_name) {
auto exec_fn = [exec_fn_ptr, param](RunContext rctx) {
exec_fn_ptr(rctx, param);
};
PushSync(exec_fn, exec_ctx, const_vars, mutable_vars, prop, priority, opr_name);
}
} // namespace mxnet