Skip to content

Commit

Permalink
[train] Enforce xgboost>=1.7 for XGBoostTrainer usage (#44269)
Browse files Browse the repository at this point in the history
`XGBoostTrainer` now relies on `xgboost.collective`, which was added in xgboost 1.7. Raise a warning if using a lower version.

Signed-off-by: Justin Yu <[email protected]>
  • Loading branch information
justinvyu committed Mar 25, 2024
1 parent 226b2cc commit a0b588b
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions python/ray/train/xgboost/xgboost_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Dict, Optional

import xgboost
from packaging.version import Version

import ray.train
from ray.train import Checkpoint
Expand Down Expand Up @@ -152,6 +153,12 @@ def __init__(
metadata: Optional[Dict[str, Any]] = None,
**train_kwargs,
):
if Version(xgboost.__version__) < Version("1.7.0"):
raise ImportError(
"`XGBoostTrainer` requires the `xgboost` version to be >= 1.7.0. "
'Upgrade with: `pip install -U "xgboost>=1.7"`'
)

# TODO(justinvyu): [Deprecated] Remove in 2.11
if dmatrix_params != _DEPRECATED_VALUE:
raise DeprecationWarning(
Expand Down

0 comments on commit a0b588b

Please sign in to comment.