-
Notifications
You must be signed in to change notification settings - Fork 5.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[train] Simplify
ray.train.xgboost/lightgbm
(2/n): Re-implement `XG…
…BoostTrainer` as a lightweight `DataParallelTrainer` (#42767) This PR re-implements `XGBoostTrainer` as a `DataParallelTrainer` that does not use `xgboost_ray` under the hood, in an effort to unify the trainer implementations and remove that external dependency. --------- Signed-off-by: Justin Yu <[email protected]>
- Loading branch information
Showing
8 changed files
with
461 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,13 @@ | ||
from ray.train.xgboost._xgboost_utils import RayTrainReportCallback | ||
from ray.train.xgboost.config import XGBoostConfig | ||
from ray.train.xgboost.xgboost_checkpoint import XGBoostCheckpoint | ||
from ray.train.xgboost.xgboost_predictor import XGBoostPredictor | ||
from ray.train.xgboost.xgboost_trainer import XGBoostTrainer | ||
|
||
__all__ = [ | ||
"RayTrainReportCallback", | ||
"XGBoostCheckpoint", | ||
"XGBoostConfig", | ||
"XGBoostPredictor", | ||
"XGBoostTrainer", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import json | ||
import logging | ||
import os | ||
from dataclasses import dataclass | ||
|
||
from xgboost import RabitTracker | ||
|
||
import ray | ||
from ray.train._internal.worker_group import WorkerGroup | ||
from ray.train.backend import Backend, BackendConfig | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass | ||
class XGBoostConfig(BackendConfig): | ||
"""Configuration for xgboost collective communication setup. | ||
Ray Train will set up the necessary coordinator processes and environment | ||
variables for your workers to communicate with each other. | ||
Additional configuration options can be passed into the | ||
`xgboost.collective.CommunicatorContext` that wraps your own `xgboost.train` code. | ||
See the `xgboost.collective` module for more information: | ||
https://github.com/dmlc/xgboost/blob/master/python-package/xgboost/collective.py | ||
Args: | ||
xgboost_communicator: The backend to use for collective communication for | ||
distributed xgboost training. For now, only "rabit" is supported. | ||
""" | ||
|
||
xgboost_communicator: str = "rabit" | ||
|
||
@property | ||
def backend_cls(self): | ||
if self.xgboost_communicator == "rabit": | ||
return _XGBoostRabitBackend | ||
|
||
raise NotImplementedError(f"Unsupported backend: {self.xgboost_communicator}") | ||
|
||
|
||
class _XGBoostRabitBackend(Backend): | ||
def __init__(self): | ||
self._tracker = None | ||
|
||
def on_training_start( | ||
self, worker_group: WorkerGroup, backend_config: XGBoostConfig | ||
): | ||
assert backend_config.xgboost_communicator == "rabit" | ||
|
||
# Set up the rabit tracker on the Train driver. | ||
num_workers = len(worker_group) | ||
rabit_args = {"DMLC_NUM_WORKER": num_workers} | ||
train_driver_ip = ray.util.get_node_ip_address() | ||
|
||
# NOTE: sortby="task" is needed to ensure that the xgboost worker ranks | ||
# align with Ray Train worker ranks. | ||
# The worker ranks will be sorted by `DMLC_TASK_ID`, | ||
# which is defined in `on_training_start`. | ||
self._tracker = RabitTracker( | ||
host_ip=train_driver_ip, n_workers=num_workers, sortby="task" | ||
) | ||
rabit_args.update(self._tracker.worker_envs()) | ||
self._tracker.start(num_workers) | ||
|
||
start_log = ( | ||
"RabitTracker coordinator started with parameters:\n" | ||
f"{json.dumps(rabit_args, indent=2)}" | ||
) | ||
logger.debug(start_log) | ||
|
||
def set_xgboost_env_vars(): | ||
import ray.train | ||
|
||
for k, v in rabit_args.items(): | ||
os.environ[k] = str(v) | ||
|
||
# Ranks are assigned in increasing order of the worker's task id. | ||
# This task id will be sorted by increasing world rank. | ||
os.environ["DMLC_TASK_ID"] = ( | ||
f"[xgboost.ray-rank={ray.train.get_context().get_world_rank():08}]:" | ||
f"{ray.get_runtime_context().get_actor_id()}" | ||
) | ||
|
||
worker_group.execute(set_xgboost_env_vars) | ||
|
||
def on_shutdown(self, worker_group: WorkerGroup, backend_config: XGBoostConfig): | ||
timeout = 5 | ||
self._tracker.thread.join(timeout=timeout) | ||
|
||
if self._tracker.thread.is_alive(): | ||
logger.warning( | ||
"During shutdown, the RabitTracker thread failed to join " | ||
f"within {timeout} seconds. " | ||
"The process will still be terminated as part of Ray actor cleanup." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import logging | ||
from typing import Any, Callable, Dict, Optional, Union | ||
|
||
import ray.train | ||
from ray.train import Checkpoint | ||
from ray.train.data_parallel_trainer import DataParallelTrainer | ||
from ray.train.trainer import GenDataset | ||
from ray.train.xgboost import XGBoostConfig | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class XGBoostTrainer(DataParallelTrainer): | ||
"""A Trainer for distributed data-parallel XGBoost training. | ||
Example | ||
------- | ||
.. testcode:: | ||
import xgboost | ||
import ray.data | ||
import ray.train | ||
from ray.train.xgboost import RayTrainReportCallback | ||
from ray.train.xgboost.v2 import XGBoostTrainer | ||
def train_fn_per_worker(config: dict): | ||
from xgboost.collective import CommunicatorContext | ||
# (Optional) Add logic to resume training state from a checkpoint. | ||
# ray.train.get_checkpoint() | ||
# 1. Get the dataset shard for the worker and convert to a `xgboost.DMatrix` | ||
train_ds_iter, eval_ds_iter = ( | ||
ray.train.get_dataset_shard("train"), | ||
ray.train.get_dataset_shard("validation"), | ||
) | ||
train_ds, eval_ds = train_ds_iter.materialize(), eval_ds_iter.materialize() | ||
train_df, eval_df = train_ds.to_pandas(), eval_ds.to_pandas() | ||
train_X, train_y = train_df.drop("y", axis=1), train_df["y"] | ||
eval_X, eval_y = eval_df.drop("y", axis=1), eval_df["y"] | ||
dtrain = xgboost.DMatrix(train_X, label=train_y) | ||
deval = xgboost.DMatrix(eval_X, label=eval_y) | ||
params = { | ||
"tree_method": "approx", | ||
"objective": "reg:squarederror", | ||
"eta": 1e-4, | ||
"subsample": 0.5, | ||
"max_depth": 2, | ||
} | ||
# 2. Do distributed data-parallel training with the `CommunicatorContext`. | ||
# Ray Train sets up the necessary coordinator processes and | ||
# environment variables for your workers to communicate with each other. | ||
with CommunicatorContext(): | ||
bst = xgboost.train( | ||
params, | ||
dtrain=dtrain, | ||
evals=[(deval, "validation")], | ||
num_boost_round=10, | ||
callbacks=[RayTrainReportCallback()], | ||
) | ||
train_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)]) | ||
eval_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(16)]) | ||
trainer = XGBoostTrainer( | ||
train_fn_per_worker, | ||
datasets={"train": train_ds, "validation": eval_ds}, | ||
scaling_config=ray.train.ScalingConfig(num_workers=4), | ||
) | ||
result = trainer.fit() | ||
booster = RayTrainReportCallback.get_model(result.checkpoint) | ||
.. testoutput:: | ||
:hide: | ||
... | ||
Args: | ||
train_loop_per_worker: The training function to execute on each worker. | ||
This function can either take in zero arguments or a single ``Dict`` | ||
argument which is set by defining ``train_loop_config``. | ||
Within this function you can use any of the | ||
:ref:`Ray Train Loop utilities <train-loop-api>`. | ||
train_loop_config: A configuration ``Dict`` to pass in as an argument to | ||
``train_loop_per_worker``. | ||
This is typically used for specifying hyperparameters. | ||
xgboost_config: The configuration for setting up the distributed xgboost | ||
backend. Defaults to using the "rabit" backend. | ||
See :class:`~ray.train.xgboost.XGBoostConfig` for more info. | ||
datasets: The Ray Datasets to use for training and validation. | ||
dataset_config: The configuration for ingesting the input ``datasets``. | ||
By default, all the Ray Datasets are split equally across workers. | ||
See :class:`~ray.train.DataConfig` for more details. | ||
scaling_config: The configuration for how to scale data parallel training. | ||
``num_workers`` determines how many Python processes are used for training, | ||
and ``use_gpu`` determines whether or not each process should use GPUs. | ||
See :class:`~ray.train.ScalingConfig` for more info. | ||
run_config: The configuration for the execution of the training run. | ||
See :class:`~ray.train.RunConfig` for more info. | ||
resume_from_checkpoint: A checkpoint to resume training from. | ||
This checkpoint can be accessed from within ``train_loop_per_worker`` | ||
by calling ``ray.train.get_checkpoint()``. | ||
metadata: Dict that should be made available via | ||
`ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()` | ||
for checkpoints saved from this Trainer. Must be JSON-serializable. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]], | ||
*, | ||
train_loop_config: Optional[Dict] = None, | ||
xgboost_config: Optional[XGBoostConfig] = None, | ||
scaling_config: Optional[ray.train.ScalingConfig] = None, | ||
run_config: Optional[ray.train.RunConfig] = None, | ||
datasets: Optional[Dict[str, GenDataset]] = None, | ||
dataset_config: Optional[ray.train.DataConfig] = None, | ||
metadata: Optional[Dict[str, Any]] = None, | ||
resume_from_checkpoint: Optional[Checkpoint] = None, | ||
): | ||
super(XGBoostTrainer, self).__init__( | ||
train_loop_per_worker=train_loop_per_worker, | ||
train_loop_config=train_loop_config, | ||
backend_config=xgboost_config or XGBoostConfig(), | ||
scaling_config=scaling_config, | ||
dataset_config=dataset_config, | ||
run_config=run_config, | ||
datasets=datasets, | ||
resume_from_checkpoint=resume_from_checkpoint, | ||
metadata=metadata, | ||
) |
Oops, something went wrong.