Skip to content

Commit

Permalink
[AIR] Distributed checkpointing (#34709)
Browse files Browse the repository at this point in the history
Signed-off-by: Jun Gong <[email protected]>
  • Loading branch information
Jun Gong committed May 12, 2023
1 parent 710c17a commit f936826
Show file tree
Hide file tree
Showing 16 changed files with 537 additions and 60 deletions.
59 changes: 40 additions & 19 deletions python/ray/air/_internal/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class _TrackedCheckpoint:
into `"evaluation/episode_reward_mean"`.
node_ip: IP of the node where the checkpoint was generated. Defaults
to the current node.
rank: Rank of the node where the checkpoint was generated. Defaults to 0.
"""

def __init__(
Expand All @@ -64,12 +65,14 @@ def __init__(
checkpoint_id: Optional[int] = None,
metrics: Optional[Dict] = None,
node_ip: Optional[str] = None,
rank: Optional[int] = 0,
):
from ray.tune.result import NODE_IP

self.dir_or_data = dir_or_data
self.id = checkpoint_id
self.storage_mode = storage_mode
self.rank = rank

self.metrics = flatten_dict(metrics) if metrics else {}
self.node_ip = node_ip or self.metrics.get(NODE_IP, None)
Expand Down Expand Up @@ -296,7 +299,7 @@ def __init__(
# always available).
self._checkpoints_to_clean_up = set()

self._delete_fn = delete_fn
self.set_delete_fn(delete_fn)

def set_delete_fn(
self, delete_fn: Optional[Callable[["_TrackedCheckpoint"], None]]
Expand All @@ -309,7 +312,10 @@ def set_delete_fn(
"""
self._delete_fn = delete_fn

def register_checkpoint(self, checkpoint: _TrackedCheckpoint):
def register_checkpoints(
self,
checkpoints: Union[_TrackedCheckpoint, List[_TrackedCheckpoint]],
):
"""Register new checkpoint and add to bookkeeping.
This method will register a new checkpoint and add it to the internal
Expand All @@ -318,23 +324,27 @@ def register_checkpoint(self, checkpoint: _TrackedCheckpoint):
checkpoints should be deleted.
Args:
checkpoint: Tracked checkpoint object to add to bookkeeping.
checkpoints: Tracked checkpoint object to add to bookkeeping.
"""
checkpoint.id = checkpoint.id or self._latest_checkpoint_id
if not isinstance(checkpoints, list):
checkpoints = [checkpoints]

if checkpoint.storage_mode == CheckpointStorage.MEMORY:
self._replace_latest_memory_checkpoint(checkpoint)
for checkpoint in checkpoints:
checkpoint.id = checkpoint.id or self._latest_checkpoint_id

if self._persist_memory_checkpoints:
persisted_checkpoint = copy.copy(checkpoint)
persisted_checkpoint.storage_mode = CheckpointStorage.PERSISTENT
if checkpoint.storage_mode == CheckpointStorage.MEMORY:
self._replace_latest_memory_checkpoint(checkpoint)

if self._persist_memory_checkpoints:
persisted_checkpoint = copy.copy(checkpoint)
persisted_checkpoint.storage_mode = CheckpointStorage.PERSISTENT
else:
persisted_checkpoint = None
else:
persisted_checkpoint = None
else:
persisted_checkpoint = checkpoint
persisted_checkpoint = checkpoint

if persisted_checkpoint and self._checkpoint_strategy.num_to_keep != 0:
self._process_persistent_checkpoint(persisted_checkpoint)
if persisted_checkpoint and self._checkpoint_strategy.num_to_keep != 0:
self._process_persistent_checkpoint(persisted_checkpoint)

self._latest_checkpoint_id += 1

Expand Down Expand Up @@ -405,29 +415,40 @@ def _get_checkpoint_score(
checkpoint.id,
)

def _process_persistent_checkpoint(self, checkpoint: _TrackedCheckpoint):
def _process_persistent_checkpoint(
self,
checkpoint: _TrackedCheckpoint,
next_checkpoint_path: Optional[str] = None,
):
# Note(jungong) : Track rank0 checkpoint as the best / worst checkpoint.
# That is because we only care about the data for checkpoints
# from non-rank0 workers. They do not represent a different Trial
# checkpoint as the rank0 one.
if checkpoint.rank > 0:
return

assert checkpoint.storage_mode == CheckpointStorage.PERSISTENT
next_checkpoint_path = next_checkpoint_path or self._get_next_checkpoint_path()

checkpoint_score = self._get_checkpoint_score(checkpoint)
wrapped_checkpoint = _HeapCheckpointWrapper(
priority=checkpoint_score, tracked_checkpoint=checkpoint
)

if self._checkpoint_strategy.num_to_keep is None:
# Keep all checkpoints
checkpoint.commit(path=self._get_next_checkpoint_path())
checkpoint.commit(path=next_checkpoint_path)
self._replace_latest_persisted_checkpoint(checkpoint)
self._top_persisted_checkpoints.append(wrapped_checkpoint)
elif (
len(self._top_persisted_checkpoints) < self._checkpoint_strategy.num_to_keep
):
checkpoint.commit(path=next_checkpoint_path)
# Heap is not full yet, so keep this checkpoint
checkpoint.commit(path=self._get_next_checkpoint_path())
heapq.heappush(self._top_persisted_checkpoints, wrapped_checkpoint)
self._replace_latest_persisted_checkpoint(checkpoint)
elif wrapped_checkpoint.priority >= self._top_persisted_checkpoints[0].priority:
checkpoint.commit(path=next_checkpoint_path)
# Priority is higher than current worst checkpoint, so replace worst
checkpoint.commit(path=self._get_next_checkpoint_path())
worst_checkpoint = heapq.heappushpop(
self._top_persisted_checkpoints, wrapped_checkpoint
).tracked_checkpoint
Expand Down
25 changes: 25 additions & 0 deletions python/ray/air/_internal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from contextlib import closing
import logging
import queue
import shutil
import threading
from typing import Optional

import numpy as np
from pathlib import Path

import ray
from ray.air.constants import _ERROR_REPORT_TIMEOUT
Expand Down Expand Up @@ -119,3 +121,26 @@ def join(self, timeout=None):
def _estimate_avail_object_store_memory() -> int:
"""Estimates total object store memory available in the cluster."""
return ray.available_resources()["object_store_memory"]


def _copy_dir_ignore_conflicts(src_dir: Path, dst_dir: Path):
"""This is a workaround for python < 3.8 where shutil.copytree does not
support dirs_exist_ok=True.
We will go through the content of the folder and manually copy ites,
while ignoring files that conflict.
TODO(jungong): remove this workaround when we drop support for python < 3.8.
"""
for inner in src_dir.iterdir():
dest = dst_dir / inner.name
if inner.is_dir():
if not dest.exists():
dest.mkdir(parents=True)
_copy_dir_ignore_conflicts(inner, dest)
else:
if not dest.exists():
shutil.copy2(str(inner.absolute()), str(dest.absolute()))
else:
# Ignore and don't overwrite the existing file.
pass
14 changes: 8 additions & 6 deletions python/ray/air/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
read_file_from_uri,
upload_to_uri,
)
from ray.air._internal.util import _copy_dir_ignore_conflicts
from ray.air.constants import PREPROCESSOR_KEY, CHECKPOINT_ID_ATTR
from ray.util.annotations import DeveloperAPI, PublicAPI

Expand Down Expand Up @@ -559,21 +560,22 @@ def _to_directory(self, path: str, move_instead_of_copy: bool = False) -> None:
if local_path:
local_path_pathlib = Path(local_path).resolve()
if local_path_pathlib != path_pathlib:
if path_pathlib.exists():
shutil.rmtree(str(path_pathlib.absolute()))
# If this exists on the local path, just copy over
if move_instead_of_copy:
os.makedirs(str(path_pathlib.absolute()), exist_ok=True)
self._local_path = str(path_pathlib.absolute())
for inner in local_path_pathlib.iterdir():
dest = path_pathlib / inner.name
if dest.exists():
# Ignore files that already exist.
# For example, checkpoints from every rank may all have
# a same .is_checkpoint file.
continue
shutil.move(
str(inner.absolute()), str(path_pathlib.absolute())
)
else:
shutil.copytree(
str(local_path_pathlib.absolute()),
str(path_pathlib.absolute()),
)
_copy_dir_ignore_conflicts(local_path_pathlib, path_pathlib)
elif external_path:
# If this exists on external storage (e.g. cloud), download
download_from_uri(uri=external_path, local_path=path, filelock=False)
Expand Down
12 changes: 11 additions & 1 deletion python/ray/air/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,14 +601,24 @@ class CheckpointConfig:
This attribute is only supported by trainers that don't take in
custom training loops. Defaults to True for trainers that support it
and False for generic function trainables.
_checkpoint_keep_all_ranks: If True, will save checkpoints from all ranked
training workers. If False, only checkpoint from rank 0 worker is kept.
NOTE: This API is experimental and subject to change between minor
releases.
_checkpoint_upload_from_workers: If True, distributed workers
will upload their checkpoints to cloud directly. This is to avoid the
need for transferring large checkpoint files to the training worker
group coordinator for persistence. NOTE: This API is experimental and
subject to change between minor releases.
"""

num_to_keep: Optional[int] = None
checkpoint_score_attribute: Optional[str] = None
checkpoint_score_order: str = MAX
checkpoint_frequency: int = 0
checkpoint_at_end: Optional[bool] = None
_checkpoint_keep_all_ranks: bool = False
_checkpoint_upload_from_workers: bool = False

def __post_init__(self):
if self.num_to_keep is not None and self.num_to_keep <= 0:
Expand Down
12 changes: 6 additions & 6 deletions python/ray/air/tests/test_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_unlimited_persistent_checkpoints():
cpm = _CheckpointManager(checkpoint_strategy=CheckpointConfig(num_to_keep=None))

for i in range(10):
cpm.register_checkpoint(
cpm.register_checkpoints(
_TrackedCheckpoint({"data": i}, storage_mode=CheckpointStorage.PERSISTENT)
)

Expand All @@ -22,7 +22,7 @@ def test_limited_persistent_checkpoints():
cpm = _CheckpointManager(checkpoint_strategy=CheckpointConfig(num_to_keep=2))

for i in range(10):
cpm.register_checkpoint(
cpm.register_checkpoints(
_TrackedCheckpoint({"data": i}, storage_mode=CheckpointStorage.PERSISTENT)
)

Expand All @@ -41,7 +41,7 @@ def __post_init__(self):
cpm = _CheckpointManager(checkpoint_strategy=_CheckpointConfig(num_to_keep=0))

for i in range(10):
cpm.register_checkpoint(
cpm.register_checkpoints(
_TrackedCheckpoint({"data": i}, storage_mode=CheckpointStorage.PERSISTENT)
)

Expand All @@ -53,7 +53,7 @@ def test_dont_persist_memory_checkpoints():
cpm._persist_memory_checkpoints = False

for i in range(10):
cpm.register_checkpoint(
cpm.register_checkpoints(
_TrackedCheckpoint({"data": i}, storage_mode=CheckpointStorage.MEMORY)
)

Expand All @@ -65,7 +65,7 @@ def test_persist_memory_checkpoints():
cpm._persist_memory_checkpoints = True

for i in range(10):
cpm.register_checkpoint(
cpm.register_checkpoints(
_TrackedCheckpoint({"data": i}, storage_mode=CheckpointStorage.MEMORY)
)

Expand All @@ -83,7 +83,7 @@ def test_keep_best_checkpoints():
cpm._persist_memory_checkpoints = True

for i in range(10):
cpm.register_checkpoint(
cpm.register_checkpoints(
_TrackedCheckpoint(
{"data": i},
storage_mode=CheckpointStorage.MEMORY,
Expand Down
29 changes: 29 additions & 0 deletions python/ray/air/tests/test_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import ray
from ray.air._internal.remote_storage import _ensure_directory, delete_at_uri
from ray.air._internal.uri_utils import URI
from ray.air._internal.util import _copy_dir_ignore_conflicts
from ray.air.checkpoint import _DICT_CHECKPOINT_ADDITIONAL_FILE_KEY, Checkpoint
from ray.air.constants import MAX_REPR_LENGTH, PREPROCESSOR_KEY
from ray.data import Preprocessor
Expand Down Expand Up @@ -159,6 +160,34 @@ def test_directory_move_instead_of_copy(self):
assert new_recovered_checkpoint.foo == "bar"
assert not list(Path(path).glob("*"))

def test_copy_dir_ignore_conflicts(self):
tmpdir = Path(tempfile.mkdtemp())

src_dir = tmpdir / "src"
dst_dir = tmpdir / "dst"

src_dir.mkdir()
dst_dir.mkdir()

(src_dir / "foo.txt").touch()
(src_dir / "bar.txt").touch()
(src_dir / "a").mkdir()
(src_dir / "a" / "a.txt").touch()
(src_dir / "b").mkdir()
(src_dir / "b" / "b.txt").touch()

# Has a file conflict.
(dst_dir / "foo.txt").touch()
# Has a directory conflict.
(dst_dir / "a").mkdir()

_copy_dir_ignore_conflicts(src_dir, dst_dir)

assert (dst_dir / "foo.txt").exists()
assert (dst_dir / "bar.txt").exists()
assert (dst_dir / "a" / "a.txt").exists()
assert (dst_dir / "b" / "b.txt").exists()

def test_uri(self):
checkpoint = StubCheckpoint.from_dict({"spam": "ham"})
assert "foo" in checkpoint._SERIALIZED_ATTRS
Expand Down
Loading

0 comments on commit f936826

Please sign in to comment.