Skip to content

Commit

Permalink
[Train] Make get_checkpoint top level function in ray.train (ray-proj…
Browse files Browse the repository at this point in the history
…ect#37906)

This is bringing the API up-to-date with ray-project/enhancements#36
  • Loading branch information
pcmoritz committed Jul 29, 2023
1 parent a193ad6 commit f8a4a79
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 6 deletions.
3 changes: 2 additions & 1 deletion python/ray/train/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ray._private.usage import usage_lib
from ray.train._internal.data_config import DataConfig
from ray.train._internal.session import get_dataset_shard, report
from ray.train._internal.session import get_checkpoint, get_dataset_shard, report
from ray.train.backend import BackendConfig
from ray.train.constants import TRAIN_DATASET_KEY
from ray.train.context import get_context
Expand All @@ -13,6 +13,7 @@
usage_lib.record_library_usage("train")

__all__ = [
"get_checkpoint",
"get_context",
"get_dataset_shard",
"report",
Expand Down
5 changes: 0 additions & 5 deletions python/ray/train/context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import threading
from typing import TYPE_CHECKING, Optional

from ray.air import Checkpoint
from ray.train._internal import session
from ray.util.annotations import PublicAPI

Expand All @@ -27,10 +26,6 @@ def wrapped(func):
class TrainContext:
"""Context for Ray training executions."""

@_copy_doc(session.get_checkpoint)
def get_checkpoint(self) -> Optional[Checkpoint]:
return session.get_checkpoint()

@_copy_doc(session.get_experiment_name)
def get_experiment_name(self) -> str:
return session.get_experiment_name()
Expand Down

0 comments on commit f8a4a79

Please sign in to comment.