Skip to content

Commit

Permalink
[Datasets] Enable lazy execution by default (#31286)
Browse files Browse the repository at this point in the history
This PR is to enable lazy execution by default. See ray-project/enhancements#19 for motivation. The change includes:
* Change `Dataset` constructor: `Dataset.__init__(lazy: bool = True)`. Also remove `defer_execution` field, as it's no longer needed.
* `read_api.py:read_datasource()` returns a lazy `Dataset` with computing the first input block.
* Add `ds.fully_executed()` calls to required unit tests, to make sure they are passing.

TODO:
- [x] Fix all unit tests
- [x] #31459
- [x] #31460 
- [ ] Remove the behavior to eagerly compute first block for read
- [ ] #31417
- [ ] Update documentation
  • Loading branch information
c21 committed Jan 6, 2023
1 parent 955e756 commit 9cb9c0e
Show file tree
Hide file tree
Showing 10 changed files with 64 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def train_func(config):
read_dataset(data_path)
)

num_columns = len(train_dataset.schema().names)
num_columns = len(train_dataset.schema(fetch_if_missing=True).names)
# remove label column.
num_features = num_columns - 1

Expand Down
5 changes: 5 additions & 0 deletions python/ray/data/_internal/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,11 @@ def schema(
self.execute()
else:
return None
elif self._in_blocks is not None and self._snapshot_blocks is None:
# If the plan only has input blocks, we execute it, so snapshot has output.
# This applies to newly created dataset. For example, initial dataset from
# read, and output datasets of Dataset.split().
self.execute()
# Snapshot is now guaranteed to be the output of the final stage or None.
blocks = self._snapshot_blocks
if not blocks:
Expand Down
30 changes: 14 additions & 16 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,19 +177,19 @@ class Dataset(Generic[T]):
>>> ds = ray.data.range(1000)
>>> # Transform in parallel with map_batches().
>>> ds.map_batches(lambda batch: [v * 2 for v in batch])
Dataset(num_blocks=..., num_rows=1000, schema=<class 'int'>)
Dataset(num_blocks=..., num_rows=..., schema=...)
>>> # Compute max.
>>> ds.max()
999
>>> # Group the data.
>>> ds.groupby(lambda x: x % 3).count()
Dataset(num_blocks=..., num_rows=3, schema=<class 'tuple'>)
Dataset(num_blocks=..., num_rows=..., schema=...)
>>> # Shuffle this dataset randomly.
>>> ds.random_shuffle()
Dataset(num_blocks=..., num_rows=1000, schema=<class 'int'>)
Dataset(num_blocks=..., num_rows=..., schema=...)
>>> # Sort it back in order.
>>> ds.sort()
Dataset(num_blocks=..., num_rows=1000, schema=<class 'int'>)
Dataset(num_blocks=..., num_rows=..., schema=...)
Since Datasets are just lists of Ray object refs, they can be passed
between Ray tasks and actors without incurring a copy. Datasets support
Expand All @@ -202,9 +202,7 @@ def __init__(
self,
plan: ExecutionPlan,
epoch: int,
lazy: bool,
*,
defer_execution: bool = False,
lazy: bool = True,
):
"""Construct a Dataset (internal API).
Expand All @@ -219,7 +217,7 @@ def __init__(
self._epoch = epoch
self._lazy = lazy

if not lazy and not defer_execution:
if not lazy:
self._plan.execute(allow_clear_input_blocks=False)

@staticmethod
Expand All @@ -243,7 +241,7 @@ def map(
>>> # Transform python objects.
>>> ds = ray.data.range(1000)
>>> ds.map(lambda x: x * 2)
Dataset(num_blocks=..., num_rows=1000, schema=<class 'int'>)
Dataset(num_blocks=..., num_rows=..., schema=...)
>>> # Transform Arrow records.
>>> ds = ray.data.from_items(
... [{"value": i} for i in range(1000)])
Expand Down Expand Up @@ -804,7 +802,7 @@ def flat_map(
>>> import ray
>>> ds = ray.data.range(1000)
>>> ds.flat_map(lambda x: [x, x ** 2, x ** 3])
Dataset(num_blocks=..., num_rows=3000, schema=<class 'int'>)
Dataset(num_blocks=..., num_rows=..., schema=...)
Time complexity: O(dataset size / parallelism)
Expand Down Expand Up @@ -872,7 +870,7 @@ def filter(
>>> import ray
>>> ds = ray.data.range(100)
>>> ds.filter(lambda x: x % 2 == 0)
Dataset(num_blocks=..., num_rows=50, schema=<class 'int'>)
Dataset(num_blocks=..., num_rows=..., schema=...)
Time complexity: O(dataset size / parallelism)
Expand Down Expand Up @@ -966,10 +964,10 @@ def random_shuffle(
>>> ds = ray.data.range(100)
>>> # Shuffle this dataset randomly.
>>> ds.random_shuffle()
Dataset(num_blocks=..., num_rows=100, schema=<class 'int'>)
Dataset(num_blocks=..., num_rows=..., schema=...)
>>> # Shuffle this dataset with a fixed random seed.
>>> ds.random_shuffle(seed=12345)
Dataset(num_blocks=..., num_rows=100, schema=<class 'int'>)
Dataset(num_blocks=..., num_rows=..., schema=...)
Time complexity: O(dataset size / parallelism)
Expand Down Expand Up @@ -1012,7 +1010,7 @@ def randomize_block_order(
"""

plan = self._plan.with_stage(RandomizeBlocksStage(seed))
return Dataset(plan, self._epoch, self._lazy, defer_execution=True)
return Dataset(plan, self._epoch, self._lazy)

def random_sample(
self, fraction: float, *, seed: Optional[int] = None
Expand Down Expand Up @@ -1533,7 +1531,7 @@ def groupby(self, key: Optional[KeyFn]) -> "GroupedDataset[T]":
>>> import ray
>>> # Group by a key function and aggregate.
>>> ray.data.range(100).groupby(lambda x: x % 3).count()
Dataset(num_blocks=..., num_rows=3, schema=<class 'tuple'>)
Dataset(num_blocks=..., num_rows=..., schema=...)
>>> # Group by an Arrow table column and aggregate.
>>> ray.data.from_items([
... {"A": x % 3, "B": x} for x in range(100)]).groupby(
Expand Down Expand Up @@ -1933,7 +1931,7 @@ def sort(
>>> # Sort using the entire record as the key.
>>> ds = ray.data.range(100)
>>> ds.sort()
Dataset(num_blocks=..., num_rows=100, schema=<class 'int'>)
Dataset(num_blocks=..., num_rows=..., schema=...)
>>> # Sort by a single column in descending order.
>>> ds = ray.data.from_items(
... [{"value": i} for i in range(1000)])
Expand Down
6 changes: 3 additions & 3 deletions python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,9 @@ def read_datasource(
block_list.ensure_metadata_for_first_block()

return Dataset(
ExecutionPlan(block_list, block_list.stats(), run_by_consumer=False),
0,
False,
plan=ExecutionPlan(block_list, block_list.stats(), run_by_consumer=False),
epoch=0,
lazy=True,
)


Expand Down
32 changes: 23 additions & 9 deletions python/ray/data/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,10 +337,10 @@ def test_zip(ray_start_regular_shared):
ds1 = ray.data.range(5, parallelism=5)
ds2 = ray.data.range(5, parallelism=5).map(lambda x: x + 1)
ds = ds1.zip(ds2)
assert ds.schema() == tuple
assert ds.schema(fetch_if_missing=True) == tuple
assert ds.take() == [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5)]
with pytest.raises(ValueError):
ds.zip(ray.data.range(3))
ds.zip(ray.data.range(3)).fully_executed()


def test_zip_pandas(ray_start_regular_shared):
Expand All @@ -366,8 +366,8 @@ def test_zip_arrow(ray_start_regular_shared):
lambda r: {"a": r["value"] + 1, "b": r["value"] + 2}
)
ds = ds1.zip(ds2)
assert "{id: int64, a: int64, b: int64}" in str(ds)
assert ds.count() == 5
assert "{id: int64, a: int64, b: int64}" in str(ds)
result = [r.as_pydict() for r in ds.take()]
assert result[0] == {"id": 0, "a": 1, "b": 2}

Expand Down Expand Up @@ -749,6 +749,7 @@ def test_tensors_sort(ray_start_regular_shared):
def test_tensors_inferred_from_map(ray_start_regular_shared):
# Test map.
ds = ray.data.range(10, parallelism=10).map(lambda _: np.ones((4, 4)))
ds.fully_executed()
assert str(ds) == (
"Dataset(num_blocks=10, num_rows=10, "
"schema={__value__: ArrowTensorType(shape=(4, 4), dtype=double)})"
Expand All @@ -758,6 +759,7 @@ def test_tensors_inferred_from_map(ray_start_regular_shared):
ds = ray.data.range(16, parallelism=4).map_batches(
lambda _: np.ones((3, 4, 4)), batch_size=2
)
ds.fully_executed()
assert str(ds) == (
"Dataset(num_blocks=4, num_rows=24, "
"schema={__value__: ArrowTensorType(shape=(4, 4), dtype=double)})"
Expand All @@ -767,6 +769,7 @@ def test_tensors_inferred_from_map(ray_start_regular_shared):
ds = ray.data.range(10, parallelism=10).flat_map(
lambda _: [np.ones((4, 4)), np.ones((4, 4))]
)
ds.fully_executed()
assert str(ds) == (
"Dataset(num_blocks=10, num_rows=20, "
"schema={__value__: ArrowTensorType(shape=(4, 4), dtype=double)})"
Expand All @@ -776,6 +779,7 @@ def test_tensors_inferred_from_map(ray_start_regular_shared):
ds = ray.data.range(16, parallelism=4).map_batches(
lambda _: pd.DataFrame({"a": [np.ones((4, 4))] * 3}), batch_size=2
)
ds.fully_executed()
assert str(ds) == (
"Dataset(num_blocks=4, num_rows=24, "
"schema={a: TensorDtype(shape=(4, 4), dtype=float64)})"
Expand All @@ -785,6 +789,7 @@ def test_tensors_inferred_from_map(ray_start_regular_shared):
lambda _: pd.DataFrame({"a": [np.ones((2, 2)), np.ones((3, 3))]}),
batch_size=2,
)
ds.fully_executed()
assert str(ds) == (
"Dataset(num_blocks=4, num_rows=16, "
"schema={a: TensorDtype(shape=(None, None), dtype=float64)})"
Expand Down Expand Up @@ -1456,24 +1461,29 @@ def test_empty_dataset(ray_start_regular_shared):

ds = ray.data.range(1)
ds = ds.filter(lambda x: x > 1)
ds.fully_executed()
assert str(ds) == "Dataset(num_blocks=1, num_rows=0, schema=Unknown schema)"

# Test map on empty dataset.
ds = ray.data.from_items([])
ds = ds.map(lambda x: x)
ds.fully_executed()
assert ds.count() == 0

# Test filter on empty dataset.
ds = ray.data.from_items([])
ds = ds.filter(lambda: True)
ds.fully_executed()
assert ds.count() == 0


def test_schema(ray_start_regular_shared):
ds = ray.data.range(10, parallelism=10)
ds2 = ray.data.range_table(10, parallelism=10)
ds3 = ds2.repartition(5)
ds3.fully_executed()
ds4 = ds3.map(lambda x: {"a": "hi", "b": 1.0}).limit(5).repartition(1)
ds4.fully_executed()
assert str(ds) == "Dataset(num_blocks=10, num_rows=10, schema=<class 'int'>)"
assert str(ds2) == "Dataset(num_blocks=10, num_rows=10, schema={value: int64})"
assert str(ds3) == "Dataset(num_blocks=5, num_rows=10, schema={value: int64})"
Expand Down Expand Up @@ -2284,7 +2294,7 @@ def test_drop_columns(ray_start_regular_shared, tmp_path):
]
# Test dropping non-existent column
with pytest.raises(KeyError):
ds.drop_columns(["dummy_col", "col1", "col2"])
ds.drop_columns(["dummy_col", "col1", "col2"]).fully_executed()


def test_select_columns(ray_start_regular_shared):
Expand Down Expand Up @@ -2315,13 +2325,13 @@ def test_select_columns(ray_start_regular_shared):
]
# Test selecting a column that is not in the dataset schema
with pytest.raises(KeyError):
each_ds.select_columns(cols=["col1", "col2", "dummy_col"])
each_ds.select_columns(cols=["col1", "col2", "dummy_col"]).fully_executed()

# Test simple
ds3 = ray.data.range(10)
assert ds3.dataset_format() == "simple"
with pytest.raises(ValueError):
ds3.select_columns(cols=[])
ds3.select_columns(cols=[]).fully_executed()


def test_map_batches_basic(ray_start_regular_shared, tmp_path):
Expand Down Expand Up @@ -2684,11 +2694,13 @@ def mutate(df):
ds = ray.data.range_table(num_rows, parallelism=num_blocks).repartition(num_blocks)
# Convert to Pandas blocks.
ds = ds.map_batches(lambda df: df, batch_format="pandas", batch_size=None)
ds.fully_executed()

# Apply UDF that mutates the batches, which should fail since the batch is
# read-only.
with pytest.raises(ValueError, match="tried to mutate a zero-copy read-only batch"):
ds.map_batches(mutate, batch_size=batch_size, zero_copy_batch=True)
ds = ds.map_batches(mutate, batch_size=batch_size, zero_copy_batch=True)
ds.fully_executed()


BLOCK_BUNDLING_TEST_CASES = [
Expand All @@ -2710,10 +2722,12 @@ def test_map_batches_block_bundling_auto(

# Blocks should be bundled up to the batch size.
ds1 = ds.map_batches(lambda x: x, batch_size=batch_size)
ds1.fully_executed()
assert ds1.num_blocks() == math.ceil(num_blocks / max(batch_size // block_size, 1))

# Blocks should not be bundled up when batch_size is not specified.
ds2 = ds.map_batches(lambda x: x)
ds2.fully_executed()
assert ds2.num_blocks() == num_blocks


Expand Down Expand Up @@ -2796,7 +2810,7 @@ def good_fn(row):
ds = ray.data.range(10, parallelism=1)
error_message = "Current row has different columns compared to previous rows."
with pytest.raises(ValueError) as e:
ds.map(bad_fn)
ds.map(bad_fn).fully_executed()
assert error_message in str(e.value)
ds_map = ds.map(good_fn)
assert ds_map.take() == [{"a": "hello1", "b": "hello2"} for _ in range(10)]
Expand Down Expand Up @@ -5364,7 +5378,7 @@ def f(x):
compute_strategy = ray.data.ActorPoolStrategy()
ray.data.range(10, parallelism=10).map_batches(
f, batch_size=1, compute=compute_strategy
)
).fully_executed()
expected_max_num_workers = math.ceil(
num_cpus * (1 / compute_strategy.ready_to_total_workers_ratio)
)
Expand Down
5 changes: 4 additions & 1 deletion python/ray/data/tests/test_dataset_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,10 @@ def prefetch_file_metadata(self, pieces):

# Expect precomputed row counts and block sizes to be missing.
assert ds._meta_count() is None
assert ds._plan._snapshot_blocks.size_bytes() == -1
assert (
ds._plan._snapshot_blocks is None
or ds._plan._snapshot_blocks.size_bytes() == -1
)

# Expect to lazily compute all metadata correctly.
assert ds._plan.execute()._num_computed() == 1
Expand Down
4 changes: 4 additions & 0 deletions python/ray/data/tests/test_dynamic_block_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,13 @@ def test_dataset(
assert ds.size_bytes() >= 0.7 * block_size * num_blocks * num_tasks

map_ds = ds.map_batches(lambda x: x)
map_ds.fully_executed()
assert map_ds.num_blocks() == num_tasks
map_ds = ds.map_batches(lambda x: x, batch_size=num_blocks * num_tasks)
map_ds.fully_executed()
assert map_ds.num_blocks() == 1
map_ds = ds.map(lambda x: x)
map_ds.fully_executed()
assert map_ds.num_blocks() == num_blocks * num_tasks

ds_list = ds.split(5)
Expand All @@ -109,6 +112,7 @@ def test_dataset(
assert ds.groupby("one").count().count() == num_blocks * num_tasks

new_ds = ds.zip(ds)
new_ds.fully_executed()
assert new_ds.num_blocks() == num_blocks * num_tasks

assert len(ds.take(5)) == 5
Expand Down
15 changes: 2 additions & 13 deletions python/ray/data/tests/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def test_memory_sanity(shutdown_only):
info = ray.init(num_cpus=1, object_store_memory=500e6)
ds = ray.data.range(10)
ds = ds.map(lambda x: np.ones(100 * 1024 * 1024, dtype=np.uint8))
ds.fully_executed()
meminfo = memory_summary(info.address_info["address"], stats_only=True)

# Sanity check spilling is happening as expected.
Expand Down Expand Up @@ -291,23 +292,11 @@ def _assert_has_stages(stages, stage_names):


def test_stage_linking(ray_start_regular_shared):
# NOTE: This tests the internals of `ExecutionPlan`, which is bad practice. Remove
# this test once we have proper unit testing of `ExecutionPlan`.
# Test eager dataset.
ds = ray.data.range(10)
assert len(ds._plan._stages_before_snapshot) == 0
assert len(ds._plan._stages_after_snapshot) == 0
assert len(ds._plan._last_optimized_stages) == 0
ds = ds.map(lambda x: x + 1)
_assert_has_stages(ds._plan._stages_before_snapshot, ["map"])
assert len(ds._plan._stages_after_snapshot) == 0
_assert_has_stages(ds._plan._last_optimized_stages, ["read->map"])

# Test lazy dataset.
ds = ray.data.range(10).lazy()
assert len(ds._plan._stages_before_snapshot) == 0
assert len(ds._plan._stages_after_snapshot) == 0
assert len(ds._plan._last_optimized_stages) == 0
assert ds._plan._last_optimized_stages is None
ds = ds.map(lambda x: x + 1)
assert len(ds._plan._stages_before_snapshot) == 0
_assert_has_stages(ds._plan._stages_after_snapshot, ["map"])
Expand Down
Loading

0 comments on commit 9cb9c0e

Please sign in to comment.