Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train] New persistence mode: Deprecate experimental distributed checkpointing configs #39279

Merged
merged 2 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions doc/source/train/user-guides/checkpoints.rst
Original file line number Diff line number Diff line change
Expand Up @@ -324,12 +324,6 @@ checkpoints to disk), a :py:class:`~ray.train.CheckpointConfig` can be passed in
:py:class:`~ray.train.CheckpointConfig`,
please ensure that the metric is always reported together with the checkpoints.

**[Experimental] Distributed Checkpoints**: For model parallel workloads where the models do not fit in a single GPU worker,
it will be important to save and upload the model that is partitioned across different workers. You
can enable this by setting `_checkpoint_keep_all_ranks=True` to retain the model checkpoints across workers,
and `_checkpoint_upload_from_workers=True` to upload their checkpoints to cloud directly in :class:`~ray.train.CheckpointConfig`. This functionality works for any trainer that inherits from :class:`~ray.train.data_parallel_trainer.DataParallelTrainer`.



.. _train-dl-loading-checkpoints:

Expand Down
42 changes: 31 additions & 11 deletions python/ray/air/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

MAX = "max"
MIN = "min"
_DEPRECATED_VALUE = "DEPRECATED"


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -608,26 +609,45 @@ class CheckpointConfig:
This attribute is only supported by trainers that don't take in
custom training loops. Defaults to True for trainers that support it
and False for generic function trainables.
_checkpoint_keep_all_ranks: If True, will save checkpoints from all ranked
training workers. If False, only checkpoint from rank 0 worker is kept.
NOTE: This API is experimental and subject to change between minor
releases.
_checkpoint_upload_from_workers: If True, distributed workers
will upload their checkpoints to cloud directly. This is to avoid the
need for transferring large checkpoint files to the training worker
group coordinator for persistence. NOTE: This API is experimental and
subject to change between minor releases.
_checkpoint_keep_all_ranks: This experimental config is deprecated.
This behavior is now controlled by reporting `checkpoint=None`
in the workers that shouldn't persist a checkpoint.
For example, if you only want the rank 0 worker to persist a checkpoint
(e.g., in standard data parallel training), then you should save and
report a checkpoint if `ray.train.get_context().get_world_rank() == 0`
and `None` otherwise.
_checkpoint_upload_from_workers: This experimental config is deprecated.
Uploading checkpoint directly from the worker is now the default behavior.
"""

num_to_keep: Optional[int] = None
checkpoint_score_attribute: Optional[str] = None
checkpoint_score_order: Optional[str] = MAX
checkpoint_frequency: Optional[int] = 0
checkpoint_at_end: Optional[bool] = None
_checkpoint_keep_all_ranks: Optional[bool] = False
_checkpoint_upload_from_workers: Optional[bool] = False
_checkpoint_keep_all_ranks: Optional[bool] = _DEPRECATED_VALUE
_checkpoint_upload_from_workers: Optional[bool] = _DEPRECATED_VALUE

def __post_init__(self):
if self._checkpoint_keep_all_ranks != _DEPRECATED_VALUE:
raise DeprecationWarning(
"The experimental `_checkpoint_keep_all_ranks` config is deprecated. "
"This behavior is now controlled by reporting `checkpoint=None` "
"in the workers that shouldn't persist a checkpoint. "
"For example, if you only want the rank 0 worker to persist a "
"checkpoint (e.g., in standard data parallel training), "
"then you should save and report a checkpoint if "
"`ray.train.get_context().get_world_rank() == 0` "
"and `None` otherwise."
)

if self._checkpoint_upload_from_workers != _DEPRECATED_VALUE:
raise DeprecationWarning(
"The experimental `_checkpoint_upload_from_workers` config is "
"deprecated. Uploading checkpoint directly from the worker is "
"now the default behavior."
)

if self.num_to_keep is not None and self.num_to_keep <= 0:
raise ValueError(
f"Received invalid num_to_keep: "
Expand Down
1 change: 0 additions & 1 deletion python/ray/train/tests/test_lightning_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def test_deepspeed_stages(ray_start_6_cpus_4_gpus, tmpdir, stage, test_restore):
num_to_keep=3,
checkpoint_score_attribute="val_loss",
checkpoint_score_order="min",
_checkpoint_keep_all_ranks=True,
),
),
)
Expand Down
18 changes: 0 additions & 18 deletions python/ray/tune/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,6 @@ def run(
checkpoint_score_attr: Optional[str] = None, # Deprecated (2.7)
checkpoint_freq: int = 0, # Deprecated (2.7)
checkpoint_at_end: bool = False, # Deprecated (2.7)
checkpoint_keep_all_ranks: bool = False, # Deprecated (2.7)
checkpoint_upload_from_workers: bool = False, # Deprecated (2.7)
chdir_to_trial_dir: bool = _DEPRECATED_VALUE, # Deprecated (2.8)
local_dir: Optional[str] = None,
# == internal only ==
Expand Down Expand Up @@ -740,22 +738,6 @@ class and registered trainables.
DeprecationWarning,
)
checkpoint_config.checkpoint_at_end = checkpoint_at_end
if checkpoint_keep_all_ranks:
warnings.warn(
"checkpoint_keep_all_ranks is deprecated and will be removed. "
"use checkpoint_config._checkpoint_keep_all_ranks instead.",
DeprecationWarning,
)
checkpoint_config._checkpoint_keep_all_ranks = checkpoint_keep_all_ranks
if checkpoint_upload_from_workers:
warnings.warn(
"checkpoint_upload_from_workers is deprecated and will be removed. "
"use checkpoint_config._checkpoint_upload_from_workers instead.",
DeprecationWarning,
)
checkpoint_config._checkpoint_upload_from_workers = (
checkpoint_upload_from_workers
)

if chdir_to_trial_dir != _DEPRECATED_VALUE:
warnings.warn(
Expand Down
Loading