Skip to content

Commit

Permalink
Annotate more api (ray-project#26501)
Browse files Browse the repository at this point in the history
Signed-off-by: Stefan van der Kleij <[email protected]>
  • Loading branch information
amogkam authored and Stefan van der Kleij committed Aug 18, 2022
1 parent 4a62235 commit 09c3d37
Show file tree
Hide file tree
Showing 15 changed files with 35 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/ray/train/horovod/horovod_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@

from ray.train.data_parallel_trainer import DataParallelTrainer
from ray.train.horovod.config import HorovodConfig
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor


@PublicAPI(stability="alpha")
class HorovodTrainer(DataParallelTrainer):
"""A Trainer for data parallel Horovod training.
Expand Down
2 changes: 2 additions & 0 deletions python/ray/train/huggingface/huggingface_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
from ray.air.checkpoint import Checkpoint
from ray.air.constants import TENSOR_COLUMN_NAME
from ray.train.predictor import Predictor
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor


@PublicAPI(stability="alpha")
class HuggingFacePredictor(Predictor):
"""A predictor for HuggingFace Transformers PyTorch models.
Expand Down
2 changes: 2 additions & 0 deletions python/ray/train/huggingface/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
)
from ray.air._internal.torch_utils import load_torch_model
from ray.air.checkpoint import Checkpoint
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor


@PublicAPI(stability="alpha")
def load_checkpoint(
checkpoint: Checkpoint,
model: Union[Type[transformers.modeling_utils.PreTrainedModel], torch.nn.Module],
Expand Down
2 changes: 2 additions & 0 deletions python/ray/train/lightgbm/lightgbm_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from ray.air.constants import TENSOR_COLUMN_NAME
from ray.train.lightgbm.utils import load_checkpoint
from ray.train.predictor import Predictor
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor


@PublicAPI(stability="alpha")
class LightGBMPredictor(Predictor):
"""A predictor for LightGBM models.
Expand Down
3 changes: 3 additions & 0 deletions python/ray/train/lightgbm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
)
from ray.air.checkpoint import Checkpoint
from ray.air.constants import MODEL_KEY
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor


@PublicAPI(stability="alpha")
def to_air_checkpoint(
path: str,
booster: lightgbm.Booster,
Expand All @@ -39,6 +41,7 @@ def to_air_checkpoint(
return checkpoint


@PublicAPI(stability="alpha")
def load_checkpoint(
checkpoint: Checkpoint,
) -> Tuple[lightgbm.Booster, Optional["Preprocessor"]]:
Expand Down
2 changes: 2 additions & 0 deletions python/ray/train/rl/rl_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from ray.rllib.utils.typing import EnvType
from ray.train.predictor import Predictor
from ray.train.rl.utils import load_checkpoint
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor


@PublicAPI(stability="alpha")
class RLPredictor(Predictor):
"""A predictor for RLlib policies.
Expand Down
2 changes: 2 additions & 0 deletions python/ray/train/rl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.typing import EnvType
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
Expand All @@ -16,6 +17,7 @@
RL_CONFIG_FILE = "config.pkl"


@PublicAPI(stability="alpha")
def load_checkpoint(
checkpoint: Checkpoint,
env: Optional[EnvType] = None,
Expand Down
2 changes: 2 additions & 0 deletions python/ray/train/sklearn/sklearn_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
from ray.train.sklearn._sklearn_utils import _set_cpu_params
from ray.train.sklearn.utils import load_checkpoint
from ray.util.joblib import register_ray
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor


@PublicAPI(stability="alpha")
class SklearnPredictor(Predictor):
"""A predictor for scikit-learn compatible estimators.
Expand Down
3 changes: 3 additions & 0 deletions python/ray/train/sklearn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
)
from ray.air.checkpoint import Checkpoint
from ray.air.constants import MODEL_KEY
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor


@PublicAPI(stability="alpha")
def to_air_checkpoint(
path: str,
estimator: BaseEstimator,
Expand All @@ -41,6 +43,7 @@ def to_air_checkpoint(
return checkpoint


@PublicAPI(stability="alpha")
def load_checkpoint(
checkpoint: Checkpoint,
) -> Tuple[BaseEstimator, Optional["Preprocessor"]]:
Expand Down
2 changes: 2 additions & 0 deletions python/ray/train/tensorflow/tensorflow_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
from ray.air.checkpoint import Checkpoint
from ray.train.data_parallel_trainer import _load_checkpoint
from ray.train._internal.dl_predictor import DLPredictor
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor

logger = logging.getLogger(__name__)


@PublicAPI(stability="alpha")
class TensorflowPredictor(DLPredictor):
"""A predictor for TensorFlow models.
Expand Down
3 changes: 3 additions & 0 deletions python/ray/train/tensorflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from ray.air.checkpoint import Checkpoint
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY
from ray.train.data_parallel_trainer import _load_checkpoint
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor


@PublicAPI(stability="alpha")
def to_air_checkpoint(
model: keras.Model, preprocessor: Optional["Preprocessor"] = None
) -> Checkpoint:
Expand All @@ -29,6 +31,7 @@ def to_air_checkpoint(
return checkpoint


@PublicAPI(stability="alpha")
def load_checkpoint(
checkpoint: Checkpoint,
model: Union[Callable[[], tf.keras.Model], Type[tf.keras.Model], tf.keras.Model],
Expand Down
2 changes: 2 additions & 0 deletions python/ray/train/torch/torch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
from ray.air.checkpoint import Checkpoint
from ray.train.torch.utils import load_checkpoint
from ray.train._internal.dl_predictor import DLPredictor
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor

logger = logging.getLogger(__name__)


@PublicAPI(stability="alpha")
class TorchPredictor(DLPredictor):
"""A predictor for PyTorch models.
Expand Down
3 changes: 3 additions & 0 deletions python/ray/train/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY
from ray.train.data_parallel_trainer import _load_checkpoint
from ray.air._internal.torch_utils import load_torch_model
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor


@PublicAPI(stability="alpha")
def to_air_checkpoint(
model: torch.nn.Module, preprocessor: Optional["Preprocessor"] = None
) -> Checkpoint:
Expand All @@ -29,6 +31,7 @@ def to_air_checkpoint(
return checkpoint


@PublicAPI(stability="alpha")
def load_checkpoint(
checkpoint: Checkpoint, model: Optional[torch.nn.Module] = None
) -> Tuple[torch.nn.Module, Optional["Preprocessor"]]:
Expand Down
3 changes: 3 additions & 0 deletions python/ray/train/xgboost/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
)
from ray.air.checkpoint import Checkpoint
from ray.air.constants import MODEL_KEY
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor


@PublicAPI(stability="alpha")
def to_air_checkpoint(
path: str,
booster: xgboost.Booster,
Expand All @@ -39,6 +41,7 @@ def to_air_checkpoint(
return checkpoint


@PublicAPI(stability="alpha")
def load_checkpoint(
checkpoint: Checkpoint,
) -> Tuple[xgboost.Booster, Optional["Preprocessor"]]:
Expand Down
2 changes: 2 additions & 0 deletions python/ray/train/xgboost/xgboost_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
from ray.air.constants import TENSOR_COLUMN_NAME
from ray.train.predictor import Predictor
from ray.train.xgboost.utils import load_checkpoint
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor


@PublicAPI(stability="alpha")
class XGBoostPredictor(Predictor):
"""A predictor for XGBoost models.
Expand Down

0 comments on commit 09c3d37

Please sign in to comment.