Skip to content

Commit

Permalink
Fast single row predict API v2 (#3268)
Browse files Browse the repository at this point in the history
* Fix bug introduced in PR #2992 for Fast predict

* Faster Fast predict API

* Add const to SingleRow Fast methods
  • Loading branch information
AlbertoEAF authored Aug 5, 2020
1 parent a9f5654 commit b5027de
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 38 deletions.
42 changes: 21 additions & 21 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -862,13 +862,21 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
* Release the ``FastConfig`` by passing its handle to ``LGBM_FastConfigFree`` when no longer needed.
*
* \param handle Booster handle
* \param predict_type What should be predicted
* - ``C_API_PREDICT_NORMAL``: normal prediction, with transform (if needed);
* - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param num_iteration Number of iterations for prediction, <= 0 means no limit
* \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64``
* \param num_col Number of columns
* \param parameter Other parameters for prediction, e.g. early stopping for prediction
* \param[out] out_fastConfig FastConfig object with which you can call ``LGBM_BoosterPredictForCSRSingleRowFast``
* \return 0 when it succeeds, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle handle,
const int predict_type,
const int num_iteration,
const int data_type,
const int64_t num_col,
const char* parameter,
Expand Down Expand Up @@ -901,25 +909,17 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle h
* \param data Pointer to the data space
* \param nindptr Number of rows in the matrix + 1
* \param nelem Number of nonzero elements in the matrix
* \param predict_type What should be predicted
* - ``C_API_PREDICT_NORMAL``: normal prediction, with transform (if needed);
* - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param num_iteration Number of iterations for prediction, <= 0 means no limit
* \param[out] out_len Length of output result
* \param[out] out_result Pointer to array with predictions
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRowFast(FastConfigHandle fastConfig_handle,
const void* indptr,
int indptr_type,
const int indptr_type,
const int32_t* indices,
const void* data,
int64_t nindptr,
int64_t nelem,
int predict_type,
int num_iteration,
const int64_t nindptr,
const int64_t nelem,
int64_t* out_len,
double* out_result);

Expand Down Expand Up @@ -1042,15 +1042,23 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
* Release the ``FastConfig`` by passing its handle to ``LGBM_FastConfigFree`` when no longer needed.
*
* \param handle Booster handle
* \param predict_type What should be predicted
* - ``C_API_PREDICT_NORMAL``: normal prediction, with transform (if needed);
* - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param num_iteration Number of iterations for prediction, <= 0 means no limit
* \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64``
* \param ncol Number of columns
* \param parameter Other parameters for prediction, e.g. early stopping for prediction
* \param[out] out_fastConfig FastConfig object with which you can call ``LGBM_BoosterPredictForMatSingleRowFast``
* \return 0 when it succeeds, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle handle,
int data_type,
int32_t ncol,
const int predict_type,
const int num_iteration,
const int data_type,
const int32_t ncol,
const char* parameter,
FastConfigHandle *out_fastConfig);

Expand All @@ -1070,20 +1078,12 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle h
*
* \param fastConfig_handle FastConfig object handle returned by ``LGBM_BoosterPredictForMatSingleRowFastInit``
* \param data Single-row array data (no other way than row-major form).
* \param predict_type What should be predicted
* - ``C_API_PREDICT_NORMAL``: normal prediction, with transform (if needed);
* - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param num_iteration Number of iteration for prediction, <= 0 means no limit
* \param[out] out_len Length of output result
* \param[out] out_result Pointer to array with predictions
* \return 0 when it succeeds, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRowFast(FastConfigHandle fastConfig_handle,
const void* data,
int predict_type,
int num_iteration,
int64_t* out_len,
double* out_result);

Expand Down
28 changes: 18 additions & 10 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1769,13 +1769,15 @@ int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
struct FastConfig {
FastConfig(Booster *const booster_ptr,
const char *parameter,
const int predict_type_,
const int data_type_,
const int32_t num_cols) : booster(booster_ptr), data_type(data_type_), ncol(num_cols) {
const int32_t num_cols) : booster(booster_ptr), predict_type(predict_type_), data_type(data_type_), ncol(num_cols) {
config.Set(Config::Str2Map(parameter));
}

Booster* const booster;
Config config;
const int predict_type;
const int data_type;
const int32_t ncol;
};
Expand Down Expand Up @@ -1939,6 +1941,8 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
}

int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle handle,
const int predict_type,
const int num_iteration,
const int data_type,
const int64_t num_col,
const char* parameter,
Expand All @@ -1953,32 +1957,33 @@ int LGBM_BoosterPredictForCSRSingleRowFastInit(BoosterHandle handle,
auto fastConfig_ptr = std::unique_ptr<FastConfig>(new FastConfig(
reinterpret_cast<Booster*>(handle),
parameter,
predict_type,
data_type,
static_cast<int32_t>(num_col)));

if (fastConfig_ptr->config.num_threads > 0) {
omp_set_num_threads(fastConfig_ptr->config.num_threads);
}

fastConfig_ptr->booster->SetSingleRowPredictor(num_iteration, predict_type, fastConfig_ptr->config);

*out_fastConfig = fastConfig_ptr.release();
API_END();
}

int LGBM_BoosterPredictForCSRSingleRowFast(FastConfigHandle fastConfig_handle,
const void* indptr,
int indptr_type,
const int indptr_type,
const int32_t* indices,
const void* data,
int64_t nindptr,
int64_t nelem,
int predict_type,
int num_iteration,
const int64_t nindptr,
const int64_t nelem,
int64_t* out_len,
double* out_result) {
API_BEGIN();
FastConfig *fastConfig = reinterpret_cast<FastConfig*>(fastConfig_handle);
auto get_row_fun = RowFunctionFromCSR<int>(indptr, indptr_type, indices, data, fastConfig->data_type, nindptr, nelem);
fastConfig->booster->PredictSingleRow(predict_type, fastConfig->ncol,
fastConfig->booster->PredictSingleRow(fastConfig->predict_type, fastConfig->ncol,
get_row_fun, fastConfig->config, out_result, out_len);
API_END();
}
Expand Down Expand Up @@ -2082,6 +2087,8 @@ int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
}

int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle handle,
const int predict_type,
const int num_iteration,
const int data_type,
const int32_t ncol,
const char* parameter,
Expand All @@ -2090,28 +2097,29 @@ int LGBM_BoosterPredictForMatSingleRowFastInit(BoosterHandle handle,
auto fastConfig_ptr = std::unique_ptr<FastConfig>(new FastConfig(
reinterpret_cast<Booster*>(handle),
parameter,
predict_type,
data_type,
ncol));

if (fastConfig_ptr->config.num_threads > 0) {
omp_set_num_threads(fastConfig_ptr->config.num_threads);
}

fastConfig_ptr->booster->SetSingleRowPredictor(num_iteration, predict_type, fastConfig_ptr->config);

*out_fastConfig = fastConfig_ptr.release();
API_END();
}

int LGBM_BoosterPredictForMatSingleRowFast(FastConfigHandle fastConfig_handle,
const void* data,
const int predict_type,
const int num_iteration,
int64_t* out_len,
double* out_result) {
API_BEGIN();
FastConfig *fastConfig = reinterpret_cast<FastConfig*>(fastConfig_handle);
// Single row in row-major format:
auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, fastConfig->ncol, fastConfig->data_type, 1);
fastConfig->booster->PredictSingleRow(predict_type, fastConfig->ncol,
fastConfig->booster->PredictSingleRow(fastConfig->predict_type, fastConfig->ncol,
get_row_fun, fastConfig->config,
out_result, out_len);
API_END();
Expand Down
9 changes: 2 additions & 7 deletions swig/lightgbmlib.i
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,11 @@
int LGBM_BoosterPredictForMatSingleRowFastCriticalSWIG(JNIEnv *jenv,
jdoubleArray data,
FastConfigHandle handle,
int predict_type,
int num_iteration,
int64_t* out_len,
double* out_result) {
double* data0 = (double*)jenv->GetPrimitiveArrayCritical(data, 0);

int ret = LGBM_BoosterPredictForMatSingleRowFast(handle, data0, predict_type,
num_iteration, out_len, out_result);
int ret = LGBM_BoosterPredictForMatSingleRowFast(handle, data0, out_len, out_result);

jenv->ReleasePrimitiveArrayCritical(data, data0, JNI_ABORT);

Expand Down Expand Up @@ -174,8 +171,6 @@
FastConfigHandle handle,
int indptr_type,
int64_t nelem,
int predict_type,
int num_iteration,
int64_t* out_len,
double* out_result) {
// Alternatives
Expand All @@ -191,7 +186,7 @@
int32_t ind[2] = { 0, numNonZeros };

int ret = LGBM_BoosterPredictForCSRSingleRowFast(handle, ind, indptr_type, indices0, values0, 2,
nelem, predict_type, num_iteration, out_len, out_result);
nelem, out_len, out_result);

jenv->ReleasePrimitiveArrayCritical(values, values0, JNI_ABORT);
jenv->ReleasePrimitiveArrayCritical(indices, indices0, JNI_ABORT);
Expand Down

0 comments on commit b5027de

Please sign in to comment.