Skip to content

Commit

Permalink
[Data] Remove option to disable block splitting (ray-project#38235)
Browse files Browse the repository at this point in the history
Dynamic block splitting has been enabled by default for a couple of releases. To improve the maintainability of our codebase, this PR removes the option to disable block splitting and removes related code paths.

Signed-off-by: Balaji Veeramani <[email protected]>
  • Loading branch information
bveeramani committed Aug 9, 2023
1 parent 9b747da commit 8afcf9b
Show file tree
Hide file tree
Showing 11 changed files with 87 additions and 240 deletions.
62 changes: 19 additions & 43 deletions python/ray/data/_internal/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@ def _apply(
if fn_kwargs is None:
fn_kwargs = {}

context = DataContext.get_current()

# Handle empty datasets.
if block_list.initial_num_blocks() == 0:
return block_list
Expand All @@ -96,37 +94,20 @@ def _apply(
name = name.title()
map_bar = ProgressBar(name, total=len(block_bundles))

if context.block_splitting_enabled:
map_block = cached_remote_fn(_map_block_split).options(
num_returns="dynamic", **remote_args
)
refs = [
map_block.remote(
block_fn,
[f for m in ms for f in m.input_files],
fn,
len(bs),
*(bs + fn_args),
**fn_kwargs,
)
for bs, ms in block_bundles
]
else:
map_block = cached_remote_fn(_map_block_nosplit).options(
**dict(remote_args, num_returns=2)
map_block = cached_remote_fn(_map_block_split).options(
num_returns="dynamic", **remote_args
)
refs = [
map_block.remote(
block_fn,
[f for m in ms for f in m.input_files],
fn,
len(bs),
*(bs + fn_args),
**fn_kwargs,
)
all_refs = [
map_block.remote(
block_fn,
[f for m in ms for f in m.input_files],
fn,
len(bs),
*(bs + fn_args),
**fn_kwargs,
)
for bs, ms in block_bundles
]
data_refs, refs = map(list, zip(*all_refs))
for bs, ms in block_bundles
]

in_block_owned_by_consumer = block_list._owned_by_consumer
# Release input block references.
Expand Down Expand Up @@ -155,17 +136,12 @@ def _apply(
raise e from None

new_blocks, new_metadata = [], []
if context.block_splitting_enabled:
for ref_generator in results:
refs = list(ref_generator)
metadata = ray.get(refs.pop(-1))
assert len(metadata) == len(refs)
new_blocks += refs
new_metadata += metadata
else:
for block, metadata in zip(data_refs, results):
new_blocks.append(block)
new_metadata.append(metadata)
for ref_generator in results:
refs = list(ref_generator)
metadata = ray.get(refs.pop(-1))
assert len(metadata) == len(refs)
new_blocks += refs
new_metadata += metadata
return BlockList(
list(new_blocks),
list(new_metadata),
Expand Down
144 changes: 43 additions & 101 deletions python/ray/data/_internal/lazy_block_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,48 +310,29 @@ def _get_blocks_with_metadata(
This will block on the completion of the underlying read tasks and will fetch
all block metadata outputted by those tasks.
"""
context = DataContext.get_current()
block_refs, meta_refs = [], []
for block_ref, meta_ref in self._iter_block_partition_refs():
block_refs.append(block_ref)
meta_refs.append(meta_ref)
if context.block_splitting_enabled:
# If block splitting is enabled, fetch the partitions through generator.
read_progress_bar = ProgressBar("Read progress", total=len(block_refs))
# Handle duplicates (e.g. due to unioning the same dataset).
unique_refs = list(set(block_refs))
generators = read_progress_bar.fetch_until_complete(unique_refs)

ref_to_blocks = {}
ref_to_metadata = {}
for ref, generator in zip(unique_refs, generators):
refs_list = list(generator)
meta = ray.get(refs_list.pop(-1))
ref_to_blocks[ref] = refs_list
ref_to_metadata[ref] = meta

output_block_refs = []
for idx, ref in enumerate(block_refs):
output_block_refs += ref_to_blocks[ref]
self._cached_metadata[idx] = ref_to_metadata[ref]
return output_block_refs, self._flatten_metadata(self._cached_metadata)
if all(meta is not None for meta in self._cached_metadata):
# Short-circuit on cached metadata.
return block_refs, self._flatten_metadata(self._cached_metadata)
if not meta_refs:
# Short-circuit on empty set of block partitions.
assert not block_refs, block_refs
return [], []
read_progress_bar = ProgressBar("Read progress", total=len(meta_refs))
# Fetch the metadata in bulk.
# If block splitting is enabled, fetch the partitions through generator.
read_progress_bar = ProgressBar("Read progress", total=len(block_refs))
# Handle duplicates (e.g. due to unioning the same dataset).
unique_meta_refs = set(meta_refs)
metadata = read_progress_bar.fetch_until_complete(list(unique_meta_refs))
ref_to_data = {
meta_ref: data for meta_ref, data in zip(unique_meta_refs, metadata)
}
self._cached_metadata = [[ref_to_data[meta_ref]] for meta_ref in meta_refs]
return block_refs, self._flatten_metadata(self._cached_metadata)
unique_refs = list(set(block_refs))
generators = read_progress_bar.fetch_until_complete(unique_refs)

ref_to_blocks = {}
ref_to_metadata = {}
for ref, generator in zip(unique_refs, generators):
refs_list = list(generator)
meta = ray.get(refs_list.pop(-1))
ref_to_blocks[ref] = refs_list
ref_to_metadata[ref] = meta

output_block_refs = []
for idx, ref in enumerate(block_refs):
output_block_refs += ref_to_blocks[ref]
self._cached_metadata[idx] = ref_to_metadata[ref]
return output_block_refs, self._flatten_metadata(self._cached_metadata)

def compute_to_blocklist(self) -> BlockList:
"""Launch all tasks and return a concrete BlockList."""
Expand Down Expand Up @@ -392,16 +373,10 @@ def ensure_metadata_for_first_block(self) -> Optional[BlockMetadata]:
pass
else:
# This blocks until the underlying read task is finished.
if DataContext.get_current().block_splitting_enabled:
# If block splitting is enabled, get metadata as the last element
# in generator.
generator = ray.get(block_partition_ref)
blocks_ref = list(generator)
metadata = ray.get(blocks_ref[-1])
self._cached_metadata[0] = metadata
else:
metadata = ray.get(metadata_ref)
self._cached_metadata[0] = [metadata]
generator = ray.get(block_partition_ref)
blocks_ref = list(generator)
metadata = ray.get(blocks_ref[-1])
self._cached_metadata[0] = metadata
return metadata

def iter_blocks(self) -> Iterator[ObjectRef[Block]]:
Expand Down Expand Up @@ -449,7 +424,6 @@ def iter_blocks_with_metadata(
Returns:
An iterator of block references and the corresponding block metadata.
"""
context = DataContext.get_current()
outer = self

class Iter:
Expand All @@ -464,27 +438,15 @@ def __iter__(self):
def __next__(self):
while not self._buffer:
self._pos += 1
if context.block_splitting_enabled:
generator_ref, _ = next(self._base_iter)
generator = ray.get(generator_ref)
refs = list(generator)
# This blocks until the read task completes, returning
# fully-specified block metadata for each output block.
metadata = ray.get(refs.pop(-1))
assert len(metadata) == len(refs)
for block_ref, meta in zip(refs, metadata):
self._buffer.append((block_ref, meta))
else:
block_ref, metadata_ref = next(self._base_iter)
if block_for_metadata:
# This blocks until the read task completes, returning
# fully-specified block metadata.
metadata = ray.get(metadata_ref)
else:
# This does not block, returning (possibly under-specified)
# pre-read block metadata.
metadata = outer._tasks[self._pos].get_metadata()
self._buffer.append((block_ref, metadata))
generator_ref, _ = next(self._base_iter)
generator = ray.get(generator_ref)
refs = list(generator)
# This blocks until the read task completes, returning
# fully-specified block metadata for each output block.
metadata = ray.get(refs.pop(-1))
assert len(metadata) == len(refs)
for block_ref, meta in zip(refs, metadata):
self._buffer.append((block_ref, meta))
return self._buffer.pop(0)

return Iter()
Expand Down Expand Up @@ -571,12 +533,6 @@ def _get_or_compute(
self._block_partition_meta_refs[j],
) = self._submit_task(j)
assert self._block_partition_refs[i], self._block_partition_refs
if not DataContext.get_current().block_splitting_enabled:
# Only check block metadata object reference if dynamic block
# splitting is off.
assert self._block_partition_meta_refs[
i
], self._block_partition_meta_refs
trace_allocation(
self._block_partition_refs[i], f"LazyBlockList.get_or_compute({i})"
)
Expand All @@ -602,32 +558,18 @@ def _submit_task(
ray.get(stats_actor.record_start.remote(self._stats_uuid))
self._execution_started = True
task = self._tasks[task_idx]
context = DataContext.get_current()
if context.block_splitting_enabled:
return (
cached_remote_fn(_execute_read_task_split)
.options(num_returns="dynamic", **self._remote_args)
.remote(
i=task_idx,
task=task,
context=DataContext.get_current(),
stats_uuid=self._stats_uuid,
stats_actor=stats_actor,
),
None,
)
else:
return (
cached_remote_fn(_execute_read_task_nosplit)
.options(num_returns=2, **self._remote_args)
.remote(
i=task_idx,
task=task,
context=DataContext.get_current(),
stats_uuid=self._stats_uuid,
stats_actor=stats_actor,
)
)
return (
cached_remote_fn(_execute_read_task_split)
.options(num_returns="dynamic", **self._remote_args)
.remote(
i=task_idx,
task=task,
context=DataContext.get_current(),
stats_uuid=self._stats_uuid,
stats_actor=stats_actor,
),
None,
)

def _num_computed(self) -> int:
i = 0
Expand Down
16 changes: 6 additions & 10 deletions python/ray/data/_internal/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,16 +291,12 @@ def to_summary(self) -> "DatasetStatsSummary":
ac = self.stats_actor
# TODO(chengsu): this is a super hack, clean it up.
stats_map, self.time_total_s = ray.get(ac.get.remote(self.stats_uuid))
if DataContext.get_current().block_splitting_enabled:
# Only populate stats when stats from all read tasks are ready at
# stats actor.
if len(stats_map.items()) == len(self.stages["Read"]):
self.stages["Read"] = []
for _, blocks_metadata in sorted(stats_map.items()):
self.stages["Read"] += blocks_metadata
else:
for i, metadata in stats_map.items():
self.stages["Read"][i] = metadata[0]
# Only populate stats when stats from all read tasks are ready at
# stats actor.
if len(stats_map.items()) == len(self.stages["Read"]):
self.stages["Read"] = []
for _, blocks_metadata in sorted(stats_map.items()):
self.stages["Read"] += blocks_metadata

stages_stats = []
is_substage = len(self.stages) > 1
Expand Down
7 changes: 0 additions & 7 deletions python/ray/data/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@
# which is very sensitive to the buffer size.
DEFAULT_STREAMING_READ_BUFFER_SIZE = 32 * 1024 * 1024

# Whether dynamic block splitting is enabled.
# NOTE: disable dynamic block splitting when using Ray client.
DEFAULT_BLOCK_SPLITTING_ENABLED = True

# Whether pandas block format is enabled.
# TODO (kfstorm): Remove this once stable.
DEFAULT_ENABLE_PANDAS_BLOCK = True
Expand Down Expand Up @@ -147,7 +143,6 @@ class DataContext:

def __init__(
self,
block_splitting_enabled: bool,
target_max_block_size: int,
target_min_block_size: int,
streaming_read_buffer_size: int,
Expand Down Expand Up @@ -178,7 +173,6 @@ def __init__(
enable_progress_bars: bool,
):
"""Private constructor (use get_current() instead)."""
self.block_splitting_enabled = block_splitting_enabled
self.target_max_block_size = target_max_block_size
self.target_min_block_size = target_min_block_size
self.streaming_read_buffer_size = streaming_read_buffer_size
Expand Down Expand Up @@ -224,7 +218,6 @@ def get_current() -> "DataContext":
with _context_lock:
if _default_context is None:
_default_context = DataContext(
block_splitting_enabled=DEFAULT_BLOCK_SPLITTING_ENABLED,
target_max_block_size=DEFAULT_TARGET_MAX_BLOCK_SIZE,
target_min_block_size=DEFAULT_TARGET_MIN_BLOCK_SIZE,
streaming_read_buffer_size=DEFAULT_STREAMING_READ_BUFFER_SIZE,
Expand Down
12 changes: 2 additions & 10 deletions python/ray/data/datasource/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np

import ray
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data._internal.execution.interfaces import TaskContext
from ray.data._internal.util import _check_pyarrow_version
from ray.data.block import Block, BlockAccessor, BlockMetadata
Expand Down Expand Up @@ -204,7 +203,6 @@ def get_metadata(self) -> BlockMetadata:
return self._metadata

def __call__(self) -> Iterable[Block]:
context = DataContext.get_current()
result = self._read_fn()
if not hasattr(result, "__iter__"):
DeprecationWarning(
Expand All @@ -213,14 +211,8 @@ def __call__(self) -> Iterable[Block]:
"`block`.".format(result)
)

if context.block_splitting_enabled:
for block in result:
yield from self._do_additional_splits(block)
else:
builder = DelegatingBlockBuilder()
for block in result:
builder.add_block(block)
yield builder.build()
for block in result:
yield from self._do_additional_splits(block)

def _set_additional_split_factor(self, k: int) -> None:
self._additional_output_splits = k
Expand Down
9 changes: 0 additions & 9 deletions python/ray/data/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,15 +327,6 @@ def enable_auto_log_stats(request):
ctx.enable_auto_log_stats = original


@pytest.fixture(params=[True])
def enable_dynamic_block_splitting(request):
ctx = ray.data.context.DataContext.get_current()
original = ctx.block_splitting_enabled
ctx.block_splitting_enabled = request.param
yield request.param
ctx.block_splitting_enabled = original


@pytest.fixture(params=[1024])
def target_max_block_size(request):
ctx = ray.data.context.DataContext.get_current()
Expand Down
Loading

0 comments on commit 8afcf9b

Please sign in to comment.