Skip to content

Commit

Permalink
[Train] Move to beta (ray-project#20378)
Browse files Browse the repository at this point in the history
  • Loading branch information
amogkam committed Nov 16, 2021
1 parent ca90c63 commit 4f88796
Show file tree
Hide file tree
Showing 12 changed files with 34 additions and 8 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Ray is packaged with the following libraries for accelerating machine learning w

- `Tune`_: Scalable Hyperparameter Tuning
- `RLlib`_: Scalable Reinforcement Learning
- `Train`_: Distributed Deep Learning (alpha)
- `Train`_: Distributed Deep Learning (beta)
- `Datasets`_: Distributed Data Loading and Compute (beta)

As well as libraries for taking ML and distributed apps to production:
Expand Down
2 changes: 1 addition & 1 deletion doc/source/train/train.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ The main features are:

.. note::

This API is in its Alpha release (as of Ray 1.7) and may be revised in
This API is in its Beta release (as of Ray 1.9) and may be revised in
future Ray releases. If you encounter any bugs, please file an
`issue on GitHub`_.
If you are looking for the previous API documentation, see :ref:`sgd-index`.
Expand Down
3 changes: 3 additions & 0 deletions python/ray/train/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ray.train.session import init_session, get_session, shutdown_session
from ray.train.utils import RayDataset, check_for_failure, Singleton
from ray.train.worker_group import WorkerGroup
from ray.util.annotations import DeveloperAPI
from ray.util.placement_group import get_current_placement_group, \
remove_placement_group

Expand All @@ -24,6 +25,7 @@
logger = logging.getLogger(__name__)


@DeveloperAPI
class BackendConfig:
"""Parent class for configurations of training backend."""

Expand All @@ -32,6 +34,7 @@ def backend_cls(self):
raise NotImplementedError


@DeveloperAPI
class Backend(metaclass=Singleton):
"""Singleton for distributed communication backend.
Expand Down
2 changes: 2 additions & 0 deletions python/ray/train/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
TUNE_CHECKPOINT_ID
from ray.train.session import TrainingResult
from ray.train.utils import construct_path
from ray.util import PublicAPI

if TUNE_INSTALLED:
from ray import tune
Expand All @@ -24,6 +25,7 @@
logger = logging.getLogger(__name__)


@PublicAPI(stability="beta")
@dataclass
class CheckpointStrategy:
"""Configurable parameters for defining the Train checkpointing strategy.
Expand Down
2 changes: 2 additions & 0 deletions python/ray/train/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
from horovod.ray.runner import Coordinator
from horovod.ray.utils import detect_nics, nics_to_env_var
from horovod.runner.common.util import secret, timeout
from ray.util import PublicAPI

logger = logging.getLogger(__name__)


@PublicAPI(stability="beta")
@dataclass
class HorovodConfig(BackendConfig):
"""Configurations for Horovod setup.
Expand Down
8 changes: 8 additions & 0 deletions python/ray/train/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
DETAILED_AUTOFILLED_KEYS, TIME_THIS_ITER_S, PID, TIMESTAMP, TIME_TOTAL_S,
NODE_IP, TRAINING_ITERATION, HOSTNAME, DATE, RESULT_FETCH_TIMEOUT)
from ray.train.utils import PropagatingThread, RayDataset
from ray.util import PublicAPI


class TrainingResultType(Enum):
Expand Down Expand Up @@ -245,6 +246,7 @@ def shutdown_session():
_session = None


@PublicAPI(stability="beta")
def get_dataset_shard(
dataset_name: Optional[str] = None) -> Optional[RayDataset]:
"""Returns the Ray Dataset or DatasetPipeline shard for this worker.
Expand Down Expand Up @@ -299,6 +301,7 @@ def train_func():
return shard


@PublicAPI(stability="beta")
def report(**kwargs) -> None:
"""Reports all keyword arguments to Train as intermediate results.
Expand Down Expand Up @@ -326,6 +329,7 @@ def train_func():
session.report(**kwargs)


@PublicAPI(stability="beta")
def world_rank() -> int:
"""Get the world rank of this worker.
Expand All @@ -350,6 +354,7 @@ def train_func():
return session.world_rank


@PublicAPI(stability="beta")
def local_rank() -> int:
"""Get the local rank of this worker (rank of the worker on its node).
Expand All @@ -373,6 +378,7 @@ def train_func():
return session.local_rank


@PublicAPI(stability="beta")
def load_checkpoint() -> Optional[Dict]:
"""Loads checkpoint data onto the worker.
Expand Down Expand Up @@ -403,6 +409,7 @@ def train_func():
return session.loaded_checkpoint


@PublicAPI(stability="beta")
def save_checkpoint(**kwargs) -> None:
"""Checkpoints all keyword arguments to Train as restorable state.
Expand All @@ -428,6 +435,7 @@ def train_func():
session.checkpoint(**kwargs)


@PublicAPI(stability="beta")
def world_size() -> int:
"""Get the current world size (i.e. total number of workers) for this run.
Expand Down
2 changes: 2 additions & 0 deletions python/ray/train/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
from ray.train.session import shutdown_session
from ray.train.utils import get_address_and_port
from ray.train.worker_group import WorkerGroup
from ray.util import PublicAPI

logger = logging.getLogger(__name__)


@PublicAPI(stability="beta")
@dataclass
class TensorflowConfig(BackendConfig):
@property
Expand Down
4 changes: 4 additions & 0 deletions python/ray/train/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

import torch
import torch.distributed as dist
from ray.util import PublicAPI
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DistributedSampler, DataLoader, \
IterableDataset, SequentialSampler

logger = logging.getLogger(__name__)


@PublicAPI(stability="beta")
@dataclass
class TorchConfig(BackendConfig):
"""Configuration for torch process group setup.
Expand Down Expand Up @@ -202,6 +204,7 @@ def get_device() -> torch.device:
return device


@PublicAPI(stability="beta")
def prepare_model(
model: torch.nn.Module,
move_to_device: bool = True,
Expand Down Expand Up @@ -246,6 +249,7 @@ def prepare_model(
return model


@PublicAPI(stability="beta")
def prepare_data_loader(data_loader: torch.utils.data.DataLoader,
add_dist_sampler: bool = True,
move_to_device: bool = True) -> \
Expand Down
5 changes: 5 additions & 0 deletions python/ray/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
# Ray Train should be usable even if Tune is not installed.
from ray.train.utils import construct_path
from ray.train.worker_group import WorkerGroup
from ray.util import PublicAPI
from ray.util.annotations import DeveloperAPI

if TUNE_INSTALLED:
from ray import tune
Expand Down Expand Up @@ -62,6 +64,7 @@ def get_backend_config_cls(backend_name) -> type:
return config_cls


@PublicAPI(stability="beta")
class Trainer:
"""A class for enabling seamless distributed deep learning.
Expand Down Expand Up @@ -512,6 +515,7 @@ def train_epoch(self):
return TrainWorkerGroup(worker_group)


@DeveloperAPI
class TrainWorkerGroup:
"""A container for a group of Ray actors.
Expand Down Expand Up @@ -557,6 +561,7 @@ def shutdown(self, patience_s: float = 5):
self._worker_group.shutdown(patience_s=patience_s)


@DeveloperAPI
class TrainingIterator:
"""An iterator over Train results. Returned by ``trainer.run_iterator``."""

Expand Down
4 changes: 2 additions & 2 deletions python/ray/util/sgd/tf/tf_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

from ray.tune import Trainable
from ray.tune.utils.placement_groups import PlacementGroupFactory
from ray.util.annotations import PublicAPI
from ray.util.annotations import Deprecated
from ray.util.sgd.tf.tf_runner import TFRunner

logger = logging.getLogger(__name__)


@PublicAPI(stability="beta")
@Deprecated
class TFTrainer:
def __init__(self,
model_creator,
Expand Down
4 changes: 2 additions & 2 deletions python/ray/util/sgd/torch/torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import ray
from ray.util import log_once
from ray.util.annotations import PublicAPI
from ray.util.annotations import Deprecated
from ray.util.sgd.torch.worker_group import LocalWorkerGroup, \
RemoteWorkerGroup, DeactivatedWorkerGroup
from ray.util.sgd.utils import NUM_SAMPLES, BATCH_SIZE
Expand Down Expand Up @@ -48,7 +48,7 @@ def _remind_gpu_usage(use_gpu):
"enable GPU usage. ")


@PublicAPI(stability="beta")
@Deprecated
class TorchTrainer:
"""Train a PyTorch model using distributed PyTorch.
Expand Down
4 changes: 2 additions & 2 deletions python/ray/util/sgd/torch/training_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.nn as nn
from filelock import FileLock

from ray.util.annotations import PublicAPI
from ray.util.annotations import Deprecated
from ray.util.sgd.utils import (TimerCollection, AverageMeterCollection,
NUM_SAMPLES)
from ray.util.sgd.torch.constants import (SCHEDULER_STEP_EPOCH, NUM_STEPS,
Expand Down Expand Up @@ -54,7 +54,7 @@ def _is_multiple(component):
return isinstance(component, Iterable) and len(component) > 1


@PublicAPI(stability="beta")
@Deprecated
class TrainingOperator:
"""Abstract class to define training and validation state and logic.
Expand Down

0 comments on commit 4f88796

Please sign in to comment.