Skip to content

Commit

Permalink
update R documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
shiyu1994 committed Aug 6, 2020
1 parent 2f27c4a commit e19742c
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 17 deletions.
Binary file added R-package/.RData
Binary file not shown.
4 changes: 2 additions & 2 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -708,8 +708,8 @@ Booster <- R6::R6Class(
#' If None or <= 0, starts from the first iteration.
#' @param num_iteration int or None, optional (default=None)
#' Limit number of iterations in the prediction.
#' If None, if the best iteration exists and start_iteration is None or <= 0, the best iteration is used;
#' otherwise, all iterations from start_iteration are used.
#' If None, if the best iteration exists and start_iteration is None or <= 0, the
#' best iteration is used; otherwise, all iterations from start_iteration are used.
#' If <= 0, all iterations from start_iteration are used (no limits).
#' @param rawscore whether the prediction should be returned in the for of original untransformed
#' sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE}
Expand Down
11 changes: 10 additions & 1 deletion R-package/man/predict.lgb.Booster.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 20 additions & 12 deletions R-package/tests/testthat/test_Predictor.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,31 +41,39 @@ test_that("start_iteration works correctly", {
, nrounds = 100L
, objective = "binary"
, save_name = tempfile(fileext = ".model")
, valids=list("test" = dtest)
, early_stopping_rounds=2L
, valids = list("test" = dtest)
, early_stopping_rounds = 2L
)
expect_true(lgb.is.Booster(bst))
pred1 <- predict(bst, data=test$data, rawscore=TRUE)
pred_contrib1 <- predict(bst, test$data, predcontrib=TRUE)
pred1 <- predict(bst, data = test$data, rawscore = TRUE)
pred_contrib1 <- predict(bst, test$data, predcontrib = TRUE)
pred2 <- rep(0.0, length(pred1))
pred_contrib2 <- rep(0.0, length(pred2))
step <- 11L
end_iter <- 99L
if (bst$best_iter != -1) {
end_iter <- bst$best_iter - 1
if (bst$best_iter != -1L) {
end_iter <- bst$best_iter - 1L
}
start_iters <- seq(0L, end_iter, by=step)
start_iters <- seq(0L, end_iter, by = step)
for (start_iter in start_iters) {
n_iter = min(c(end_iter - start_iter + 1, step))
inc_pred <- predict(bst, test$data, start_iteration=start_iter, num_iteration=n_iter, rawscore=TRUE)
inc_pred_contrib <- bst$predict(test$data, start_iteration=start_iter, num_iteration=n_iter, predcontrib=TRUE)
n_iter <- min(c(end_iter - start_iter + 1L, step))
inc_pred <- predict(bst, test$data
, start_iteration = start_iter
, num_iteration = n_iter
, rawscore = TRUE
)
inc_pred_contrib <- bst$predict(test$data
, start_iteration = start_iter
, num_iteration = n_iter
, predcontrib = TRUE
)
pred2 <- pred2 + inc_pred
pred_contrib2 <- pred_contrib2 + inc_pred_contrib
}
expect_equal(pred2, pred1)
expect_equal(pred_contrib2, pred_contrib1)

pred_leaf1 <- predict(bst, test$data, predleaf=TRUE)
pred_leaf2 <- predict(bst, test$data, start_iteration=0L, num_iteration=end_iter + 1, predleaf=TRUE)
pred_leaf1 <- predict(bst, test$data, predleaf = TRUE)
pred_leaf2 <- predict(bst, test$data, start_iteration = 0L, num_iteration = end_iter + 1L, predleaf = TRUE)
expect_equal(pred_leaf1, pred_leaf2)
})
3 changes: 1 addition & 2 deletions src/boosting/gbdt.h
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,7 @@ class GBDT : public GBDTBase {
start_iteration = std::min(start_iteration, num_iteration_for_pred_);
if (num_iteration > 0) {
num_iteration_for_pred_ = std::min(num_iteration, num_iteration_for_pred_ - start_iteration);
}
else {
} else {
num_iteration_for_pred_ = num_iteration_for_pred_ - start_iteration;
}
start_iteration_for_pred_ = start_iteration;
Expand Down

0 comments on commit e19742c

Please sign in to comment.