-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[train] Simplify ray.train.xgboost/lightgbm
(2/n): Re-implement XGBoostTrainer
as a lightweight DataParallelTrainer
#42767
Conversation
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Very neat solution! I have some general questions:
|
train_ds = ray.train.get_dataset_shard("train")
train_df = train_ds.to_pandas()
X, y = train_df.drop("y", axis=1), train_df["y"]
dtrain = xgboost.DMatrix(X, label=y)
Two options:
bst_model = None
num_boost_rounds_per_iter =
for i in range(num_iters):
bst_model = xgboost.train(
..., xgb_model=bst_model, # start from bst_model
num_boost_rounds=num_boost_rounds_per_iter
)
ray.train.report(..., checkpoint=...) |
Here's a summary of the enhancements achieved by this proposal, once we fully migrate to the
Let me know if there are any "regressions" that I'm missing with this change. |
python/ray/train/xgboost/config.py
Outdated
# Set up the rabit tracker on the Train driver. | ||
num_workers = len(worker_group) | ||
self.rabit_args = {"DMLC_NUM_WORKER": num_workers} | ||
train_driver_ip = ray.util.get_node_ip_address() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be this IP or the rank 0 worker?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rabit process should be on the "driver", which in this case is the Trainer. All ranks connect to the driver rabit process.
Take a look at how this dask distributed xgboost test sets it up: https://github.com/dmlc/xgboost/blob/662854c7d75ef1ec543ee0db73098227de5be59c/tests/test_distributed/test_with_dask/test_with_dask.py#L1619-L1654
from xgboost.collective import CommunicatorContext | ||
|
||
with CommunicatorContext(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, this is the thing that actually connects the worker to the collective group (kind of like torch.init_process_group).
Usually, you need to add a bunch of args in here, but the environment variables that I set above take care of that. We could consider making this a ray train utility, but I feel like keeping the native usage is pretty simple.
I was trying to do it for the user here, but that didn't end up working, since the context needs to be directly wrapping the user code it seems.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see. @woshiyyya brought up a similar thing for Torch where we might want to set the torch device, which also needs to modify the user code since it has to be run in the same thread.
Maybe we can have some sort of decorator abstraction that surrounds the user's train function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I was trying to set torch cuda device by default, but it only works when we call it inside the training function. I am thinking we can have a function decorator, so that we can inject some environment setup in an elegant way.
e.g.
@ray.train.context(framework="xgboost")
def train_func():
...
(Need a better naming for this decorator..)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems that it's more crucial for the new XGBoostTrainer API. We could consider having this decorator so users don't have to think about it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, we may be able to get rid of the Trainer
s and mirror the Ray Core API more. 😆
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So clean!
from xgboost.collective import CommunicatorContext | ||
|
||
with CommunicatorContext(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems that it's more crucial for the new XGBoostTrainer API. We could consider having this decorator so users don't have to think about it.
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
…lify_xgb Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
…lify_xgb Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
XGBoostTrainer
as a lightweight DataParallelTrainer
ray.train.xgboost/lightgbm
(2/n): Re-implement XGBoostTrainer
as a lightweight DataParallelTrainer
Signed-off-by: Justin Yu <[email protected]>
@@ -20,6 +20,7 @@ def __init__( | |||
self, | |||
datasets_to_split: Union[Literal["all"], List[str]] = "all", | |||
execution_options: Optional[ExecutionOptions] = None, | |||
convert_to_data_iterator: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: This seems unclear for the people don't know about what data_iterator refers to. Consider rename it to streaming_execution=True
or materialize=False
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update: we're gonna have the user call DataIterator.materialize
instead.
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
python/ray/train/xgboost/config.py
Outdated
# Ranks are assigned in increasing order of the worker's task id. | ||
# This task id will be sorted by increasing world rank. | ||
os.environ["DMLC_TASK_ID"] = ( | ||
f"[xgboost.ray-rank={ray.train.get_context().get_world_rank()}]:" | ||
f"{ray.get_runtime_context().get_actor_id()}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a strict interface that this needs to follow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, this format was inspired by xgboost.dask
: https://github.com/dmlc/xgboost/blob/7cc256e246e68a6c641ecb57a138e3c8a721c55e/python-package/xgboost/dask/__init__.py#L237
The strings will just be sorted by world rank: https://github.com/ray-project/ray/pull/42767/files#diff-e40514dc5fac1d96905235c5090fbeeb747889a95a3dd86228f8737348911236R55-R60
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm so both here and in Dask it'll end up sorting by the rank strings rather than integers. I guess that's fine if Dask is doing it, but maybe we can update the documentation to match? (Or we can prepend some zeros to the world rank).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah, string comparison will mess up if you have more than 10 workers. I do actually want the ranks to all match up, so prepending a few 0s makes sense.
# TODO(justinvyu): [Deprecated] Remove in 2.11 | ||
if dmatrix_params != _DEPRECATED_VALUE: | ||
raise DeprecationWarning( | ||
"`dmatrix_params` is deprecated, since XGBoostTrainer no longer " | ||
"depends on the `xgboost_ray.RayDMatrix` utility." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any alternative needed here for functional parity?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the closest thing would be passing these params into the xgboost.DMatrix(...)
.
This would be a new feature though, since the original usage was to pass extra params as xgboost_ray.RayDMatrix
constructor args, but none of those apply anymore.
Let's just keep it as deprecated?
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eval_X, eval_y = eval_df.drop(label_column, axis=1), eval_df[label_column] | ||
evals.append((xgboost.DMatrix(eval_X, label=eval_y), eval_name)) | ||
|
||
with CommunicatorContext(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will we (eventually) move this into the train_func_context
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I'll add that in a followup!
TL;DR
This PR re-implements
XGBoostTrainer
as aDataParallelTrainer
that does not usexgboost_ray
under the hood, in an effort to unify the trainer implementations and remove that external dependency.Motivation
xgboost_ray
/lightgbm_ray
version.xgboost_ray
to no longer work →ray.train.xgboost.XGBoostTrainer
breaks → Ray Train team needs to patch a fix in this separate package, make a release, then update the pinned package version in CI.xgboost_ray
introduces significant code complexity.XGBoostTrainer
andLightGBMTrainer
are data parallel trainers, but go through a completely different code path asDataParallelTrainer
implementations.DataParallelTrainer
execution logic.TorchTrainer
unification effort.xgboost_ray
andlightgbm_ray
are designed to be run independently, so they implement their execution loop with resource scheduling logic and error handling. There is a huge overlap in the external libraries and Tune, and it’s very difficult to navigate between the 2 codebases as a maintainer.PR Summary
ray.train.xgboost.v2.XGBoostTrainer
andray.train.lightgbm.v2.LightGBMTrainer
that do not depend onxgboost_ray
andlightgbm_ray
.DataParallelTrainer
. Users are able to pass in their own training function.ray.train.xgboost.XGBoostTrainer
andray.train.lightgbm.LightGBMTrainer
on top of the v2 counterparts.xgboost_ray
andlightgbm_ray
dependencies from Ray Train starting immediately from 2.10. (Will do in a follow-up PR.)Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.