Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Custom Operator Random Number Generator Support #17762

Merged
merged 28 commits into from
Apr 8, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix subgraph segfault and improve cmake
  • Loading branch information
rondogency committed Mar 27, 2020
commit e67ee886876e7f7a2da9c19dadc9647b05fe35ba
17 changes: 6 additions & 11 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ if(USE_CUDA)
message("-- CUDA: Using the following NVCC architecture flags ${CUDA_ARCH_FLAGS}")
set(arch_code_list)
foreach(arch_str ${CUDA_ARCH_FLAGS})
if((arch_str MATCHES ".*sm_[0-9]+"))
if((arch_str MATCHES ".*sm_[0-9]+"))
string( REGEX REPLACE ".*sm_([0-9]+)" "\\1" arch_code ${arch_str} )
list(APPEND arch_code_list ${arch_code})
endif()
Expand Down Expand Up @@ -732,7 +732,7 @@ elseif(MSVC)
"$<$<COMPILE_LANGUAGE:CUDA>:--gpu-code=sm_${arch},compute_${arch}>"
)
target_compile_options(
mxnet_${arch}
mxnet_${arch}
PRIVATE "$<$<AND:$<CONFIG:DEBUG>,$<COMPILE_LANGUAGE:CUDA>>:-Xcompiler=-MTd -Gy /bigobj>")
target_compile_options(
mxnet_${arch}
Expand Down Expand Up @@ -761,26 +761,21 @@ elseif(MSVC)
endif()
endif()

# extension libraries (custom operators, custom subgraphs) are built by default
add_library(customop_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/gemm_lib.cc)
add_library(subgraph_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_subgraph/subgraph_lib.cc)
target_include_directories(customop_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet)
target_include_directories(subgraph_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet)
if (USE_CUDA)
if(USE_CUDA)
add_library(customop_gpu_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/relu_lib.cu)
target_include_directories(customop_gpu_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet)
endif()
if(UNIX)
target_compile_options(customop_lib PUBLIC -shared)
target_compile_options(subgraph_lib PUBLIC -shared)
if (USE_CUDA)
target_compile_options(customop_gpu_lib PUBLIC -shared)
endif()
elseif(MSVC)
if(MSVC)
target_compile_options(customop_lib PUBLIC /LD)
target_compile_options(subgraph_lib PUBLIC /LD)
set_target_properties(customop_lib PROPERTIES PREFIX "lib")
set_target_properties(subgraph_lib PROPERTIES PREFIX "lib")
if (USE_CUDA)
if(USE_CUDA)
target_compile_options(customop_gpu_lib PUBLIC "$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-fPIC>")
set_target_properties(customop_gpu_lib PROPERTIES PREFIX "lib")
endif()
Expand Down
8 changes: 4 additions & 4 deletions example/extensions/lib_custom_op/test_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@
r2 = mx.nd.relu(b)
mx.nd.waitall()
t3 = time.time()
print("Custom GPU running time:")
print(t2 - t1)
print("Native GPU running time:")
print(t3 - t2)
print("Custom GPU running time ms:")
print((t2 - t1) * 1000)
print("Native GPU running time ms:")
print((t3 - t2) * 1000)

print("--------start testing noisy relu---------")

Expand Down
2 changes: 1 addition & 1 deletion include/mxnet/lib_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ class OpResource {

/*! \brief get pointer to gpu random number inited and seeded states */
/* this global states are located on gpu, of type curandStatePhilox4_32_10_t */
/* note that if you are usign cpu build, then it will return a nullptr */
/* note that if you are using cpu build, it will return a nullptr */
void* get_gpu_rand_states() {
return gpu_rand_states;
rondogency marked this conversation as resolved.
Show resolved Hide resolved
}
Expand Down
33 changes: 21 additions & 12 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ void CustomFComputeDispatcher(const std::string op_name,
const nnvm::NodeAttrs* attrs,
const opCallFStatefulComp_t callFStatefulComp,
int stateful_forward_flag,
bool is_subgraph_op,
const OpStatePtr* state_ptr,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
Expand Down Expand Up @@ -146,6 +147,7 @@ void CustomFComputeDispatcher(const std::string op_name,
in_dims.push_back(in_nd->shape().ndim());
in_types.push_back(in_nd->dtype());
in_verIDs.push_back(in_nd->version());
// string repr of supported context for custom library, currently only "cpu" and "gpu"
const char* ctx_str = in_nd->ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu";
in_dev_type.push_back(ctx_str);

Expand Down Expand Up @@ -256,14 +258,16 @@ void CustomFComputeDispatcher(const std::string op_name,

// get mxnet inited and seeded rng states and pass to lib_api.h
rondogency marked this conversation as resolved.
Show resolved Hide resolved
void *cpu_states = nullptr, *gpu_states = nullptr;
mxnet::common::random::RandGenerator<cpu, float> *pgen_cpu =
ctx.requested[1].get_parallel_random<cpu, float>();
cpu_states = pgen_cpu->GetStates();
if (!is_subgraph_op) {
rondogency marked this conversation as resolved.
Show resolved Hide resolved
mxnet::common::random::RandGenerator<cpu, float> *pgen_cpu =
ctx.requested[1].get_parallel_random<cpu, float>();
cpu_states = pgen_cpu->GetStates();
#if MXNET_USE_CUDA
mxnet::common::random::RandGenerator<gpu, float> *pgen_gpu =
ctx.requested[1].get_parallel_random<gpu, float>();
gpu_states = pgen_gpu->GetStates();
mxnet::common::random::RandGenerator<gpu, float> *pgen_gpu =
ctx.requested[1].get_parallel_random<gpu, float>();
gpu_states = pgen_gpu->GetStates();
#endif
}

CHECK((fcomp_fp != nullptr && state_ptr == nullptr)
|| (fcomp_fp == nullptr && state_ptr != nullptr))
Expand Down Expand Up @@ -837,7 +841,8 @@ int MXLoadLib(const char *path) {
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CustomFComputeDispatcher(name_str, nullptr, nullptr, nullptr,
callFStatefulComp, 1, &state_ptr, ctx, inputs, req, outputs);
callFStatefulComp, 1, isSubgraphOp, &state_ptr,
ctx, inputs, req, outputs);
};
if (createop_map.count("cpu") > 0)
regOp.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", fstate_forward, plevel);
Expand All @@ -853,12 +858,14 @@ int MXLoadLib(const char *path) {
CHECK_GT(forward_ctx_map.count("cpu"), 0);
fcomp_t fcomp = forward_ctx_map.at("cpu");
CustomFComputeDispatcher(name_str, callFComp, fcomp, &attrs,
nullptr, 0, nullptr, ctx, inputs, req, outputs);
nullptr, 0, isSubgraphOp, nullptr,
ctx, inputs, req, outputs);
} else if (ctx.run_ctx.ctx.dev_mask() == Context::kGPU) {
CHECK_GT(forward_ctx_map.count("gpu"), 0);
fcomp_t fcomp = forward_ctx_map.at("gpu");
CustomFComputeDispatcher(name_str, callFComp, fcomp, &attrs,
nullptr, 0, nullptr, ctx, inputs, req, outputs);
nullptr, 0, isSubgraphOp, nullptr,
ctx, inputs, req, outputs);
}
};
if (forward_ctx_map.count("cpu") > 0)
Expand Down Expand Up @@ -902,7 +909,7 @@ int MXLoadLib(const char *path) {
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CustomFComputeDispatcher(name_str, nullptr, nullptr, nullptr,
callFStatefulComp, 0, &state_ptr,
callFStatefulComp, 0, isSubgraphOp, &state_ptr,
ctx, inputs, req, outputs);
};
gradOp.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", fstate_backward, plevel);
Expand All @@ -917,7 +924,8 @@ int MXLoadLib(const char *path) {
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CustomFComputeDispatcher(name_str, callFComp, fcomp_back_cpu, &attrs,
nullptr, 0, nullptr, ctx, inputs, req, outputs);
nullptr, 0, isSubgraphOp, nullptr,
ctx, inputs, req, outputs);
};
gradOp.set_attr<FComputeEx>("FComputeEx<cpu>", backward_cpu_lambda, plevel);
}
Expand All @@ -929,7 +937,8 @@ int MXLoadLib(const char *path) {
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CustomFComputeDispatcher(name_str, callFComp, fcomp_back_gpu, &attrs,
nullptr, 0, nullptr, ctx, inputs, req, outputs);
nullptr, 0, isSubgraphOp, nullptr,
ctx, inputs, req, outputs);
};
gradOp.set_attr<FComputeEx>("FComputeEx<gpu>", backward_gpu_lambda, plevel);
}
Expand Down
2 changes: 1 addition & 1 deletion src/common/random_generator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ void RandGenerator<gpu, float>::FreeState(RandGenerator<gpu> *inst) {

template<>
void* RandGenerator<gpu, float>::GetStates() {
return (void*)states_;
return static_cast<void*>(states_);
}

} // namespace random
Expand Down
4 changes: 2 additions & 2 deletions tests/python/gpu/test_extensions_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def test_custom_op_gpu():
assert_almost_equal(in_grad_base[0].asnumpy(), in_grad[0].asnumpy(), rtol=1e-3, atol=1e-3)

# test custom noisy relu producing deterministic result given same seed managed by mxnet
d1 = mx.nd.ones(shape=(10,10), ctx=mx.cpu())
d2 = mx.nd.ones(shape=(10,10), ctx=mx.gpu())
d1 = mx.nd.ones(shape=(10,10,10), ctx=mx.cpu())
d2 = mx.nd.ones(shape=(10,10,10), ctx=mx.gpu())

mx.random.seed(128, ctx=mx.cpu())
r1 = mx.nd.my_noisy_relu(d1)
Expand Down