Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train] Storage: Change class trainable save_checkpoint implementations #38554

Merged
merged 7 commits into from
Aug 21, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update typehints
Signed-off-by: Kai Fricke <[email protected]>
  • Loading branch information
Kai Fricke committed Aug 17, 2023
commit 711eefafa8468af2d4020c602ad95b850f7a0835
18 changes: 7 additions & 11 deletions python/ray/tune/trainable/trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1371,7 +1371,7 @@ def step(self):
"""
raise NotImplementedError

def save_checkpoint(self, checkpoint_dir: str) -> Optional[Union[str, Dict]]:
def save_checkpoint(self, checkpoint_dir: str) -> Optional[Dict]:
"""Subclasses should override this to implement ``save()``.

Warning:
Expand All @@ -1395,11 +1395,9 @@ def save_checkpoint(self, checkpoint_dir: str) -> Optional[Union[str, Dict]]:
the provided path may be temporary and moved.

Returns:
A dict or string. If string, the return value is expected to be
the `checkpoint_dir`. If dict, the return value will
be automatically serialized by Tune. In both cases, the return value
is exactly what will be passed to ``Trainable.load_checkpoint()``
upon restore.
A dict or None. If dict, the return value will
be automatically serialized by Tune. In that case,
``Trainable.load_checkpoint()`` will receive the dict upon restore.

Example:
>>> trainable, trainable1, trainable2 = ... # doctest: +SKIP
Expand All @@ -1412,7 +1410,7 @@ def save_checkpoint(self, checkpoint_dir: str) -> Optional[Union[str, Dict]]:
"""
raise NotImplementedError

def load_checkpoint(self, checkpoint: Union[Dict, str]):
def load_checkpoint(self, checkpoint: Optional[Dict]):
"""Subclasses should override this to implement restore().

Warning:
Expand Down Expand Up @@ -1462,10 +1460,8 @@ def load_checkpoint(self, checkpoint: Union[Dict, str]):

Args:
checkpoint: If dict, the return value is as
returned by `save_checkpoint`. If a string, then it is
a checkpoint path that may have a different prefix than that
returned by `save_checkpoint`. The directory structure
underneath the `checkpoint_dir` from `save_checkpoint` is preserved.
returned by ``save_checkpoint``. Otherwise, the directory
the checkpoint was stored in.
"""
raise NotImplementedError

Expand Down
Loading