Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
avoid using shared_ptr for param in new APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
yuxihu committed Apr 8, 2019
1 parent 1ff1dc2 commit 49f118d
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 15 deletions.
16 changes: 10 additions & 6 deletions include/mxnet/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,11 @@ class MXNET_API Engine {
/*! \brief Asynchronous operation to pass to engine. */
typedef std::function<void(RunContext, CallbackOnComplete)> AsyncFn;
/*! \brief Synchronous operation (function pointer) to pass to engine. */
typedef void (*SyncFnPtr)(RunContext, const std::shared_ptr<void>&);
typedef void (*SyncFnPtr)(RunContext, void*);
/*! \brief Asynchronous operation (function pointer) to pass to engine. */
typedef void (*AsyncFnPtr)(RunContext, CallbackOnComplete, const std::shared_ptr<void>&);
typedef void (*AsyncFnPtr)(RunContext, CallbackOnComplete, void*);
/*! \brief Callback to free the param passed into AsyncFnPtr/SyncFnPtr. */
typedef void (*FnPtrParamDeleter)(void* param);
/*! \brief Variable pointer */
typedef engine::VarHandle VarHandle;
/*! \brief Operator pointer */
Expand Down Expand Up @@ -215,7 +217,8 @@ class MXNET_API Engine {
* \param exec_fn_ptr Execution function, this function takes a parameter
* on_complete that must be called when the execution
* completes.
* \param param the parameter set on calling exec_fn_ptr.
* \param param The parameter set on calling exec_fn_ptr, can be NULL.
* \param del The callback to free param, can be NULL.
* \param exec_ctx Execution context.
* \param const_vars The variables that current operation will use but not
* mutate.
Expand All @@ -225,15 +228,16 @@ class MXNET_API Engine {
* \param opr_name The operator name.
* \param wait Whether this is a WaitForVar operation.
*/
void PushAsyncPtr(AsyncFnPtr exec_fn_ptr, const std::shared_ptr<void>& param,
void PushAsyncPtr(AsyncFnPtr exec_fn_ptr, void* param, FnPtrParamDeleter del,
Context exec_ctx, std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal, int priority = 0,
const char* opr_name = nullptr, bool wait = false);
/*!
* \brief Push an synchronous operation to the engine.
* \param exec_fn_ptr Execution function that executes the operation.
* \param param the parameter set on calling exec_fn_ptr.
* \param param The parameter set on calling exec_fn_ptr, can be NULL.
* \param del The callback to free param, can be NULL.
* \param exec_ctx Execution context.
* \param const_vars The variables that current operation will use but not
* mutate.
Expand All @@ -242,7 +246,7 @@ class MXNET_API Engine {
* \param priority Priority of the action, as hint to the engine.
* \param opr_name The operator name.
*/
void PushSyncPtr(SyncFnPtr exec_fn_ptr, const std::shared_ptr<void>& param,
void PushSyncPtr(SyncFnPtr exec_fn_ptr, void* param, FnPtrParamDeleter del,
Context exec_ctx, std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal, int priority = 0,
Expand Down
41 changes: 32 additions & 9 deletions src/engine/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,25 +68,48 @@ Engine* Engine::Get() {
return inst;
}

void Engine::PushAsyncPtr(AsyncFnPtr exec_fn_ptr, const std::shared_ptr<void>& param,
void Engine::PushAsyncPtr(AsyncFnPtr exec_fn_ptr, void* param, FnPtrParamDeleter del,
Context exec_ctx, std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop, int priority,
const char* opr_name, bool wait) {
auto exec_fn = [exec_fn_ptr, param](RunContext rctx,
CallbackOnComplete on_complete) {
exec_fn_ptr(rctx, on_complete, param);
};
AsyncFn exec_fn;
if (del == nullptr) {
exec_fn = [exec_fn_ptr, param](RunContext rctx,
CallbackOnComplete on_complete) {
exec_fn_ptr(rctx, on_complete, param);
};
} else {
// Wrap param in a shared_ptr with del as deleter such that del will be
// called when the lambda goes out of scope.
std::shared_ptr<void> shared_param(param, del);
exec_fn = [exec_fn_ptr, shared_param](RunContext rctx,
CallbackOnComplete on_complete) {
exec_fn_ptr(rctx, on_complete, shared_param.get());
};
}

PushAsync(exec_fn, exec_ctx, const_vars, mutable_vars, prop, priority, opr_name, wait);
}

void Engine::PushSyncPtr(SyncFnPtr exec_fn_ptr, const std::shared_ptr<void>& param,
void Engine::PushSyncPtr(SyncFnPtr exec_fn_ptr, void* param, FnPtrParamDeleter del,
Context exec_ctx, std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop, int priority, const char* opr_name) {
auto exec_fn = [exec_fn_ptr, param](RunContext rctx) {
exec_fn_ptr(rctx, param);
};
SyncFn exec_fn;
if (del == nullptr) {
exec_fn = [exec_fn_ptr, param](RunContext rctx) {
exec_fn_ptr(rctx, param);
};
} else {
// Wrap param in a shared_ptr with del as deleter such that del will be
// called when the lambda goes out of scope.
std::shared_ptr<void> shared_param(param, del);
exec_fn = [exec_fn_ptr, shared_param](RunContext rctx) {
exec_fn_ptr(rctx, shared_param.get());
};
}

PushSync(exec_fn, exec_ctx, const_vars, mutable_vars, prop, priority, opr_name);
}
} // namespace mxnet

0 comments on commit 49f118d

Please sign in to comment.