Skip to content

Commit

Permalink
[python] add start_iteration to python predict interface (#3058)
Browse files Browse the repository at this point in the history
  • Loading branch information
shiyu1994 committed Aug 4, 2020
1 parent 1d59a04 commit 5b7c0ed
Show file tree
Hide file tree
Showing 14 changed files with 178 additions and 91 deletions.
8 changes: 4 additions & 4 deletions R-package/src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ LGBM_SE LGBM_BoosterPredictForFile_R(LGBM_SE handle,
R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
CHECK_CALL(LGBM_BoosterPredictForFile(R_GET_PTR(handle), R_CHAR_PTR(data_filename),
R_AS_INT(data_has_header), pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter),
R_AS_INT(data_has_header), pred_type, 0, R_AS_INT(num_iteration), R_CHAR_PTR(parameter),
R_CHAR_PTR(result_filename)));
R_API_END();
}
Expand All @@ -565,7 +565,7 @@ LGBM_SE LGBM_BoosterCalcNumPredict_R(LGBM_SE handle,
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
int64_t len = 0;
CHECK_CALL(LGBM_BoosterCalcNumPredict(R_GET_PTR(handle), R_AS_INT(num_row),
pred_type, R_AS_INT(num_iteration), &len));
pred_type, 0, R_AS_INT(num_iteration), &len));
R_INT_PTR(out_len)[0] = static_cast<int>(len);
R_API_END();
}
Expand Down Expand Up @@ -599,7 +599,7 @@ LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
CHECK_CALL(LGBM_BoosterPredictForCSC(R_GET_PTR(handle),
p_indptr, C_API_DTYPE_INT32, p_indices,
p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
nrow, pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
nrow, pred_type, 0, R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
R_API_END();
}

Expand All @@ -625,7 +625,7 @@ LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,
int64_t out_len;
CHECK_CALL(LGBM_BoosterPredictForMat(R_GET_PTR(handle),
p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
pred_type, 0, R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));

R_API_END();
}
Expand Down
5 changes: 3 additions & 2 deletions include/LightGBM/boosting.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class LIGHTGBM_EXPORT Boosting {
*/
virtual void GetPredictAt(int data_idx, double* result, int64_t* out_len) = 0;

virtual int NumPredictOneRow(int num_iteration, bool is_pred_leaf, bool is_pred_contrib) const = 0;
virtual int NumPredictOneRow(int start_iteration, int num_iteration, bool is_pred_leaf, bool is_pred_contrib) const = 0;

/*!
* \brief Prediction for one record, not sigmoid transform
Expand Down Expand Up @@ -284,10 +284,11 @@ class LIGHTGBM_EXPORT Boosting {

/*!
* \brief Initial work for the prediction
* \param start_iteration Start index of the iteration to predict
* \param num_iteration number of used iteration
* \param is_pred_contrib
*/
virtual void InitPredict(int num_iteration, bool is_pred_contrib) = 0;
virtual void InitPredict(int start_iteration, int num_iteration, bool is_pred_contrib) = 0;

/*!
* \brief Name of submodel
Expand Down
18 changes: 18 additions & 0 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle,
* - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param start_iteration Start index of the iteration to predict
* \param num_iteration Number of iterations for prediction, <= 0 means no limit
* \param parameter Other parameters for prediction, e.g. early stopping for prediction
* \param result_filename Filename of result file in which predictions will be written
Expand All @@ -684,6 +685,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
const char* data_filename,
int data_has_header,
int predict_type,
int start_iteration,
int num_iteration,
const char* parameter,
const char* result_filename);
Expand All @@ -697,13 +699,15 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
* - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param start_iteration Start index of the iteration to predict
* \param num_iteration Number of iterations for prediction, <= 0 means no limit
* \param[out] out_len Length of prediction
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
int num_row,
int predict_type,
int start_iteration,
int num_iteration,
int64_t* out_len);

Expand Down Expand Up @@ -736,6 +740,7 @@ LIGHTGBM_C_EXPORT int LGBM_FastConfigFree(FastConfigHandle fastConfig);
* - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param start_iteration Start index of the iteration to predict
* \param num_iteration Number of iterations for prediction, <= 0 means no limit
* \param parameter Other parameters for prediction, e.g. early stopping for prediction
* \param[out] out_len Length of output result
Expand All @@ -752,6 +757,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
int64_t nelem,
int64_t num_col,
int predict_type,
int start_iteration,
int num_iteration,
const char* parameter,
int64_t* out_len,
Expand All @@ -775,6 +781,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
* \param num_col_or_row Number of columns for CSR or number of rows for CSC
* \param predict_type What should be predicted, only feature contributions supported currently
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param start_iteration Start index of the iteration to predict
* \param num_iteration Number of iterations for prediction, <= 0 means no limit
* \param parameter Other parameters for prediction, e.g. early stopping for prediction
* \param matrix_type Type of matrix input and output, can be ``C_API_MATRIX_TYPE_CSR`` or ``C_API_MATRIX_TYPE_CSC``
Expand All @@ -794,6 +801,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictSparseOutput(BoosterHandle handle,
int64_t nelem,
int64_t num_col_or_row,
int predict_type,
int start_iteration,
int num_iteration,
const char* parameter,
int matrix_type,
Expand Down Expand Up @@ -835,6 +843,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterFreePredictSparse(void* indptr, int32_t* indic
* - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param start_iteration Start index of the iteration to predict
* \param num_iteration Number of iterations for prediction, <= 0 means no limit
* \param parameter Other parameters for prediction, e.g. early stopping for prediction
* \param[out] out_len Length of output result
Expand All @@ -851,6 +860,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
int64_t nelem,
int64_t num_col,
int predict_type,
int start_iteration,
int num_iteration,
const char* parameter,
int64_t* out_len,
Expand Down Expand Up @@ -944,6 +954,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSRSingleRowFast(FastConfigHandle fa
* - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param start_iteration Start index of the iteration to predict
* \param num_iteration Number of iteration for prediction, <= 0 means no limit
* \param parameter Other parameters for prediction, e.g. early stopping for prediction
* \param[out] out_len Length of output result
Expand All @@ -960,6 +971,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
int64_t nelem,
int64_t num_row,
int predict_type,
int start_iteration,
int num_iteration,
const char* parameter,
int64_t* out_len,
Expand All @@ -983,6 +995,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
* - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param start_iteration Start index of the iteration to predict
* \param num_iteration Number of iteration for prediction, <= 0 means no limit
* \param parameter Other parameters for prediction, e.g. early stopping for prediction
* \param[out] out_len Length of output result
Expand All @@ -996,6 +1009,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
int32_t ncol,
int is_row_major,
int predict_type,
int start_iteration,
int num_iteration,
const char* parameter,
int64_t* out_len,
Expand All @@ -1019,6 +1033,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
* - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param start_iteration Start index of the iteration to predict
* \param num_iteration Number of iteration for prediction, <= 0 means no limit
* \param parameter Other parameters for prediction, e.g. early stopping for prediction
* \param[out] out_len Length of output result
Expand All @@ -1031,6 +1046,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
int ncol,
int is_row_major,
int predict_type,
int start_iteration,
int num_iteration,
const char* parameter,
int64_t* out_len,
Expand Down Expand Up @@ -1104,6 +1120,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMatSingleRowFast(FastConfigHandle fa
* - ``C_API_PREDICT_RAW_SCORE``: raw score;
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
* \param start_iteration Start index of the iteration to predict
* \param num_iteration Number of iteration for prediction, <= 0 means no limit
* \param parameter Other parameters for prediction, e.g. early stopping for prediction
* \param[out] out_len Length of output result
Expand All @@ -1116,6 +1133,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMats(BoosterHandle handle,
int32_t nrow,
int32_t ncol,
int predict_type,
int start_iteration,
int num_iteration,
const char* parameter,
int64_t* out_len,
Expand Down
6 changes: 6 additions & 0 deletions include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,12 @@ struct Config {

#pragma region Predict Parameters

// [no-save]
// desc = used only in ``prediction`` task
// desc = used to specify from which iteration to start the prediction
// desc = ``<= 0`` means from the first iteration
int start_iteration_predict = 0;

// [no-save]
// desc = used only in ``prediction`` task
// desc = used to specify how many trained iterations will be used in prediction
Expand Down
Loading

0 comments on commit 5b7c0ed

Please sign in to comment.