Skip to content

Commit

Permalink
[train] Simplify ray.train.xgboost/lightgbm (2/n): Re-implement `XG…
Browse files Browse the repository at this point in the history
…BoostTrainer` as a lightweight `DataParallelTrainer` (#42767)

This PR re-implements `XGBoostTrainer` as a `DataParallelTrainer` that does not use `xgboost_ray` under the hood, in an effort to unify the trainer implementations and remove that external dependency.

---------

Signed-off-by: Justin Yu <[email protected]>
  • Loading branch information
justinvyu committed Feb 23, 2024
1 parent 6908b12 commit 62dbcb2
Show file tree
Hide file tree
Showing 8 changed files with 461 additions and 54 deletions.
3 changes: 2 additions & 1 deletion doc/source/train/doc_code/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# __basic_start__
import ray
from ray import tune
from ray import train, tune
from ray.tune import Tuner
from ray.train.xgboost import XGBoostTrainer

Expand All @@ -17,6 +17,7 @@
"max_depth": 4,
},
datasets={"train": dataset},
scaling_config=train.ScalingConfig(num_workers=2),
)

# Create Tuner
Expand Down
4 changes: 2 additions & 2 deletions python/ray/air/tests/test_resource_changing.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_gbdt_trainer(ray_start_8_cpus):
dataset_df["target"] = data_raw["target"]
train_ds = ray.data.from_pandas(dataset_df).repartition(16)
trainer = AssertingXGBoostTrainer(
datasets={TRAIN_DATASET_KEY: train_ds},
datasets={TRAIN_DATASET_KEY: train_ds, "validation": train_ds},
label_column="target",
scaling_config=ScalingConfig(num_workers=2, placement_strategy="SPREAD"),
params={
Expand All @@ -142,7 +142,7 @@ def test_gbdt_trainer(ray_start_8_cpus):
},
tune_config=TuneConfig(
mode="min",
metric="train-logloss",
metric="validation-logloss",
max_concurrent_trials=3,
scheduler=ResourceChangingScheduler(
ASHAScheduler(),
Expand Down
2 changes: 1 addition & 1 deletion python/ray/train/_internal/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __init__(
local_world_size: int,
world_size: int,
trial_info: Optional[TrialInfo] = None,
dataset_shard: Optional[Dataset] = None,
dataset_shard: Optional[Dict[str, Dataset]] = None,
metadata: Dict[str, Any] = None,
checkpoint: Optional[Checkpoint] = None,
detailed_autofilled_metrics: bool = False,
Expand Down
56 changes: 8 additions & 48 deletions python/ray/train/tests/test_xgboost_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_resume_from_checkpoint(ray_start_4_cpus, tmpdir):
result = trainer.fit()
checkpoint = result.checkpoint
xgb_model = XGBoostTrainer.get_model(checkpoint)
assert get_num_trees(xgb_model) == 5
assert xgb_model.num_boosted_rounds() == 5

trainer = XGBoostTrainer(
scaling_config=scale_config,
Expand All @@ -108,7 +108,7 @@ def test_resume_from_checkpoint(ray_start_4_cpus, tmpdir):
)
result = trainer.fit()
model = XGBoostTrainer.get_model(result.checkpoint)
assert get_num_trees(model) == 10
assert model.num_boosted_rounds() == 10


@pytest.mark.parametrize(
Expand Down Expand Up @@ -158,22 +158,19 @@ def test_tune(ray_start_8_cpus):
trainer = XGBoostTrainer(
scaling_config=scale_config,
label_column="target",
params={**params, **{"max_depth": 1}},
params={**params, "max_depth": 1},
datasets={TRAIN_DATASET_KEY: train_dataset, "valid": valid_dataset},
)

tune.run(
trainer.as_trainable(),
config={"params": {"max_depth": tune.randint(2, 4)}},
num_samples=2,
tuner = tune.Tuner(
trainer,
param_space={"params": {"max_depth": tune.grid_search([2, 4])}},
)

# Make sure original Trainer is not affected.
assert trainer.params["max_depth"] == 1
results = tuner.fit()
assert sorted([r.config["params"]["max_depth"] for r in results]) == [2, 4]


def test_validation(ray_start_4_cpus):
train_dataset = ray.data.from_pandas(train_df)
valid_dataset = ray.data.from_pandas(test_df)
with pytest.raises(KeyError, match=TRAIN_DATASET_KEY):
XGBoostTrainer(
Expand All @@ -182,43 +179,6 @@ def test_validation(ray_start_4_cpus):
params=params,
datasets={"valid": valid_dataset},
)
with pytest.raises(KeyError, match="dmatrix_params"):
XGBoostTrainer(
scaling_config=ScalingConfig(num_workers=2),
label_column="target",
params=params,
dmatrix_params={"data": {}},
datasets={TRAIN_DATASET_KEY: train_dataset, "valid": valid_dataset},
)


def test_distributed_data_loading(ray_start_4_cpus):
"""Checks that XGBoostTrainer does distributed data loading for Datasets."""

class DummyXGBoostTrainer(XGBoostTrainer):
def _train(self, params, dtrain, **kwargs):
assert dtrain.distributed
return super()._train(params=params, dtrain=dtrain, **kwargs)

train_dataset = ray.data.from_pandas(train_df)

trainer = DummyXGBoostTrainer(
scaling_config=ScalingConfig(num_workers=2),
label_column="target",
params=params,
datasets={TRAIN_DATASET_KEY: train_dataset},
)

assert trainer.dmatrix_params[TRAIN_DATASET_KEY]["distributed"]
trainer.fit()


def test_xgboost_trainer_resources():
"""`trainer_resources` is not allowed in the scaling config"""
with pytest.raises(ValueError):
XGBoostTrainer._validate_scaling_config(
ScalingConfig(trainer_resources={"something": 1})
)


def test_callback_get_model(tmp_path):
Expand Down
2 changes: 2 additions & 0 deletions python/ray/train/xgboost/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from ray.train.xgboost._xgboost_utils import RayTrainReportCallback
from ray.train.xgboost.config import XGBoostConfig
from ray.train.xgboost.xgboost_checkpoint import XGBoostCheckpoint
from ray.train.xgboost.xgboost_predictor import XGBoostPredictor
from ray.train.xgboost.xgboost_trainer import XGBoostTrainer

__all__ = [
"RayTrainReportCallback",
"XGBoostCheckpoint",
"XGBoostConfig",
"XGBoostPredictor",
"XGBoostTrainer",
]
96 changes: 96 additions & 0 deletions python/ray/train/xgboost/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import json
import logging
import os
from dataclasses import dataclass

from xgboost import RabitTracker

import ray
from ray.train._internal.worker_group import WorkerGroup
from ray.train.backend import Backend, BackendConfig

logger = logging.getLogger(__name__)


@dataclass
class XGBoostConfig(BackendConfig):
"""Configuration for xgboost collective communication setup.
Ray Train will set up the necessary coordinator processes and environment
variables for your workers to communicate with each other.
Additional configuration options can be passed into the
`xgboost.collective.CommunicatorContext` that wraps your own `xgboost.train` code.
See the `xgboost.collective` module for more information:
https://github.com/dmlc/xgboost/blob/master/python-package/xgboost/collective.py
Args:
xgboost_communicator: The backend to use for collective communication for
distributed xgboost training. For now, only "rabit" is supported.
"""

xgboost_communicator: str = "rabit"

@property
def backend_cls(self):
if self.xgboost_communicator == "rabit":
return _XGBoostRabitBackend

raise NotImplementedError(f"Unsupported backend: {self.xgboost_communicator}")


class _XGBoostRabitBackend(Backend):
def __init__(self):
self._tracker = None

def on_training_start(
self, worker_group: WorkerGroup, backend_config: XGBoostConfig
):
assert backend_config.xgboost_communicator == "rabit"

# Set up the rabit tracker on the Train driver.
num_workers = len(worker_group)
rabit_args = {"DMLC_NUM_WORKER": num_workers}
train_driver_ip = ray.util.get_node_ip_address()

# NOTE: sortby="task" is needed to ensure that the xgboost worker ranks
# align with Ray Train worker ranks.
# The worker ranks will be sorted by `DMLC_TASK_ID`,
# which is defined in `on_training_start`.
self._tracker = RabitTracker(
host_ip=train_driver_ip, n_workers=num_workers, sortby="task"
)
rabit_args.update(self._tracker.worker_envs())
self._tracker.start(num_workers)

start_log = (
"RabitTracker coordinator started with parameters:\n"
f"{json.dumps(rabit_args, indent=2)}"
)
logger.debug(start_log)

def set_xgboost_env_vars():
import ray.train

for k, v in rabit_args.items():
os.environ[k] = str(v)

# Ranks are assigned in increasing order of the worker's task id.
# This task id will be sorted by increasing world rank.
os.environ["DMLC_TASK_ID"] = (
f"[xgboost.ray-rank={ray.train.get_context().get_world_rank():08}]:"
f"{ray.get_runtime_context().get_actor_id()}"
)

worker_group.execute(set_xgboost_env_vars)

def on_shutdown(self, worker_group: WorkerGroup, backend_config: XGBoostConfig):
timeout = 5
self._tracker.thread.join(timeout=timeout)

if self._tracker.thread.is_alive():
logger.warning(
"During shutdown, the RabitTracker thread failed to join "
f"within {timeout} seconds. "
"The process will still be terminated as part of Ray actor cleanup."
)
136 changes: 136 additions & 0 deletions python/ray/train/xgboost/v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import logging
from typing import Any, Callable, Dict, Optional, Union

import ray.train
from ray.train import Checkpoint
from ray.train.data_parallel_trainer import DataParallelTrainer
from ray.train.trainer import GenDataset
from ray.train.xgboost import XGBoostConfig

logger = logging.getLogger(__name__)


class XGBoostTrainer(DataParallelTrainer):
"""A Trainer for distributed data-parallel XGBoost training.
Example
-------
.. testcode::
import xgboost
import ray.data
import ray.train
from ray.train.xgboost import RayTrainReportCallback
from ray.train.xgboost.v2 import XGBoostTrainer
def train_fn_per_worker(config: dict):
from xgboost.collective import CommunicatorContext
# (Optional) Add logic to resume training state from a checkpoint.
# ray.train.get_checkpoint()
# 1. Get the dataset shard for the worker and convert to a `xgboost.DMatrix`
train_ds_iter, eval_ds_iter = (
ray.train.get_dataset_shard("train"),
ray.train.get_dataset_shard("validation"),
)
train_ds, eval_ds = train_ds_iter.materialize(), eval_ds_iter.materialize()
train_df, eval_df = train_ds.to_pandas(), eval_ds.to_pandas()
train_X, train_y = train_df.drop("y", axis=1), train_df["y"]
eval_X, eval_y = eval_df.drop("y", axis=1), eval_df["y"]
dtrain = xgboost.DMatrix(train_X, label=train_y)
deval = xgboost.DMatrix(eval_X, label=eval_y)
params = {
"tree_method": "approx",
"objective": "reg:squarederror",
"eta": 1e-4,
"subsample": 0.5,
"max_depth": 2,
}
# 2. Do distributed data-parallel training with the `CommunicatorContext`.
# Ray Train sets up the necessary coordinator processes and
# environment variables for your workers to communicate with each other.
with CommunicatorContext():
bst = xgboost.train(
params,
dtrain=dtrain,
evals=[(deval, "validation")],
num_boost_round=10,
callbacks=[RayTrainReportCallback()],
)
train_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)])
eval_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(16)])
trainer = XGBoostTrainer(
train_fn_per_worker,
datasets={"train": train_ds, "validation": eval_ds},
scaling_config=ray.train.ScalingConfig(num_workers=4),
)
result = trainer.fit()
booster = RayTrainReportCallback.get_model(result.checkpoint)
.. testoutput::
:hide:
...
Args:
train_loop_per_worker: The training function to execute on each worker.
This function can either take in zero arguments or a single ``Dict``
argument which is set by defining ``train_loop_config``.
Within this function you can use any of the
:ref:`Ray Train Loop utilities <train-loop-api>`.
train_loop_config: A configuration ``Dict`` to pass in as an argument to
``train_loop_per_worker``.
This is typically used for specifying hyperparameters.
xgboost_config: The configuration for setting up the distributed xgboost
backend. Defaults to using the "rabit" backend.
See :class:`~ray.train.xgboost.XGBoostConfig` for more info.
datasets: The Ray Datasets to use for training and validation.
dataset_config: The configuration for ingesting the input ``datasets``.
By default, all the Ray Datasets are split equally across workers.
See :class:`~ray.train.DataConfig` for more details.
scaling_config: The configuration for how to scale data parallel training.
``num_workers`` determines how many Python processes are used for training,
and ``use_gpu`` determines whether or not each process should use GPUs.
See :class:`~ray.train.ScalingConfig` for more info.
run_config: The configuration for the execution of the training run.
See :class:`~ray.train.RunConfig` for more info.
resume_from_checkpoint: A checkpoint to resume training from.
This checkpoint can be accessed from within ``train_loop_per_worker``
by calling ``ray.train.get_checkpoint()``.
metadata: Dict that should be made available via
`ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
for checkpoints saved from this Trainer. Must be JSON-serializable.
"""

def __init__(
self,
train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]],
*,
train_loop_config: Optional[Dict] = None,
xgboost_config: Optional[XGBoostConfig] = None,
scaling_config: Optional[ray.train.ScalingConfig] = None,
run_config: Optional[ray.train.RunConfig] = None,
datasets: Optional[Dict[str, GenDataset]] = None,
dataset_config: Optional[ray.train.DataConfig] = None,
metadata: Optional[Dict[str, Any]] = None,
resume_from_checkpoint: Optional[Checkpoint] = None,
):
super(XGBoostTrainer, self).__init__(
train_loop_per_worker=train_loop_per_worker,
train_loop_config=train_loop_config,
backend_config=xgboost_config or XGBoostConfig(),
scaling_config=scaling_config,
dataset_config=dataset_config,
run_config=run_config,
datasets=datasets,
resume_from_checkpoint=resume_from_checkpoint,
metadata=metadata,
)
Loading

0 comments on commit 62dbcb2

Please sign in to comment.