Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add MXEnginePushAsync and MXEnginePushSync C APIs #14615

Merged
merged 7 commits into from
Apr 12, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
avoid using shared_ptr for param in new APIs
  • Loading branch information
yuxihu committed Apr 9, 2019
commit d5a27b540fce4fb7ec8e28729303a11f90f6d3c7
16 changes: 10 additions & 6 deletions include/mxnet/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,11 @@ class MXNET_API Engine {
/*! \brief Asynchronous operation to pass to engine. */
typedef std::function<void(RunContext, CallbackOnComplete)> AsyncFn;
/*! \brief Synchronous operation (function pointer) to pass to engine. */
typedef void (*SyncFnPtr)(RunContext, const std::shared_ptr<void>&);
typedef void (*SyncFnPtr)(RunContext, void*);
/*! \brief Asynchronous operation (function pointer) to pass to engine. */
typedef void (*AsyncFnPtr)(RunContext, CallbackOnComplete, const std::shared_ptr<void>&);
typedef void (*AsyncFnPtr)(RunContext, CallbackOnComplete, void*);
/*! \brief Callback to free the param passed into AsyncFnPtr/SyncFnPtr. */
typedef void (*FnPtrParamDeleter)(void* param);
/*! \brief Variable pointer */
typedef engine::VarHandle VarHandle;
/*! \brief Operator pointer */
Expand Down Expand Up @@ -215,7 +217,8 @@ class MXNET_API Engine {
* \param exec_fn_ptr Execution function, this function takes a parameter
* on_complete that must be called when the execution
* completes.
* \param param the parameter set on calling exec_fn_ptr.
* \param param The parameter set on calling exec_fn_ptr, can be NULL.
* \param del The callback to free param, can be NULL.
* \param exec_ctx Execution context.
* \param const_vars The variables that current operation will use but not
* mutate.
Expand All @@ -225,15 +228,16 @@ class MXNET_API Engine {
* \param opr_name The operator name.
* \param wait Whether this is a WaitForVar operation.
*/
void PushAsyncPtr(AsyncFnPtr exec_fn_ptr, const std::shared_ptr<void>& param,
void PushAsyncPtr(AsyncFnPtr exec_fn_ptr, void* param, FnPtrParamDeleter del,
Context exec_ctx, std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal, int priority = 0,
const char* opr_name = nullptr, bool wait = false);
/*!
* \brief Push an synchronous operation to the engine.
* \param exec_fn_ptr Execution function that executes the operation.
* \param param the parameter set on calling exec_fn_ptr.
* \param param The parameter set on calling exec_fn_ptr, can be NULL.
* \param del The callback to free param, can be NULL.
* \param exec_ctx Execution context.
* \param const_vars The variables that current operation will use but not
* mutate.
Expand All @@ -242,7 +246,7 @@ class MXNET_API Engine {
* \param priority Priority of the action, as hint to the engine.
* \param opr_name The operator name.
*/
void PushSyncPtr(SyncFnPtr exec_fn_ptr, const std::shared_ptr<void>& param,
void PushSyncPtr(SyncFnPtr exec_fn_ptr, void* param, FnPtrParamDeleter del,
Context exec_ctx, std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
yuxihu marked this conversation as resolved.
Show resolved Hide resolved
FnProperty prop = FnProperty::kNormal, int priority = 0,
Expand Down
41 changes: 32 additions & 9 deletions src/engine/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,25 +68,48 @@ Engine* Engine::Get() {
return inst;
}

void Engine::PushAsyncPtr(AsyncFnPtr exec_fn_ptr, const std::shared_ptr<void>& param,
void Engine::PushAsyncPtr(AsyncFnPtr exec_fn_ptr, void* param, FnPtrParamDeleter del,
Context exec_ctx, std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop, int priority,
const char* opr_name, bool wait) {
auto exec_fn = [exec_fn_ptr, param](RunContext rctx,
CallbackOnComplete on_complete) {
exec_fn_ptr(rctx, on_complete, param);
};
AsyncFn exec_fn;
if (del == nullptr) {
exec_fn = [exec_fn_ptr, param](RunContext rctx,
CallbackOnComplete on_complete) {
exec_fn_ptr(rctx, on_complete, param);
};
} else {
// Wrap param in a shared_ptr with del as deleter such that del will be
// called when the lambda goes out of scope.
std::shared_ptr<void> shared_param(param, del);
exec_fn = [exec_fn_ptr, shared_param](RunContext rctx,
CallbackOnComplete on_complete) {
exec_fn_ptr(rctx, on_complete, shared_param.get());
};
}

PushAsync(exec_fn, exec_ctx, const_vars, mutable_vars, prop, priority, opr_name, wait);
}

void Engine::PushSyncPtr(SyncFnPtr exec_fn_ptr, const std::shared_ptr<void>& param,
void Engine::PushSyncPtr(SyncFnPtr exec_fn_ptr, void* param, FnPtrParamDeleter del,
Context exec_ctx, std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop, int priority, const char* opr_name) {
auto exec_fn = [exec_fn_ptr, param](RunContext rctx) {
exec_fn_ptr(rctx, param);
};
SyncFn exec_fn;
if (del == nullptr) {
exec_fn = [exec_fn_ptr, param](RunContext rctx) {
exec_fn_ptr(rctx, param);
};
} else {
// Wrap param in a shared_ptr with del as deleter such that del will be
// called when the lambda goes out of scope.
std::shared_ptr<void> shared_param(param, del);
exec_fn = [exec_fn_ptr, shared_param](RunContext rctx) {
exec_fn_ptr(rctx, shared_param.get());
};
}

PushSync(exec_fn, exec_ctx, const_vars, mutable_vars, prop, priority, opr_name);
}
} // namespace mxnet