Skip to content

Commit

Permalink
[AIR] Introduce DatasetIterator for bulk and streaming ingest (ray-pr…
Browse files Browse the repository at this point in the history
…oject#31470)

Introduces ray.air.DatasetIterator which exposes the same iteration-based interfaces as Dataset:

    iter_batches()
    to_tf()
    iter_torch_batches()
    stats()

This interface replaces Dataset and DatasetPipeline as the default data iterator interface in AIR trainers. Since both bulk and streaming ingest now use the same interface, this PR also hard-deprecates use_stream_api and stream_window_size (previously experimental). These are now replaced with a single max_object_store_memory_fraction, or the fraction of Ray's object store memory to use. The value defaults to -1, meaning bulk ingest.

This also simplifies the configs for specifying bulk/streaming ingest with global shuffle. Previously, global_shuffle=True would shuffle once before training (using Dataset) or once before each epoch (using DatasetPipeline). Now the preprocessed dataset is always shuffled once before each epoch (using DatasetPipeline).

For backwards compatibility in v2.3, DatasetIterator currently forwards unsupported methods to Dataset or DatasetPipeline.

Signed-off-by: Stephanie Wang <[email protected]>
  • Loading branch information
stephanie-wang committed Jan 12, 2023
1 parent 05d6f30 commit 835d1d5
Show file tree
Hide file tree
Showing 30 changed files with 987 additions and 234 deletions.
1 change: 1 addition & 0 deletions doc/source/data/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Ray Datasets API

input_output.rst
dataset.rst
dataset_iterator.rst
dataset_pipeline.rst
grouped_dataset.rst
dataset_context.rst
Expand Down
1 change: 1 addition & 0 deletions doc/source/data/api/dataset.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ Consuming Datasets
Dataset.show
Dataset.take
Dataset.take_all
Dataset.iterator
Dataset.iter_rows
Dataset.iter_batches
Dataset.iter_torch_batches
Expand Down
16 changes: 16 additions & 0 deletions doc/source/data/api/dataset_iterator.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
.. _dataset-iterator-api:

DatasetIterator API
===================

.. currentmodule:: ray.data

.. autoclass:: DatasetIterator

.. autosummary::
:toctree: doc/

DatasetIterator.iter_batches
DatasetIterator.iter_torch_batches
DatasetIterator.to_tf
DatasetIterator.stats
1 change: 1 addition & 0 deletions doc/source/data/api/dataset_pipeline.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ Consuming DatasetPipelines
DatasetPipeline.show_windows
DatasetPipeline.take
DatasetPipeline.take_all
DatasetPipeline.iterator
DatasetPipeline.iter_rows
DatasetPipeline.iter_batches
DatasetPipeline.iter_torch_batches
Expand Down
65 changes: 42 additions & 23 deletions doc/source/ray-air/check-ingest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ Let's walk through the stages of what happens when ``Trainer.fit()`` is called.
on the train dataset passed to the Trainer, followed by :py:meth:`prep.transform() <ray.data.preprocessor.Preprocessor.transform>`
on remaining datasets.

**Training**: Then, AIR passes the preprocessed dataset to Train workers (Ray actors) launched by the Trainer. Each worker calls :py:func:`get_dataset_shard <ray.air.session.get_dataset_shard>` to get a handle to its assigned data shard, and then calls one of :py:meth:`iter_batches() <ray.data.Dataset.iter_batches>`, :py:meth:`iter_torch_batches() <ray.data.Dataset.iter_torch_batches>`, or :meth:`~ray.data.Dataset.to_tf` to loop over the data.
**Training**: Then, AIR passes the preprocessed dataset to Train workers (Ray actors) launched by the Trainer. Each worker calls :func:`~ray.air.session.get_dataset_shard` to get a handle to its assigned data shard.
This returns a :class:`~ray.data.DatasetIterator`, which can be used to loop over the data with :meth:`~ray.data.DatasetIterator.iter_batches`, :meth:`~ray.data.Dataset.iter_torch_batches`, or :meth:`~ray.data.Dataset.to_tf`.
Each of these returns a batch iterator for one epoch (a full pass over the original dataset).

Getting Started
---------------
Expand All @@ -43,6 +45,8 @@ The following is a simple example of how to configure ingest for a dummy :py:cla

.. _air-configure-ingest:

For local development and testing, you can also use the helper function :meth:`~ray.air.util.check_ingest.make_local_dataset_iterator` to get a local :class:`~ray.data.DatasetIterator`.

Configuring Ingest
------------------
You can use the :py:class:`~ray.air.config.DatasetConfig` object to configure how Datasets are preprocessed and split across training workers.
Expand Down Expand Up @@ -83,20 +87,31 @@ Enabling Streaming Ingest

You should use bulk ingest when:

* you have enough memory to fit data blocks in cluster object store;
* your preprocessing step is expensive per each epoch; and
* you want best performance when both or either the above conditions are met.
* you have enough memory to fit data blocks in cluster object store; or
* your preprocessing transform is expensive to recompute on each epoch

.. tabbed:: Streaming Ingest (experimental)

In streaming ingest mode, :py:func:`~ray.air.session.get_dataset_shard` returns a :py:class:`~ray.data.dataset_pipeline.DatasetPipeline` pipeline that
can be used to read data in a streaming way.
To enable streaming ingest, set ``use_stream_api=True`` in the dataset config.

By default, this will tell AIR to load *windows* of 1GiB of data into memory at a time.
Performance can be increased with larger window sizes, which can be adjusted using the
``stream_window_size`` config.
A reasonable stream window size is something like 20% of available object store memory.
In streaming ingest mode, instead of loading the entire dataset into the
Ray object store at once, AIR will load a fraction of the dataset at a
time. This can be desirable when the dataset is very large, and caching it
all at once would cause expensive disk spilling. The downside is that the
dataset will have to be preprocessed on each epoch, which may be more
expensive. Preprocessing is overlapped with training computation, but
overall training throughput may still decrease if preprocessing is more
expensive than the training computation (forward pass, backward pass,
gradient sync).

To enable this mode, use the :py:meth:`max_object_store_memory_fraction
<ray.air.config.DatasetConfig>` argument. This argument defaults to -1,
meaning that bulk ingest should be used and the entire dataset should be
computed and cached before training starts.

Use a float value 0 or greater to indicate the "window" size, i.e. the
maximum fraction of object store memory that should be used at once. A
reasonable value is 0.2, meaning 20% of available object store memory.
Larger window sizes can improve performance by increasing parallelism. A
window size of 1 or greater will likely result in spilling.

.. literalinclude:: doc_code/air_ingest.py
:language: python
Expand All @@ -105,10 +120,11 @@ Enabling Streaming Ingest

Use streaming ingest when:

* you have large datasets that don't fit into memory;
* you want to process small chunks or blocks per window;
* you can use small windows with small data blocks minimizing or avoiding memory starvation or OOM errors; and
* your preprocessing step is not a bottleneck or not an expensive operation since it's re-executed on each pass over the data.
* you have large datasets that don't fit into memory; and
* re-executing the preprocessing step on each epoch is faster than caching the preprocessed dataset on disk and reloading from disk on each epoch

Note that this feature is experimental and the actual object store memory
usage may vary. Please file a `GitHub issue <https://github.com/ray-project/ray/issues>`_ if you run into problems.

.. _air-shuffle:

Expand All @@ -120,12 +136,12 @@ Shuffling or data randomization is important for training high-quality models. B
.. tabbed:: Local Shuffling

Local shuffling is the recommended approach for randomizing data order. To use local shuffle,
simply specify a non-zero ``local_shuffle_buffer_size`` as an argument to :py:meth:`iter_batches() <ray.data.Dataset.iter_batches>`.
simply specify a non-zero ``local_shuffle_buffer_size`` as an argument to :meth:`~ray.data.DatasetIterator.iter_batches`.
The iterator will then use a local buffer of the given size to randomize record order. The
larger the buffer size, the more randomization will be applied, but it will also use more
memory.

See :meth:`ds.iter_batches() <ray.data.Dataset.iter_batches>` for more details.
See :meth:`~ray.data.DatasetIterator.iter_batches` for more details.

.. literalinclude:: doc_code/air_ingest.py
:language: python
Expand All @@ -143,9 +159,11 @@ Shuffling or data randomization is important for training high-quality models. B
Global shuffling provides more uniformly random (decorrelated) samples and is carried
out via a distributed map-reduce operation. This higher quality shuffle can often lead
to more precision gain per training step, but it is also an expensive distributed
operation and will decrease the ingest throughput. As long as the shuffled ingest
throughput matches or exceeds the model training (forward pass, backward pass, gradient sync)
throughput, this higher-quality shuffle shouldn't slow down the overall training.
operation and will decrease the ingest throughput. The shuffle step is overlapped with
training computation, so as long as the shuffled ingest throughput matches
or exceeds the model training (forward pass, backward pass, gradient sync)
throughput, this higher-quality shuffle shouldn't slow down the overall
training.

If global shuffling *is* causing the ingest throughput to become the training
bottleneck, local shuffling may be a better option.
Expand Down Expand Up @@ -241,7 +259,9 @@ Debugging Ingest with the ``DummyTrainer``

Data ingest problems can be challenging to debug when combined in a full training pipeline. To isolate data
ingest issues from other possible training problems, we provide the :py:class:`~ray.air.util.check_ingest.DummyTrainer`
utility class that can be used to debug ingest problems. Let's walk through using DummyTrainer to understand
utility class that can be used to debug ingest problems.
You can also use the helper function :meth:`~ray.air.util.check_ingest.make_local_dataset_iterator` to get a local :class:`~ray.data.DatasetIterator` for debugging purposes.
Let's walk through using ``DummyTrainer`` to understand
and resolve an ingest misconfiguration.

Setting it up
Expand Down Expand Up @@ -394,4 +414,3 @@ How do I shard validation and test datasets?
By default only the `"train"` Dataset is sharded. To also shard validation and test datasets, you can configure the ``dataset_config``
that is passed to your ``Trainer``.
See the :ref:`Splitting Auxiliary Datasets <air-splitting-aux-datasets>` section for a full example.

31 changes: 14 additions & 17 deletions doc/source/ray-air/doc_code/air_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
# __config_4__
import ray
from ray.air import session
from ray.data import Dataset
from ray.data import DatasetIterator
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig

Expand All @@ -97,8 +97,8 @@


def train_loop_per_worker():
# Get a handle to the worker's assigned Dataset shard.
data_shard: Dataset = session.get_dataset_shard("train")
# Get a handle to the worker's assigned DatasetIterator shard.
data_shard: DatasetIterator = session.get_dataset_shard("train")

# Manually iterate over the data 10 times (10 epochs).
for _ in range(10):
Expand All @@ -123,7 +123,7 @@ def train_loop_per_worker():
# __config_5__
import ray
from ray.air import session
from ray.data import DatasetPipeline
from ray.data import DatasetIterator
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig, DatasetConfig

Expand All @@ -132,29 +132,26 @@ def train_loop_per_worker():


def train_loop_per_worker():
# A DatasetPipeline object is returned when `use_stream_api` is set.
data_shard: DatasetPipeline = session.get_dataset_shard("train")
data_shard: DatasetIterator = session.get_dataset_shard("train")

# Use iter_epochs(10) to iterate over 10 epochs of data.
for epoch in data_shard.iter_epochs(10):
for batch in epoch.iter_batches():
# Iterate over 10 epochs of data.
for _ in range(10):
for batch in data_shard.iter_batches():
print("Do some training on batch", batch)

# View the stats for performance debugging.
print(data_shard.stats())


# Set N = 200 bytes for this toy example. Typically, you'd set N >= 1GiB.
N = 200

my_trainer = TorchTrainer(
train_loop_per_worker,
scaling_config=ScalingConfig(num_workers=1),
datasets={
"train": ray.data.range_tensor(1000),
},
dataset_config={
"train": DatasetConfig(use_stream_api=True, stream_window_size=N),
# Use 20% of object store memory.
"train": DatasetConfig(max_object_store_memory_fraction=0.2),
},
preprocessor=preprocessor,
)
Expand All @@ -164,13 +161,13 @@ def train_loop_per_worker():
# __global_shuffling_start__
import ray
from ray.air import session
from ray.data import Dataset
from ray.data import DatasetIterator
from ray.train.torch import TorchTrainer
from ray.air.config import DatasetConfig, ScalingConfig


def train_loop_per_worker():
data_shard: Dataset = session.get_dataset_shard("train")
data_shard: DatasetIterator = session.get_dataset_shard("train")

# Iterate over 10 epochs of data.
for epoch in range(10):
Expand All @@ -197,13 +194,13 @@ def train_loop_per_worker():
# __local_shuffling_start__
import ray
from ray.air import session
from ray.data import Dataset
from ray.data import DatasetIterator
from ray.train.torch import TorchTrainer
from ray.air.config import DatasetConfig, ScalingConfig


def train_loop_per_worker():
data_shard: Dataset = session.get_dataset_shard("train")
data_shard: DatasetIterator = session.get_dataset_shard("train")

# Iterate over 10 epochs of data.
for epoch in range(10):
Expand Down
15 changes: 15 additions & 0 deletions doc/source/ray-air/package-ref.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,23 @@ Abstract Classes

.. automethod:: __init__

.. autoclass:: ray.air.util.check_ingest.DummyTrainer
:members:
:show-inheritance:

.. automethod:: __init__

.. _air-results-ref:

Dataset Iteration
#################

.. autoclass:: ray.data.DatasetIterator
:members:
:noindex:

.. autofunction:: ray.air.util.check_ingest.make_local_dataset_iterator

Training Result
###############

Expand Down
38 changes: 22 additions & 16 deletions doc/source/ray-core/_examples/datasets_train/datasets_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,12 +406,20 @@ def train_func(config):
print(f"Device: {device}")

# Setup data.
train_dataset_pipeline = session.get_dataset_shard("train")
train_dataset_epoch_iterator = train_dataset_pipeline.iter_epochs()
test_dataset = session.get_dataset_shard("test")
test_torch_dataset = test_dataset.to_torch(
label_column="label", batch_size=batch_size, drop_last=True
)
train_dataset_iterator = session.get_dataset_shard("train")
test_dataset_iterator = session.get_dataset_shard("test")

def to_torch_dataset(torch_batch_iterator):
for batch in torch_batch_iterator:
label_column = "label"
labels = batch[label_column].unsqueeze(1)
features = [
batch[col_name].unsqueeze(1)
for col_name in batch
if col_name != label_column
]
inputs = torch.cat(features, dim=1)
yield inputs, labels

net = Net(
n_layers=num_layers,
Expand All @@ -429,12 +437,9 @@ def train_func(config):

print("Starting training...")
for epoch in range(num_epochs):
train_dataset = next(train_dataset_epoch_iterator)

train_torch_dataset = train_dataset.to_torch(
label_column="label", batch_size=batch_size
train_torch_dataset = to_torch_dataset(
train_dataset_iterator.iter_torch_batches(batch_size=batch_size)
)

train_running_loss, train_num_correct, train_num_total = train_epoch(
train_torch_dataset, net, device, criterion, optimizer
)
Expand All @@ -444,6 +449,11 @@ def train_func(config):
f"{train_num_correct} / {train_num_total} = {train_acc:.4f}"
)

test_torch_dataset = to_torch_dataset(
test_dataset_iterator.iter_torch_batches(
batch_size=batch_size, drop_last=True
)
)
test_running_loss, test_num_correct, test_num_total = test_epoch(
test_torch_dataset, net, device, criterion
)
Expand Down Expand Up @@ -623,11 +633,7 @@ def train_func(config):
resources_per_worker=resources_per_worker,
),
run_config=RunConfig(callbacks=callbacks),
dataset_config={
"train": DatasetConfig(
use_stream_api=True, stream_window_size=-1, global_shuffle=True
)
},
dataset_config={"train": DatasetConfig(global_shuffle=True)},
)
results = trainer.fit()
state_dict = results.checkpoint.to_dict()["model"]
Expand Down
6 changes: 6 additions & 0 deletions python/ray/air/_internal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np

import ray
from ray.air.constants import _ERROR_REPORT_TIMEOUT

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -113,3 +114,8 @@ def run(self):
def join(self, timeout=None):
super(RunnerThread, self).join(timeout)
return self._ret


def _estimate_avail_object_store_memory() -> int:
"""Estimates total object store memory available in the cluster."""
return ray.available_resources()["object_store_memory"]
Loading

0 comments on commit 835d1d5

Please sign in to comment.