Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Custom Operator Random Number Generator Support #17762

Merged
merged 28 commits into from
Apr 8, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
misc fix
  • Loading branch information
rondogency committed Feb 27, 2020
commit 87d3314141636535c69ab12aef8ce16302cb7b3e
3 changes: 3 additions & 0 deletions example/extensions/lib_custom_op/gemm_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ MXReturnValue forward(std::map<std::string, std::string> attrs,
unsigned k = inputs[0].shape[1];
unsigned m = inputs[1].shape[1];

int random = res.get_randint();
std::cout << random << std::endl;

gemm(A, B, C, n, k, m);
}
return MX_SUCCESS;
Expand Down
2 changes: 0 additions & 2 deletions example/extensions/lib_custom_op/relu_lib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ MXReturnValue forwardCPU(std::map<std::string, std::string> attrs,
OpResource res) {
float* in_data = inputs[0].data<float>();
float* out_data = outputs[0].data<float>();
int random = res.get_randint();
std::cout << random << std::endl;
for (int i=0; i<inputs[0].size(); i++) {
out_data[i] = in_data[i] > 0 ? in_data[i] : 0;
}
Expand Down
4 changes: 2 additions & 2 deletions include/mxnet/lib_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ union RandomType {
double d;
};

typedef RandomType (*rng_caller_t)(void*, char*);
typedef RandomType (*rng_caller_t)(void*, const char*);

#if defined(__NVCC__)
typedef cudaStream_t mx_stream_t;
Expand Down Expand Up @@ -404,7 +404,7 @@ class OpResource {
return static_cast<mx_stream_t>(cuda_stream);
}

int gen_randint() {
int get_randint() {
RandomType ret = rng_caller_nocap(rng_caller, "rand");
return ret.i;
}
Expand Down
4 changes: 2 additions & 2 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ void CustomFComputeDispatcher(const std::string op_name,
mxnet::common::random::RandGenerator<cpu, double> *pgen =
ctx.requested[1].get_parallel_random<cpu, double>();

auto rng_caller = [&](char *rand_type) {
auto rng_caller = [&](const char *rand_type) {
LOG(INFO) << "rng_caller called";
typename mxnet::common::random::RandGenerator<cpu, double>::Impl genImpl(pgen, 1);
std::string rand_str(rand_type);
Expand All @@ -194,7 +194,7 @@ void CustomFComputeDispatcher(const std::string op_name,
};

typedef decltype(rng_caller) type_rng_caller;
auto rng_caller_nocap = [](void *rng_call, char *rand_type) {
auto rng_caller_nocap = [](void *rng_call, const char *rand_type) {
type_rng_caller* rngcaller = static_cast<type_rng_caller*>(rng_call);
return (*rngcaller)(rand_type);
};
Expand Down