Skip to content

Commit

Permalink
[train] RayTrainReportCallback should only save a checkpoint on ran…
Browse files Browse the repository at this point in the history
…k 0 for xgboost/lightgbm (#45083)

This PR adds a condition to only save and report a checkpoint on the
rank 0 worker for xgboost and lightgbm. This prevents unnecessary
checkpoints being created, since all data parallel workers have the same
model states. Note: this also accounts for usage in Tune, where
`ray.train.get_context().get_world_rank()` returns `None`.

Fix `checkpoint_at_end` for the xgboost callback to avoid duplicate checkpoints.
---------

Signed-off-by: Justin Yu <[email protected]>
Co-authored-by: Hongpeng Guo <[email protected]>
  • Loading branch information
justinvyu and hongpeng-guo committed May 9, 2024
1 parent 45c2c4f commit 112e859
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 38 deletions.
30 changes: 17 additions & 13 deletions python/ray/train/lightgbm/_lightgbm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from lightgbm.basic import Booster
from lightgbm.callback import CallbackEnv

from ray import train
import ray.train
from ray.train import Checkpoint
from ray.tune.utils import flatten_dict
from ray.util.annotations import PublicAPI
Expand Down Expand Up @@ -142,25 +142,29 @@ def _get_eval_result(self, env: CallbackEnv) -> dict:

@contextmanager
def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]:
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix())
yield Checkpoint.from_directory(temp_checkpoint_dir)
if ray.train.get_context().get_world_rank() in (0, None):
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix())
yield Checkpoint.from_directory(temp_checkpoint_dir)
else:
yield None

def __call__(self, env: CallbackEnv) -> None:
eval_result = self._get_eval_result(env)
report_dict = self._get_report_dict(eval_result)

# Ex: if frequency=2, checkpoint_at_end=True and num_boost_rounds=11,
# you will checkpoint at iterations 1, 3, 5, ..., 9, and 10 (checkpoint_at_end)
# (iterations count from 0)
on_last_iter = env.iteration == env.end_iteration - 1
checkpointing_disabled = self._frequency == 0
# Ex: if frequency=2, checkpoint_at_end=True and num_boost_rounds=10,
# you will checkpoint at iterations 1, 3, 5, ..., and 9 (checkpoint_at_end)
# (counting from 0)
should_checkpoint = (
not checkpointing_disabled and (env.iteration + 1) % self._frequency == 0
) or (on_last_iter and self._checkpoint_at_end)
should_checkpoint_at_end = on_last_iter and self._checkpoint_at_end
should_checkpoint_with_frequency = (
self._frequency != 0 and (env.iteration + 1) % self._frequency == 0
)
should_checkpoint = should_checkpoint_at_end or should_checkpoint_with_frequency

if should_checkpoint:
with self._get_checkpoint(model=env.model) as checkpoint:
train.report(report_dict, checkpoint=checkpoint)
ray.train.report(report_dict, checkpoint=checkpoint)
else:
train.report(report_dict)
ray.train.report(report_dict)
32 changes: 27 additions & 5 deletions python/ray/train/tests/test_lightgbm_trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from unittest import mock

import lightgbm as lgbm
import pandas as pd
Expand All @@ -10,7 +11,7 @@
from ray import tune
from ray.train import ScalingConfig
from ray.train.constants import TRAIN_DATASET_KEY
from ray.train.lightgbm import LightGBMTrainer
from ray.train.lightgbm import LightGBMTrainer, RayTrainReportCallback


@pytest.fixture
Expand Down Expand Up @@ -101,10 +102,11 @@ def test_resume_from_checkpoint(ray_start_6_cpus, tmpdir):
@pytest.mark.parametrize(
"freq_end_expected",
[
(4, True, 7), # 4, 8, 12, 16, 20, 24, 25
(4, False, 6), # 4, 8, 12, 16, 20, 24
(5, True, 5), # 5, 10, 15, 20, 25
(0, True, 1),
# With num_boost_round=25 with 0 indexing, the checkpoints will be at:
(4, True, 7), # 3, 7, 11, 15, 19, 23, 24 (end)
(4, False, 6), # 3, 7, 11, 15, 19, 23
(5, True, 5), # 4, 9, 14, 19, 24
(0, True, 1), # 24 (end)
(0, False, 0),
],
)
Expand Down Expand Up @@ -166,6 +168,26 @@ def test_validation(ray_start_6_cpus):
)


@pytest.mark.parametrize("rank", [None, 0, 1])
def test_checkpoint_only_on_rank0(rank):
"""Tests that the callback only reports checkpoints on rank 0,
or if the rank is not available (Tune usage)."""
callback = RayTrainReportCallback(frequency=2, checkpoint_at_end=True)

booster = mock.MagicMock()

with mock.patch("ray.train.get_context") as mock_get_context:
mock_context = mock.MagicMock()
mock_context.get_world_rank.return_value = rank
mock_get_context.return_value = mock_context

with callback._get_checkpoint(booster) as checkpoint:
if rank in (0, None):
assert checkpoint
else:
assert not checkpoint


if __name__ == "__main__":
import sys

Expand Down
38 changes: 26 additions & 12 deletions python/ray/train/tests/test_xgboost_trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import json
from unittest import mock

import pandas as pd
import pytest
Expand Down Expand Up @@ -43,11 +43,6 @@ def ray_start_8_cpus():
}


def get_num_trees(booster: xgb.Booster) -> int:
data = [json.loads(d) for d in booster.get_dump(dump_format="json")]
return len(data)


def test_fit(ray_start_4_cpus):
train_dataset = ray.data.from_pandas(train_df)
valid_dataset = ray.data.from_pandas(test_df)
Expand Down Expand Up @@ -114,12 +109,11 @@ def test_resume_from_checkpoint(ray_start_4_cpus, tmpdir):
@pytest.mark.parametrize(
"freq_end_expected",
[
(4, True, 7), # 4, 8, 12, 16, 20, 24, 25
(4, False, 6), # 4, 8, 12, 16, 20, 24
# TODO(justinvyu): [simplify_xgb]
# Fix this checkpoint_at_end/checkpoint_frequency overlap behavior.
# (5, True, 5), # 5, 10, 15, 20, 25
(0, True, 1), # end
# With num_boost_round=25 with 0 indexing, the checkpoints will be at:
(4, True, 7), # 3, 7, 11, 15, 19, 23, 24 (end)
(4, False, 6), # 3, 7, 11, 15, 19, 23
(5, True, 5), # 4, 9, 14, 19, 24
(0, True, 1), # 24 (end)
(0, False, 0),
],
)
Expand Down Expand Up @@ -152,6 +146,26 @@ def test_checkpoint_freq(ray_start_4_cpus, freq_end_expected):
assert cp_paths == sorted(cp_paths), str(cp_paths)


@pytest.mark.parametrize("rank", [None, 0, 1])
def test_checkpoint_only_on_rank0(rank):
"""Tests that the callback only reports checkpoints on rank 0,
or if the rank is not available (Tune usage)."""
callback = RayTrainReportCallback(frequency=2, checkpoint_at_end=True)

booster = mock.MagicMock()

with mock.patch("ray.train.get_context") as mock_get_context:
mock_context = mock.MagicMock()
mock_context.get_world_rank.return_value = rank
mock_get_context.return_value = mock_context

with callback._get_checkpoint(booster) as checkpoint:
if rank in (0, None):
assert checkpoint
else:
assert not checkpoint


def test_tune(ray_start_8_cpus):
train_dataset = ray.data.from_pandas(train_df)
valid_dataset = ray.data.from_pandas(test_df)
Expand Down
32 changes: 24 additions & 8 deletions python/ray/train/xgboost/_xgboost_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from xgboost.core import Booster

from ray import train
import ray.train
from ray.train import Checkpoint
from ray.tune.utils import flatten_dict
from ray.util.annotations import PublicAPI
Expand Down Expand Up @@ -118,6 +118,9 @@ def __init__(
# so that the latest metrics can be reported with the checkpoint
# at the end of training.
self._evals_log = None
# Keep track of the last checkpoint iteration to avoid double-checkpointing
# when using `checkpoint_at_end=True`.
self._last_checkpoint_iteration = None

@classmethod
def get_model(
Expand Down Expand Up @@ -163,9 +166,13 @@ def _get_report_dict(self, evals_log):

@contextmanager
def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]:
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix())
yield Checkpoint(temp_checkpoint_dir)
# NOTE: The world rank returns None for Tune usage without Train.
if ray.train.get_context().get_world_rank() in (0, None):
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix())
yield Checkpoint(temp_checkpoint_dir)
else:
yield None

def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):
self._evals_log = evals_log
Expand All @@ -178,17 +185,26 @@ def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):

report_dict = self._get_report_dict(evals_log)
if should_checkpoint:
self._last_checkpoint_iteration = epoch
with self._get_checkpoint(model=model) as checkpoint:
train.report(report_dict, checkpoint=checkpoint)
ray.train.report(report_dict, checkpoint=checkpoint)
else:
train.report(report_dict)
ray.train.report(report_dict)

def after_training(self, model: Booster):
def after_training(self, model: Booster) -> Booster:
if not self._checkpoint_at_end:
return model

if (
self._last_checkpoint_iteration is not None
and model.num_boosted_rounds() - 1 == self._last_checkpoint_iteration
):
# Avoids a duplicate checkpoint if the checkpoint frequency happens
# to align with the last iteration.
return model

report_dict = self._get_report_dict(self._evals_log) if self._evals_log else {}
with self._get_checkpoint(model=model) as checkpoint:
train.report(report_dict, checkpoint=checkpoint)
ray.train.report(report_dict, checkpoint=checkpoint)

return model

0 comments on commit 112e859

Please sign in to comment.