diff --git a/CMakeLists.txt b/CMakeLists.txt index f62dd0d0fdad..9dce131473b6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -584,7 +584,7 @@ if(USE_CUDA) message("-- CUDA: Using the following NVCC architecture flags ${CUDA_ARCH_FLAGS}") set(arch_code_list) foreach(arch_str ${CUDA_ARCH_FLAGS}) - if((arch_str MATCHES ".*sm_[0-9]+")) + if((arch_str MATCHES ".*sm_[0-9]+")) string( REGEX REPLACE ".*sm_([0-9]+)" "\\1" arch_code ${arch_str} ) list(APPEND arch_code_list ${arch_code}) endif() @@ -719,7 +719,7 @@ elseif(MSVC) "$<$:--gpu-code=sm_${arch},compute_${arch}>" ) target_compile_options( - mxnet_${arch} + mxnet_${arch} PRIVATE "$<$,$>:-Xcompiler=-MTd -Gy /bigobj>") target_compile_options( mxnet_${arch} @@ -748,26 +748,21 @@ elseif(MSVC) endif() endif() +# extension libraries (custom operators, custom subgraphs) are built by default add_library(customop_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/gemm_lib.cc) add_library(subgraph_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_subgraph/subgraph_lib.cc) target_include_directories(customop_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet) target_include_directories(subgraph_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet) -if (USE_CUDA) +if(USE_CUDA) add_library(customop_gpu_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/relu_lib.cu) target_include_directories(customop_gpu_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet) endif() -if(UNIX) - target_compile_options(customop_lib PUBLIC -shared) - target_compile_options(subgraph_lib PUBLIC -shared) - if (USE_CUDA) - target_compile_options(customop_gpu_lib PUBLIC -shared) - endif() -elseif(MSVC) +if(MSVC) target_compile_options(customop_lib PUBLIC /LD) target_compile_options(subgraph_lib PUBLIC /LD) set_target_properties(customop_lib PROPERTIES PREFIX "lib") set_target_properties(subgraph_lib PROPERTIES PREFIX "lib") - if (USE_CUDA) + if(USE_CUDA) target_compile_options(customop_gpu_lib PUBLIC "$<$:-Xcompiler=-fPIC>") set_target_properties(customop_gpu_lib PROPERTIES PREFIX "lib") endif() diff --git a/example/extensions/lib_custom_op/relu_lib.cu b/example/extensions/lib_custom_op/relu_lib.cu index 3beb68c20fa7..60112ee4b1e5 100644 --- a/example/extensions/lib_custom_op/relu_lib.cu +++ b/example/extensions/lib_custom_op/relu_lib.cu @@ -20,12 +20,14 @@ /*! * Copyright (c) 2020 by Contributors * \file relu_lib.cu - * \brief simple custom relu operator implemented using CUDA function + * \brief simple custom relu and noisy relu operator implemented using CUDA function */ #include #include "lib_api.h" +#define NumThreadPerBlock 256 // mxnet recommended cuda thread number per block + __global__ void relu_gpu_forward(float *out, float *in, int64_t N) { int tid = blockIdx.x * blockDim.x + threadIdx.x; if (tid < N) @@ -72,9 +74,9 @@ MXReturnValue forwardGPU(std::map attrs, mx_stream_t cuda_stream = res.get_cuda_stream(); int64_t N = inputs[0].size(); - int block = 256; - int grid = (N + (block - 1)) / block; - relu_gpu_forward<<>>(out_data, in_data, N); + int num_block = (N + NumThreadPerBlock - 1) / NumThreadPerBlock; + + relu_gpu_forward<<>>(out_data, in_data, N); return MX_SUCCESS; } @@ -89,9 +91,9 @@ MXReturnValue backwardGPU(std::map attrs, mx_stream_t cuda_stream = res.get_cuda_stream(); int64_t N = inputs[0].size(); - int block = 256; - int grid = (N + (block - 1)) / block; - relu_gpu_backward<<>>(in_grad, out_grad, in_data, N); + int num_block = (N + NumThreadPerBlock - 1) / NumThreadPerBlock; + + relu_gpu_backward<<>>(in_grad, out_grad, in_data, N); return MX_SUCCESS; } @@ -180,6 +182,80 @@ REGISTER_OP(my_state_relu) .setCreateOpState(createOpStateCPU, "cpu") .setCreateOpState(createOpStateGPU, "gpu"); +/* + * Below is noisy ReLU operator example + * noisy ReLU is made from ReLU extended to include Gaussian noise + * forward - add Gaussian noise generated from normal distribution to each unit + * backward - gradient doesn't need to change since noise is constant + */ + +#define NumRandomPerThread 64 // mxnet recommended random numbers generated per thread + +__global__ void noisy_relu_gpu_forward(float *out, float *in, int64_t N, mx_gpu_rand_t* states, int step) { + // the launcher logic ensures tid less than NumGPURandomStates + int tid = blockIdx.x * blockDim.x + threadIdx.x; + // each thread generates unique sequence of random numbers + mx_gpu_rand_t thread_state = states[tid]; + // each thread works on number of calculation + int start = tid * step; + int end = start + step; + for (int i=start; i 0 ? in[i] + noise : 0; + } +} + +MXReturnValue noisyForwardCPU(std::map attrs, + std::vector inputs, + std::vector outputs, + OpResource res) { + float* in_data = inputs[0].data(); + float* out_data = outputs[0].data(); + + mx_cpu_rand_t* states = res.get_cpu_rand_states(); + std::normal_distribution dist_normal; + + for (int i=0; i 0 ? in_data[i] + noise : 0; + } + return MX_SUCCESS; +} + +MXReturnValue noisyForwardGPU(std::map attrs, + std::vector inputs, + std::vector outputs, + OpResource res) { + float* in_data = inputs[0].data(); + float* out_data = outputs[0].data(); + + mx_stream_t cuda_stream = res.get_cuda_stream(); + int64_t N = inputs[0].size(); + + // below is mxnet recommended workflow to parallel random number generating + int nthread = (N + NumRandomPerThread - 1) / NumRandomPerThread; + // we should not launch more threads than mxnet supported random number GPU states + int num_thread_need = nthread < MX_NUM_GPU_RANDOM_STATES ? nthread : MX_NUM_GPU_RANDOM_STATES; + // each cuda thread processes [step * tid, step * id + step) snippet of input tensor + int step = (N + num_thread_need - 1) / num_thread_need; + // this can ensure number of parallel threads less than mxnet supported random number states + int num_block = (num_thread_need + NumThreadPerBlock - 1) / NumThreadPerBlock; + + noisy_relu_gpu_forward<<>>( + out_data, in_data, N, res.get_gpu_rand_states(), step); + + return MX_SUCCESS; +} + +REGISTER_OP(my_noisy_relu) +.setParseAttrs(parseAttrs) +.setInferType(inferType) +.setInferShape(inferShape) +.setForward(noisyForwardCPU, "cpu") +.setForward(noisyForwardGPU, "gpu") +.setBackward(backwardCPU, "cpu") +.setBackward(backwardGPU, "gpu"); + MXReturnValue initialize(int version) { if (version >= 10400) { std::cout << "MXNet version " << version << " supported" << std::endl; diff --git a/example/extensions/lib_custom_op/test_relu.py b/example/extensions/lib_custom_op/test_relu.py index 03d02f32d633..a37ea25b2ba4 100644 --- a/example/extensions/lib_custom_op/test_relu.py +++ b/example/extensions/lib_custom_op/test_relu.py @@ -35,13 +35,13 @@ a = mx.nd.array([[-2,-1],[1,2]], ctx=mx.cpu()) b = mx.nd.array([[-2,-1],[1,2]], ctx=mx.gpu()) -print("--------start ndarray compute---------") +print("--------ndarray compute---------") print(mx.nd.my_relu(a)) print(mx.nd.my_relu(b)) print(mx.nd.my_state_relu(a)) print(mx.nd.my_state_relu(b)) -print("--------start symbolic compute--------") +print("--------symbolic compute--------") c = mx.sym.Variable('c') d = mx.sym.Variable('d') e = mx.sym.my_relu(c) @@ -55,30 +55,41 @@ print(out) print(out_base) -print("--------start backward compute--------") +print("--------backward compute--------") out_grad = mx.nd.ones((2,2), ctx=mx.gpu()) exe.backward([out_grad]) exe_base.backward([out_grad]) print(in_grad) print(in_grad_base) -print("--------start testing larger ndarray---------") -a = mx.nd.uniform(shape=(100,100,100), ctx=mx.cpu()) +print("--------test ndarray with size of 1 million---------") b = mx.nd.uniform(shape=(100,100,100), ctx=mx.gpu()) mx.nd.waitall() t1 = time.time() -r1 = mx.nd.my_relu(a) +r1 = mx.nd.my_relu(b) mx.nd.waitall() t2 = time.time() -r2 = mx.nd.my_relu(b) +r2 = mx.nd.relu(b) mx.nd.waitall() t3 = time.time() -r3 = mx.nd.relu(b) -mx.nd.waitall() -t4 = time.time() -print("CPU running time:") -print(t2 - t1) -print("GPU running time:") -print(t3 - t2) -print("Baseline GPU running time:") -print(t4 - t3) +print("Custom ReLU running time in ms:") +print((t2 - t1) * 1000) +print("Native ReLU running time in ms:") +print((t3 - t2) * 1000) + +print("--------test noisy relu identical sequence---------") + +a = mx.nd.ones(shape=(13,5), ctx=mx.cpu()) +b = mx.nd.ones(shape=(13,5), ctx=mx.gpu()) + +mx.random.seed(128, ctx=mx.cpu()) +print(mx.nd.my_noisy_relu(a)) + +mx.random.seed(128, ctx=mx.cpu()) +print(mx.nd.my_noisy_relu(a)) + +mx.random.seed(128, ctx=mx.gpu()) +print(mx.nd.my_noisy_relu(b)) + +mx.random.seed(128, ctx=mx.gpu()) +print(mx.nd.my_noisy_relu(b)) diff --git a/include/mxnet/lib_api.h b/include/mxnet/lib_api.h index fd526ee4172f..c793a30c96d9 100644 --- a/include/mxnet/lib_api.h +++ b/include/mxnet/lib_api.h @@ -38,8 +38,14 @@ #include #include #include +#include -#define MX_LIBRARY_VERSION 5 +#if defined(__NVCC__) + #include +#endif + +/* Make sure to update the version number everytime you make changes */ +#define MX_LIBRARY_VERSION 6 /*! * \brief For loading multiple custom op libraries in Linux, exporting same symbol multiple @@ -395,8 +401,8 @@ struct MXTensor { stype == oth.stype; } - // For dense, data_ptr points to data. - // For sparse, data_ptr points to MXSparse. + // For dense, data_ptr points to 1D flattened tensor data + // For sparse, data_ptr points to MXSparse void *data_ptr; // shape is in [2,3,4] format to represent high-dim tensor @@ -426,9 +432,17 @@ typedef void (*sparse_malloc_t)(void*, int, int, int, void**, int64_t**, int64_t #if defined(__NVCC__) typedef cudaStream_t mx_stream_t; + typedef curandStatePhilox4_32_10_t mx_gpu_rand_t; #else typedef void* mx_stream_t; + typedef void* mx_gpu_rand_t; #endif +typedef std::mt19937 mx_cpu_rand_t; + +/*! \brief MXNet initialized random states for each device, used for parallelism */ +/* Each thread should generate random number unique sequence out of different states */ +#define MX_NUM_CPU_RANDOM_STATES 1024 +#define MX_NUM_GPU_RANDOM_STATES 32768 /*! * \brief provide resource APIs memory allocation mechanism to Forward/Backward functions @@ -437,10 +451,12 @@ class OpResource { public: OpResource(xpu_malloc_t cpu_malloc_fp, void* cpu_alloc_fp, xpu_malloc_t gpu_malloc_fp, void* gpu_alloc_fp, void* stream, - sparse_malloc_t sparse_malloc_fp, void* sparse_alloc_fp) + sparse_malloc_t sparse_malloc_fp, void* sparse_alloc_fp, + void* rng_cpu_states, void* rng_gpu_states) : cpu_malloc(cpu_malloc_fp), gpu_malloc(gpu_malloc_fp), cpu_alloc(cpu_alloc_fp), gpu_alloc(gpu_alloc_fp), cuda_stream(stream), - sparse_malloc(sparse_malloc_fp), sparse_alloc(sparse_alloc_fp) {} + sparse_malloc(sparse_malloc_fp), sparse_alloc(sparse_alloc_fp), + rand_cpu_states(rng_cpu_states), rand_gpu_states(rng_gpu_states) {} /*! \brief allocate cpu memory controlled by MXNet */ void* alloc_cpu(int size) { @@ -463,6 +479,19 @@ class OpResource { &(sparse->data), &(sparse->indices), &(sparse->indptr)); } + /*! \brief get pointer to initialized and seeded random number states located on CPU */ + /* Access each state by states[id], but this id should be <= MX_NUM_CPU_RANDOM_STATES */ + mx_cpu_rand_t* get_cpu_rand_states() { + return static_cast(rand_cpu_states); + } + + /*! \brief get pointer to initialized and seeded random number states located on GPU */ + /* Access each state by states[id], but this id should be <= MX_NUM_GPU_RANDOM_STATES */ + /* Note that if you are using cpu build, it will return a nullptr */ + mx_gpu_rand_t* get_gpu_rand_states() { + return static_cast(rand_gpu_states); + } + private: /*! \brief allocation lambda function */ xpu_malloc_t cpu_malloc, gpu_malloc; @@ -474,6 +503,8 @@ class OpResource { sparse_malloc_t sparse_malloc; /*! \brief lambda function to return allocated sparse memory handle */ void *sparse_alloc; + /*! \brief cpu and gpu rng fully inited and seeded states */ + void *rand_cpu_states, *rand_gpu_states; }; /*! @@ -997,7 +1028,8 @@ typedef int (*opCallFComp_t)(fcomp_t fcomp, const char* const* keys, void** in_indices, void** out_indices, void** in_indptr, void** out_indptr, int64_t* in_indices_shapes, int64_t* out_indices_shapes, - int64_t* in_indptr_shapes, int64_t* out_indptr_shapes); + int64_t* in_indptr_shapes, int64_t* out_indptr_shapes, + void* rng_cpu_states, void* rng_gpu_states); #define MXLIB_OPCALLMUTATEINPUTS_STR "_opCallMutateInputs" typedef int (*opCallMutateInputs_t)(mutateInputs_t mutate, const char* const* keys, @@ -1026,7 +1058,8 @@ typedef int (*opCallFStatefulComp_t)(int is_forward, void* state_op, void** in_indices, void** out_indices, void** in_indptr, void** out_indptr, int64_t* in_indices_shapes, int64_t* out_indices_shapes, - int64_t* in_indptr_shapes, int64_t* out_indptr_shapes); + int64_t* in_indptr_shapes, int64_t* out_indptr_shapes, + void* rng_cpu_states, void* rng_gpu_states); #define MXLIB_PARTREGSIZE_STR "_partRegSize" typedef int (*partRegSize_t)(void); @@ -1284,7 +1317,8 @@ extern "C" { int* instypes, int* outstypes, void** in_indices, void** out_indices, void** in_indptr, void** out_indptr, int64_t* in_indices_shapes, int64_t* out_indices_shapes, - int64_t* in_indptr_shapes, int64_t* out_indptr_shapes) { + int64_t* in_indptr_shapes, int64_t* out_indptr_shapes, + void* rng_cpu_states, void* rng_gpu_states) { // create map of attributes from list std::map attrs; for (int i = 0; i < num; i++) { @@ -1345,7 +1379,7 @@ extern "C" { } OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc, - cuda_stream, sparse_malloc, sparse_alloc); + cuda_stream, sparse_malloc, sparse_alloc, rng_cpu_states, rng_gpu_states); return fcomp(attrs, inputs, outputs, res); } @@ -1419,7 +1453,8 @@ extern "C" { int* instypes, int* outstypes, void** in_indices, void** out_indices, void** in_indptr, void** out_indptr, int64_t* in_indices_shapes, int64_t* out_indices_shapes, - int64_t* in_indptr_shapes, int64_t* out_indptr_shapes) { + int64_t* in_indptr_shapes, int64_t* out_indptr_shapes, + void* rng_cpu_states, void* rng_gpu_states) { // create a vector of tensors for inputs std::vector inputs(num_in); // create a vector for sparse inputs @@ -1476,7 +1511,7 @@ extern "C" { } OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc, - stream, sparse_malloc, sparse_alloc); + stream, sparse_malloc, sparse_alloc, rng_cpu_states, rng_gpu_states); CustomStatefulOp* op_ptr = reinterpret_cast(state_op); if (is_forward) { diff --git a/include/mxnet/random_generator.h b/include/mxnet/random_generator.h index e7b419309cb7..a5a9b8e35e57 100644 --- a/include/mxnet/random_generator.h +++ b/include/mxnet/random_generator.h @@ -96,6 +96,11 @@ class RandGenerator { for (int i = 0; i < kNumRandomStates; ++i) (states_ + i)->seed(seed + i); } + // export global random states, used by c++ custom operator + MSHADOW_XINLINE void* GetStates() { + return static_cast(states_); + } + private: std::mt19937 *states_; }; // class RandGenerator @@ -165,6 +170,9 @@ class RandGenerator { void Seed(mshadow::Stream *s, uint32_t seed); + // export global random states, used by c++ custom operator + void* GetStates(); + private: curandStatePhilox4_32_10_t *states_; }; // class RandGenerator diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 949a59406c11..6b0eb4f1c769 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -146,6 +146,7 @@ void CustomFComputeDispatcher(const std::string op_name, in_dims.push_back(in_nd->shape().ndim()); in_types.push_back(in_nd->dtype()); in_verIDs.push_back(in_nd->version()); + // string repr of supported context for custom library, currently only "cpu" and "gpu" const char* ctx_str = in_nd->ctx().dev_mask() == Context::kCPU ? "cpu" : "gpu"; in_dev_type.push_back(ctx_str); @@ -187,7 +188,9 @@ void CustomFComputeDispatcher(const std::string op_name, } // get memory resource and mxnet backend streams - const Resource &resource = ctx.requested[0]; + CHECK(ctx.requested.size() >= 2) + << "Custom operator should register at least memory resource and parallel random resource"; + const Resource &resource = ctx.requested.at(0); mshadow::Stream *cpu_stream = ctx.get_stream(); mshadow::Stream *gpu_stream = ctx.get_stream(); @@ -222,7 +225,7 @@ void CustomFComputeDispatcher(const std::string op_name, } }; - // create lambda without captures so that we can cast it to function pointer + // create no-capture lambda so that we can cast it to function pointer // lambda with captures cannot be cast to function pointer and pass to lib_api.h // this needs to be a lambda function so that we can do the decltype cast typedef decltype(cpu_alloc) alloc_type_cpu; @@ -232,6 +235,7 @@ void CustomFComputeDispatcher(const std::string op_name, // call cpu_alloc to actually allocate memory and return the pointer return static_cast((*cpualloc)(size)); }; + typedef decltype(gpu_alloc) alloc_type_gpu; auto gpu_malloc = [](void* _gpu_alloc, int size) { alloc_type_gpu* gpualloc = static_cast(_gpu_alloc); @@ -248,11 +252,22 @@ void CustomFComputeDispatcher(const std::string op_name, // get actual cudaStream_t out of mxnet gpu stream and pass to lib_api.h void *cuda_stream = nullptr; #if MXNET_USE_CUDA - if (inputs[0].ctx().dev_mask() == Context::kGPU) { + if ((inputs.size() > 0 && inputs[0].ctx().dev_mask() == Context::kGPU) || + (outputs.size() > 0 && outputs[0].ctx().dev_mask() == Context::kGPU)) { cuda_stream = static_cast(gpu_stream->stream_); } #endif + // get mxnet initialized and seeded RNG states and pass to lib_api.h + void *rng_cpu_states = nullptr, *rng_gpu_states = nullptr; + using mxnet::common::random::RandGenerator; + RandGenerator *pgen_cpu = ctx.requested.at(1).get_parallel_random(); + rng_cpu_states = pgen_cpu->GetStates(); +#if MXNET_USE_CUDA + RandGenerator *pgen_gpu = ctx.requested.at(1).get_parallel_random(); + rng_gpu_states = pgen_gpu->GetStates(); +#endif + CHECK((fcomp_fp != nullptr && state_ptr == nullptr) || (fcomp_fp == nullptr && state_ptr != nullptr)) << "Can only register either regular op or stateful op for '" << op_name << "'"; @@ -275,7 +290,8 @@ void CustomFComputeDispatcher(const std::string op_name, sparse_malloc, &sparse_alloc, in_stypes.data(), out_stypes.data(), in_indices.data(), out_indices.data(), in_indptr.data(), out_indptr.data(), in_indices_shapes.data(), out_indices_shapes.data(), - in_indptr_shapes.data(), out_indptr_shapes.data())) + in_indptr_shapes.data(), out_indptr_shapes.data(), + rng_cpu_states, rng_gpu_states)) << "Error calling FCompute for custom operator '" << op_name << "'"; } @@ -299,7 +315,8 @@ void CustomFComputeDispatcher(const std::string op_name, in_indices.data(), out_indices.data(), in_indptr.data(), out_indptr.data(), in_indices_shapes.data(), out_indices_shapes.data(), - in_indptr_shapes.data(), out_indptr_shapes.data())) + in_indptr_shapes.data(), out_indptr_shapes.data(), + rng_cpu_states, rng_gpu_states)) << "Error calling FStatefulCompute for custom operator '" << op_name << "'"; } } @@ -356,7 +373,6 @@ int MXLoadLib(const char *path) { partCallSupportedOps_t callSupportedOps = get_func(lib, const_cast(MXLIB_PARTCALLSUPPORTEDOPS_STR)); - partCallReviewSubgraph_t callReviewSubgraph = get_func(lib, const_cast(MXLIB_PARTCALLREVIEWSUBGRAPH_STR)); @@ -435,7 +451,7 @@ int MXLoadLib(const char *path) { /* * Below are a series of lambda functions that will be registered in the NNVM op registration * Each one has the standard MXNet signature and converts to types supported by externally - * registered operators. + * registered operators. */ // lambda function to call parse attributes @@ -726,7 +742,8 @@ int MXLoadLib(const char *path) { }; auto resc_req = [=](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; + return std::vector{ResourceRequest::kTempSpace, + ResourceRequest::kParallelRandom}; }; // library author should implement and return a 'state' which points to an instance @@ -790,6 +807,8 @@ int MXLoadLib(const char *path) { // TODO(samskalicky): enable constant overwriting of registertion multiple times plevel++; } + // define supported resources for both subgraph ops and regular ops + regOp.set_attr("FResourceRequest", resc_req, plevel); if (!isSubgraphOp) { regOp.set_attr_parser(attr_parser); regOp.set_num_inputs(num_inputs); @@ -797,7 +816,6 @@ int MXLoadLib(const char *path) { regOp.set_attr("FInferType", infer_type, plevel); regOp.set_attr("FInferStorageType", infer_storage_type, plevel); regOp.set_attr("FInferShape", infer_shape, plevel); - regOp.set_attr("FResourceRequest", resc_req, plevel); // optionally add fmutate inputs if user specified a function if (mutate_fp != nullptr) regOp.set_attr("FMutateInputs", mutate_inputs, plevel); @@ -809,8 +827,6 @@ int MXLoadLib(const char *path) { regOp.set_attr("FInferShape", DefaultSubgraphOpShape, plevel); regOp.set_attr("FInferStorageType", DefaultSubgraphOpStorageType, plevel); - regOp.set_attr("FResourceRequest", - DefaultSubgraphOpResourceRequest, plevel); regOp.set_attr("FMutateInputs", DefaultSubgraphOpMutableInputs, plevel); } @@ -888,8 +904,7 @@ int MXLoadLib(const char *path) { const std::vector& req, const std::vector& outputs) { CustomFComputeDispatcher(name_str, nullptr, nullptr, nullptr, - callFStatefulComp, 0, &state_ptr, - ctx, inputs, req, outputs); + callFStatefulComp, 0, &state_ptr, ctx, inputs, req, outputs); }; gradOp.set_attr("FStatefulComputeEx", fstate_backward, plevel); gradOp.set_attr("FStatefulComputeEx", fstate_backward, plevel); diff --git a/src/common/random_generator.cu b/src/common/random_generator.cu index a2d3e0d911e3..8f7b95985d02 100644 --- a/src/common/random_generator.cu +++ b/src/common/random_generator.cu @@ -70,6 +70,11 @@ void RandGenerator::FreeState(RandGenerator *inst) { CUDA_CALL(cudaFree(inst->states_)); } +template<> +void* RandGenerator::GetStates() { + return static_cast(states_); +} + } // namespace random } // namespace common } // namespace mxnet diff --git a/tests/python/gpu/test_extensions_gpu.py b/tests/python/gpu/test_extensions_gpu.py index 08930a3986ef..8315b49660f3 100644 --- a/tests/python/gpu/test_extensions_gpu.py +++ b/tests/python/gpu/test_extensions_gpu.py @@ -68,8 +68,24 @@ def test_custom_op_gpu(): out_base = exe_base.forward() assert_almost_equal(out_base[0].asnumpy(), out[0].asnumpy(), rtol=1e-3, atol=1e-3) - # test backward + # test custom relu backward out_grad = mx.nd.ones((2,2), ctx=mx.gpu()) exe.backward([out_grad]) exe_base.backward([out_grad]) assert_almost_equal(in_grad_base[0].asnumpy(), in_grad[0].asnumpy(), rtol=1e-3, atol=1e-3) + + # test custom noisy relu producing deterministic result given same seed managed by mxnet + d1 = mx.nd.ones(shape=(10,10,10), ctx=mx.cpu()) + d2 = mx.nd.ones(shape=(10,10,10), ctx=mx.gpu()) + + mx.random.seed(128, ctx=mx.cpu()) + r1 = mx.nd.my_noisy_relu(d1) + mx.random.seed(128, ctx=mx.cpu()) + r2 = mx.nd.my_noisy_relu(d1) + assert_almost_equal(r1.asnumpy(), r2.asnumpy(), rtol=1e-3, atol=1e-3) + + mx.random.seed(128, ctx=mx.gpu()) + r3 = mx.nd.my_noisy_relu(d2) + mx.random.seed(128, ctx=mx.gpu()) + r4 = mx.nd.my_noisy_relu(d2) + assert_almost_equal(r3.asnumpy(), r4.asnumpy(), rtol=1e-3, atol=1e-3)