Skip to content

Commit

Permalink
[train] Simplify ray.train.xgboost/lightgbm (1/n): Align frequency-…
Browse files Browse the repository at this point in the history
…based and `checkpoint_at_end` checkpoint formats (#42111)

Centralizes checkpoint saving and loading implementations around the utility callbacks `ray.train.xgboost/lightgbm.RayTrainReportCallback`.

---------

Signed-off-by: Justin Yu <[email protected]>
  • Loading branch information
justinvyu committed Feb 13, 2024
1 parent 1cc18ea commit eb8950b
Show file tree
Hide file tree
Showing 21 changed files with 588 additions and 1,264 deletions.
2 changes: 2 additions & 0 deletions doc/source/train/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ XGBoost
:toctree: doc/

~train.xgboost.XGBoostTrainer
~train.xgboost.RayTrainReportCallback


LightGBM
Expand All @@ -106,6 +107,7 @@ LightGBM
:toctree: doc/

~train.lightgbm.LightGBMTrainer
~train.lightgbm.RayTrainReportCallback


.. _ray-train-configs-api:
Expand Down
5 changes: 2 additions & 3 deletions doc/source/tune/api/integration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ PyTorch Lightning (tune.integration.pytorch_lightning)
:nosignatures:
:toctree: doc/

~tune.integration.pytorch_lightning.TuneReportCallback
~tune.integration.pytorch_lightning.TuneReportCheckpointCallback

.. _tune-integration-xgboost:
Expand All @@ -24,9 +23,9 @@ XGBoost (tune.integration.xgboost)

.. autosummary::
:nosignatures:
:template: autosummary/class_without_autosummary.rst
:toctree: doc/

~tune.integration.xgboost.TuneReportCallback
~tune.integration.xgboost.TuneReportCheckpointCallback


Expand All @@ -37,7 +36,7 @@ LightGBM (tune.integration.lightgbm)

.. autosummary::
:nosignatures:
:template: autosummary/class_without_autosummary.rst
:toctree: doc/

~tune.integration.lightgbm.TuneReportCallback
~tune.integration.lightgbm.TuneReportCheckpointCallback
676 changes: 53 additions & 623 deletions doc/source/tune/examples/tune-xgboost.ipynb

Large diffs are not rendered by default.

65 changes: 27 additions & 38 deletions python/ray/train/gbdt_trainer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import logging
import os
import tempfile
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Optional, Type

from ray import train, tune
from ray._private.dict import flatten_dict
from ray.train import Checkpoint, RunConfig, ScalingConfig
from ray.train.constants import MODEL_KEY, TRAIN_DATASET_KEY
from ray.train.constants import TRAIN_DATASET_KEY
from ray.train.trainer import BaseTrainer, GenDataset
from ray.tune import Trainable
from ray.tune.execution.placement_groups import PlacementGroupFactory
Expand Down Expand Up @@ -224,17 +220,28 @@ def _get_dmatrices(
for k, v in self.datasets.items()
}

@classmethod
def get_model(cls, checkpoint: Checkpoint, checkpoint_cls: Type[Any]) -> Any:
raise NotImplementedError

def _load_checkpoint(
self,
checkpoint: Checkpoint,
) -> Any:
raise NotImplementedError
# TODO(justinvyu): [code_removal] Remove in 2.11.
raise DeprecationWarning(
"The internal method `_load_checkpoint` deprecated and will be removed. "
f"See `{self.__class__.__name__}.get_model` instead."
)

def _train(self, **kwargs):
raise NotImplementedError

def _save_model(self, model: Any, path: str):
raise NotImplementedError
# TODO(justinvyu): [code_removal] Remove in 2.11.
raise DeprecationWarning(
"The internal method `_save_model` is deprecated and will be removed."
)

def _model_iteration(self, model: Any) -> int:
raise NotImplementedError
Expand Down Expand Up @@ -269,24 +276,6 @@ def _repartition_datasets_to_match_num_actors(self):
self._ray_params.num_actors
)

def _checkpoint_at_end(self, model, evals_result: dict) -> None:
# We need to call session.report to save checkpoints, so we report
# the last received metrics (possibly again).
result_dict = flatten_dict(evals_result, delimiter="-")
for k in list(result_dict):
result_dict[k] = result_dict[k][-1]

if getattr(self._tune_callback_checkpoint_cls, "_report_callbacks_cls", None):
# Deprecate: Remove in Ray 2.8
with tune.checkpoint_dir(step=self._model_iteration(model)) as cp_dir:
self._save_model(model, path=os.path.join(cp_dir, MODEL_KEY))
tune.report(**result_dict)
else:
with tempfile.TemporaryDirectory() as checkpoint_dir:
self._save_model(model, path=checkpoint_dir)
checkpoint = Checkpoint.from_directory(checkpoint_dir)
train.report(result_dict, checkpoint=checkpoint)

def training_loop(self) -> None:
config = self.train_kwargs.copy()
config[self._num_iterations_argument] = self.num_boost_round
Expand All @@ -299,21 +288,28 @@ def training_loop(self) -> None:

init_model = None
if self.starting_checkpoint:
init_model = self._load_checkpoint(self.starting_checkpoint)
init_model = self.__class__.get_model(self.starting_checkpoint)

config.setdefault("verbose_eval", False)
config.setdefault("callbacks", [])

if not any(
has_user_supplied_callback = any(
isinstance(cb, self._tune_callback_checkpoint_cls)
for cb in config["callbacks"]
):
# Only add our own callback if it hasn't been added before
)
if not has_user_supplied_callback:
# Only add our own default callback if the user hasn't supplied one.
checkpoint_frequency = (
self.run_config.checkpoint_config.checkpoint_frequency
)

checkpoint_at_end = self.run_config.checkpoint_config.checkpoint_at_end
if checkpoint_at_end is None:
# Defaults to True
checkpoint_at_end = True

callback = self._tune_callback_checkpoint_cls(
filename=MODEL_KEY, frequency=checkpoint_frequency
frequency=checkpoint_frequency, checkpoint_at_end=checkpoint_at_end
)

config["callbacks"] += [callback]
Expand All @@ -336,7 +332,7 @@ def training_loop(self) -> None:
f"({self._num_iterations_argument}={num_iterations})."
)

model = self._train(
self._train(
params=self.params,
dtrain=train_dmatrix,
evals_result=evals_result,
Expand All @@ -345,13 +341,6 @@ def training_loop(self) -> None:
**config,
)

checkpoint_at_end = self.run_config.checkpoint_config.checkpoint_at_end
if checkpoint_at_end is None:
checkpoint_at_end = True

if checkpoint_at_end:
self._checkpoint_at_end(model, evals_result)

def _generate_trainable_cls(self) -> Type["Trainable"]:
trainable_cls = super()._generate_trainable_cls()
trainer_cls = self.__class__
Expand Down
2 changes: 2 additions & 0 deletions python/ray/train/lightgbm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from ray.train.lightgbm._lightgbm_utils import RayTrainReportCallback
from ray.train.lightgbm.lightgbm_checkpoint import LightGBMCheckpoint
from ray.train.lightgbm.lightgbm_predictor import LightGBMPredictor
from ray.train.lightgbm.lightgbm_trainer import LightGBMTrainer

__all__ = [
"RayTrainReportCallback",
"LightGBMCheckpoint",
"LightGBMPredictor",
"LightGBMTrainer",
Expand Down
166 changes: 166 additions & 0 deletions python/ray/train/lightgbm/_lightgbm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import tempfile
from contextlib import contextmanager
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union

from lightgbm.basic import Booster
from lightgbm.callback import CallbackEnv

from ray import train
from ray.train import Checkpoint
from ray.tune.utils import flatten_dict
from ray.util.annotations import PublicAPI


@PublicAPI(stability="beta")
class RayTrainReportCallback:
"""Creates a callback that reports metrics and checkpoints model.
Args:
metrics: Metrics to report. If this is a list,
each item should be a metric key reported by LightGBM,
and it will be reported to Ray Train/Tune under the same name.
This can also be a dict of {<key-to-report>: <lightgbm-metric-key>},
which can be used to rename LightGBM default metrics.
filename: Customize the saved checkpoint file type by passing
a filename. Defaults to "model.txt".
frequency: How often to save checkpoints, in terms of iterations.
Defaults to 0 (no checkpoints are saved during training).
checkpoint_at_end: Whether or not to save a checkpoint at the end of training.
results_postprocessing_fn: An optional Callable that takes in
the metrics dict that will be reported (after it has been flattened)
and returns a modified dict.
Examples
--------
Reporting checkpoints and metrics to Ray Tune when running many
independent xgboost trials (without data parallelism within a trial).
.. testcode::
:skipif: True
import lightgbm
from ray.train.lightgbm import RayTrainReportCallback
config = {
# ...
"metric": ["binary_logloss", "binary_error"],
}
# Report only log loss to Tune after each validation epoch.
bst = lightgbm.train(
...,
callbacks=[
RayTrainReportCallback(
metrics={"loss": "eval-binary_logloss"}, frequency=1
)
],
)
Loading a model from a checkpoint reported by this callback.
.. testcode::
:skipif: True
from ray.train.lightgbm import RayTrainReportCallback
# Get a `Checkpoint` object that is saved by the callback during training.
result = trainer.fit()
booster = RayTrainReportCallback.get_model(result.checkpoint)
"""

CHECKPOINT_NAME = "model.txt"

def __init__(
self,
metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
filename: str = CHECKPOINT_NAME,
frequency: int = 0,
checkpoint_at_end: bool = True,
results_postprocessing_fn: Optional[
Callable[[Dict[str, Union[float, List[float]]]], Dict[str, float]]
] = None,
):
if isinstance(metrics, str):
metrics = [metrics]
self._metrics = metrics
self._filename = filename
self._frequency = frequency
self._checkpoint_at_end = checkpoint_at_end
self._results_postprocessing_fn = results_postprocessing_fn

@classmethod
def get_model(
cls, checkpoint: Checkpoint, filename: str = CHECKPOINT_NAME
) -> Booster:
"""Retrieve the model stored in a checkpoint reported by this callback.
Args:
checkpoint: The checkpoint object returned by a training run.
The checkpoint should be saved by an instance of this callback.
filename: The filename to load the model from, which should match
the filename used when creating the callback.
"""
with checkpoint.as_directory() as checkpoint_path:
return Booster(model_file=Path(checkpoint_path, filename).as_posix())

def _get_report_dict(self, evals_log: Dict[str, Dict[str, list]]) -> dict:
result_dict = flatten_dict(evals_log, delimiter="-")
if not self._metrics:
report_dict = result_dict
else:
report_dict = {}
for key in self._metrics:
if isinstance(self._metrics, dict):
metric = self._metrics[key]
else:
metric = key
report_dict[key] = result_dict[metric]
if self._results_postprocessing_fn:
report_dict = self._results_postprocessing_fn(report_dict)
return report_dict

def _get_eval_result(self, env: CallbackEnv) -> dict:
eval_result = {}
for entry in env.evaluation_result_list:
data_name, eval_name, result = entry[0:3]
if len(entry) > 4:
stdv = entry[4]
suffix = "-mean"
else:
stdv = None
suffix = ""
if data_name not in eval_result:
eval_result[data_name] = {}
eval_result[data_name][eval_name + suffix] = result
if stdv is not None:
eval_result[data_name][eval_name + "-stdv"] = stdv
return eval_result

@contextmanager
def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]:
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix())
yield Checkpoint.from_directory(temp_checkpoint_dir)

def __call__(self, env: CallbackEnv) -> None:
eval_result = self._get_eval_result(env)
report_dict = self._get_report_dict(eval_result)

on_last_iter = env.iteration == env.end_iteration - 1
checkpointing_disabled = self._frequency == 0
# Ex: if frequency=2, checkpoint_at_end=True and num_boost_rounds=10,
# you will checkpoint at iterations 1, 3, 5, ..., and 9 (checkpoint_at_end)
# (counting from 0)
should_checkpoint = (
not checkpointing_disabled and (env.iteration + 1) % self._frequency == 0
) or (on_last_iter and self._checkpoint_at_end)

if should_checkpoint:
with self._get_checkpoint(model=env.model) as checkpoint:
train.report(report_dict, checkpoint=checkpoint)
else:
train.report(report_dict)
15 changes: 12 additions & 3 deletions python/ray/train/lightgbm/lightgbm_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,17 @@ def from_model(
booster: lightgbm.Booster,
*,
preprocessor: Optional["Preprocessor"] = None,
path: Optional[str] = None,
) -> "LightGBMCheckpoint":
"""Create a :py:class:`~ray.train.Checkpoint` that stores a LightGBM model.
Args:
booster: The LightGBM model to store in the checkpoint.
preprocessor: A fitted preprocessor to be applied before inference.
path: The path to the directory where the checkpoint file will be saved.
This should start as an empty directory, since the *entire*
directory will be treated as the checkpoint when reported.
By default, a temporary directory will be created.
Returns:
An :py:class:`LightGBMCheckpoint` containing the specified ``Estimator``.
Expand All @@ -44,10 +49,14 @@ def from_model(
>>> model = lightgbm.LGBMClassifier().fit(train_X, train_y)
>>> checkpoint = LightGBMCheckpoint.from_model(model.booster_)
"""
tempdir = tempfile.mkdtemp()
booster.save_model(Path(tempdir, cls.MODEL_FILENAME).as_posix())
checkpoint_path = Path(path or tempfile.mkdtemp())

checkpoint = cls.from_directory(tempdir)
if not checkpoint_path.is_dir():
raise ValueError(f"`path` must be a directory, but got: {checkpoint_path}")

booster.save_model(checkpoint_path.joinpath(cls.MODEL_FILENAME).as_posix())

checkpoint = cls.from_directory(checkpoint_path.as_posix())
if preprocessor:
checkpoint.set_preprocessor(preprocessor)

Expand Down
Loading

0 comments on commit eb8950b

Please sign in to comment.