Skip to content

Commit

Permalink
[data] Add DataIterator.materialize (#43210)
Browse files Browse the repository at this point in the history
This PR introduces a `DataIterator.materialize` API that fully executes/consumes a data iterator and returns it as a `MaterializedDataset` for the user to continue processing it.

The reason to add this API is to support model training in Ray Train that requires the full dataset up front. For example, `xgboost` needs to consider the full dataset to fit decision trees and expects that full dataset to be .

The `get_dataset_shard` API which bridges Ray Data and Ray Train calls `streaming_split` on the dataset, where the number of splits is the number of training workers. This works well for SGD training schemes (typical for Torch, Tensorflow users), since the typical training procedure is to estimate the gradient on a small batch of data at a time. Fitting decision trees requires searching for the best split over the entire dataset, where the batch by batch dataloading is not suitable.

Signed-off-by: Justin Yu <[email protected]>
  • Loading branch information
justinvyu committed Feb 16, 2024
1 parent 63510e5 commit e221c6e
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 3 deletions.
1 change: 1 addition & 0 deletions doc/source/data/api/data_iterator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ DataIterator API
DataIterator.iter_batches
DataIterator.iter_torch_batches
DataIterator.to_tf
DataIterator.materialize
DataIterator.stats
46 changes: 43 additions & 3 deletions python/ray/data/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
import numpy as np

from ray.data._internal.block_batching.iter_batches import iter_batches
from ray.data._internal.block_list import BlockList
from ray.data._internal.execution.legacy_compat import _block_list_to_bundles
from ray.data._internal.logical.operators.input_data_operator import InputData
from ray.data._internal.logical.optimizers import LogicalPlan
from ray.data._internal.plan import ExecutionPlan
from ray.data._internal.stats import DatasetStats, StatsManager
from ray.data.block import (
Block,
Expand All @@ -34,11 +39,13 @@

from ray.data.dataset import (
CollatedData,
MaterializedDataset,
Schema,
TensorFlowTensorBatchType,
TorchBatchType,
)


T = TypeVar("T")


Expand Down Expand Up @@ -647,9 +654,11 @@ def make_generator():
key: convert_pandas_to_torch_tensor(
batch,
feature_columns[key],
feature_column_dtypes[key]
if isinstance(feature_column_dtypes, dict)
else feature_column_dtypes,
(
feature_column_dtypes[key]
if isinstance(feature_column_dtypes, dict)
else feature_column_dtypes
),
unsqueeze=unsqueeze_feature_tensors,
)
for key in feature_columns
Expand Down Expand Up @@ -852,6 +861,37 @@ def generator():
)
return dataset.with_options(options)

def materialize(self) -> "MaterializedDataset":
"""Execute and materialize this data iterator into object store memory.
.. note::
This method triggers the execution and materializes all blocks
of the iterator, returning its contents as a
:class:`~ray.data.dataset.MaterializedDataset` for further processing.
"""

from ray.data.dataset import MaterializedDataset

block_iter, stats, owned_by_consumer = self._to_block_iterator()

block_refs_and_metadata = list(block_iter)
block_refs = [block_ref for block_ref, _ in block_refs_and_metadata]
metadata = [metadata for _, metadata in block_refs_and_metadata]

block_list = BlockList(
block_refs, metadata, owned_by_consumer=owned_by_consumer
)
ref_bundles = _block_list_to_bundles(block_list, owned_by_consumer)
logical_plan = LogicalPlan(InputData(input_data=ref_bundles))
return MaterializedDataset(
ExecutionPlan(
block_list,
stats,
run_by_consumer=owned_by_consumer,
),
logical_plan,
)

def __del__(self):
# Clear metrics on deletion in case the iterator was not fully consumed.
StatsManager.clear_iteration_metrics(self._get_dataset_tag())
Expand Down
40 changes: 40 additions & 0 deletions python/ray/data/tests/test_iterator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import threading
from typing import Dict
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -177,6 +178,45 @@ def collate_fn(batch: Dict[str, np.ndarray]):
), iter_batches_calls_kwargs


def test_iterator_to_materialized_dataset(ray_start_regular_shared):
"""Tests that `DataIterator.materialize` fully consumes the
iterator and returns a `MaterializedDataset` view of the data
that can be used to interact with the full dataset
(e.g. load it all into memory)."""
ds = ray.data.range(10)
num_splits = 2
iters = ds.streaming_split(num_splits, equal=True)

def consume_in_parallel(fn):
runners = [
threading.Thread(target=fn, args=(it, i)) for i, it in enumerate(iters)
]
[r.start() for r in runners]
[r.join() for r in runners]

materialized_ds = {}
shard_data = {}

def materialize(it, i):
materialized_ds[i] = it.materialize()

def iter_batches(it, i):
data = []
for batch in it.iter_batches():
data.extend(batch["id"].tolist())
shard_data[i] = data

consume_in_parallel(materialize)
consume_in_parallel(iter_batches)

# Check that the materialized datasets contain the same data as the
# original iterators.
for i in range(num_splits):
assert sorted(materialized_ds[i].to_pandas()["id"].tolist()) == sorted(
shard_data[i]
)


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit e221c6e

Please sign in to comment.