Skip to content

Commit

Permalink
Make write an operator as part of the execution plan (#32015)
Browse files Browse the repository at this point in the history
  • Loading branch information
jianoaix committed Feb 8, 2023
1 parent befad81 commit aa504ae
Show file tree
Hide file tree
Showing 9 changed files with 207 additions and 130 deletions.
90 changes: 61 additions & 29 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2640,8 +2640,6 @@ def write_datasource(
ray_remote_args: Kwargs passed to ray.remote in the write tasks.
write_args: Additional write args to pass to the datasource.
"""

ctx = DatasetContext.get_current()
if ray_remote_args is None:
ray_remote_args = {}
path = write_args.get("path", None)
Expand All @@ -2655,37 +2653,71 @@ def write_datasource(
soft=False,
)

blocks, metadata = zip(*self._plan.execute().get_blocks_with_metadata())

# TODO(ekl) remove this feature flag.
if "RAY_DATASET_FORCE_LOCAL_METADATA" in os.environ:
write_results: List[ObjectRef[WriteResult]] = datasource.do_write(
blocks, metadata, ray_remote_args=ray_remote_args, **write_args
)
else:
# Prepare write in a remote task so that in Ray client mode, we
# don't do metadata resolution from the client machine.
do_write = cached_remote_fn(_do_write, retry_exceptions=False, num_cpus=0)
write_results: List[ObjectRef[WriteResult]] = ray.get(
do_write.remote(
datasource,
ctx,
blocks,
metadata,
if hasattr(datasource, "write"):
# If the write operator succeeds, the resulting Dataset is a list of
# WriteResult (one element per write task). Otherwise, an error will
# be raised. The Datasource can handle execution outcomes with the
# on_write_complete() and on_write_failed().
def transform(blocks: Iterable[Block], ctx, fn) -> Iterable[Block]:
return [[datasource.write(blocks, ctx, **write_args)]]

plan = self._plan.with_stage(
OneToOneStage(
"write",
transform,
"tasks",
ray_remote_args,
_wrap_arrow_serialization_workaround(write_args),
fn=lambda x: x,
)
)
try:
self._write_ds = Dataset(plan, self._epoch, self._lazy).fully_executed()
datasource.on_write_complete(
ray.get(self._write_ds._plan.execute().get_blocks())
)
except Exception as e:
datasource.on_write_failed([], e)
raise
else:
ctx = DatasetContext.get_current()
blocks, metadata = zip(*self._plan.execute().get_blocks_with_metadata())

progress = ProgressBar("Write Progress", len(write_results))
try:
progress.block_until_complete(write_results)
datasource.on_write_complete(ray.get(write_results))
except Exception as e:
datasource.on_write_failed(write_results, e)
raise
finally:
progress.close()
# TODO(ekl) remove this feature flag.
if "RAY_DATASET_FORCE_LOCAL_METADATA" in os.environ:
write_results: List[ObjectRef[WriteResult]] = datasource.do_write(
blocks, metadata, ray_remote_args=ray_remote_args, **write_args
)
else:
logger.warning(
"The Datasource.do_write() is deprecated in "
"Ray 2.4 and will be removed in future release. Use "
"Datasource.write() instead."
)
# Prepare write in a remote task so that in Ray client mode, we
# don't do metadata resolution from the client machine.
do_write = cached_remote_fn(
_do_write, retry_exceptions=False, num_cpus=0
)
write_results: List[ObjectRef[WriteResult]] = ray.get(
do_write.remote(
datasource,
ctx,
blocks,
metadata,
ray_remote_args,
_wrap_arrow_serialization_workaround(write_args),
)
)

progress = ProgressBar("Write Progress", len(write_results))
try:
progress.block_until_complete(write_results)
datasource.on_write_complete(ray.get(write_results))
except Exception as e:
datasource.on_write_failed(write_results, e)
raise
finally:
progress.close()

def iterator(self) -> DatasetIterator:
"""Return a :class:`~ray.data.DatasetIterator` that
Expand Down
44 changes: 31 additions & 13 deletions python/ray/data/datasource/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import ray
from ray.data._internal.arrow_block import ArrowRow
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data._internal.execution.interfaces import TaskContext
from ray.data._internal.util import _check_pyarrow_version
from ray.data.block import (
Block,
Expand All @@ -31,7 +32,7 @@ class Datasource(Generic[T]):
of how to implement readable and writable datasources.
Datasource instances must be serializable, since ``create_reader()`` and
``do_write()`` are called in remote tasks.
``write()`` are called in remote tasks.
"""

def create_reader(self, **read_args) -> "Reader[T]":
Expand All @@ -50,6 +51,25 @@ def prepare_read(self, parallelism: int, **read_args) -> List["ReadTask[T]"]:
"""Deprecated: Please implement create_reader() instead."""
raise NotImplementedError

def write(
self,
blocks: Iterable[Block],
**write_args,
) -> WriteResult:
"""Write blocks out to the datasource. This is used by a single write task.
Args:
blocks: List of data blocks.
write_args: Additional kwargs to pass to the datasource impl.
Returns:
The output of the write task.
"""
raise NotImplementedError

@Deprecated(
message="do_write() is deprecated in Ray 2.4. Use write() instead", warning=True
)
def do_write(
self,
blocks: List[ObjectRef[Block]],
Expand Down Expand Up @@ -319,35 +339,33 @@ def __init__(self):

def write(self, block: Block) -> str:
block = BlockAccessor.for_block(block)
if not self.enabled:
raise ValueError("disabled")
self.rows_written += block.num_rows()
return "ok"

def get_rows_written(self):
return self.rows_written

def set_enabled(self, enabled):
self.enabled = enabled

self.data_sink = DataSink.remote()
self.num_ok = 0
self.num_failed = 0
self.enabled = True

def do_write(
def write(
self,
blocks: List[ObjectRef[Block]],
metadata: List[BlockMetadata],
ray_remote_args: Dict[str, Any],
blocks: Iterable[Block],
ctx: TaskContext,
**write_args,
) -> List[ObjectRef[WriteResult]]:
) -> WriteResult:
tasks = []
if not self.enabled:
raise ValueError("disabled")
for b in blocks:
tasks.append(self.data_sink.write.remote(b))
return tasks
ray.get(tasks)
return "ok"

def on_write_complete(self, write_results: List[WriteResult]) -> None:
assert all(w == "ok" for w in write_results), write_results
assert all(w == ["ok"] for w in write_results), write_results
self.num_ok += 1

def on_write_failed(
Expand Down
57 changes: 27 additions & 30 deletions python/ray/data/datasource/file_based_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
)

from ray.data._internal.arrow_block import ArrowRow
from ray.data._internal.block_list import BlockMetadata
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data._internal.execution.interfaces import TaskContext
from ray.data._internal.output_buffer import BlockOutputBuffer
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.util import _check_pyarrow_version, _resolve_custom_scheme
from ray.data.block import Block, BlockAccessor
from ray.data.context import DatasetContext
Expand Down Expand Up @@ -60,7 +60,7 @@ def _get_write_path_for_block(
*,
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
dataset_uuid: Optional[str] = None,
block: Optional[ObjectRef[Block]] = None,
block: Optional[Block] = None,
block_index: Optional[int] = None,
file_format: Optional[str] = None,
) -> str:
Expand All @@ -77,7 +77,7 @@ def _get_write_path_for_block(
write a file out to the write path returned.
dataset_uuid: Unique identifier for the dataset that this block
belongs to.
block: Object reference to the block to write.
block: The block to write.
block_index: Ordered index of the block to write within its parent
dataset.
file_format: File format string for the block that can be used as
Expand All @@ -94,7 +94,7 @@ def __call__(
*,
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
dataset_uuid: Optional[str] = None,
block: Optional[ObjectRef[Block]] = None,
block: Optional[Block] = None,
block_index: Optional[int] = None,
file_format: Optional[str] = None,
) -> str:
Expand Down Expand Up @@ -257,10 +257,10 @@ def _convert_block_to_tabular_block(
"then you need to implement `_convert_block_to_tabular_block."
)

def do_write(
def write(
self,
blocks: List[ObjectRef[Block]],
metadata: List[BlockMetadata],
blocks: Iterable[Block],
ctx: TaskContext,
path: str,
dataset_uuid: str,
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
Expand All @@ -269,10 +269,9 @@ def do_write(
block_path_provider: BlockWritePathProvider = DefaultBlockWritePathProvider(),
write_args_fn: Callable[[], Dict[str, Any]] = lambda: {},
_block_udf: Optional[Callable[[Block], Block]] = None,
ray_remote_args: Dict[str, Any] = None,
**write_args,
) -> List[ObjectRef[WriteResult]]:
"""Creates and returns write tasks for a file-based datasource."""
) -> WriteResult:
"""Write blocks for a file-based datasource."""
path, filesystem = _resolve_paths_and_filesystem(path, filesystem)
path = path[0]
if try_create_dir:
Expand All @@ -287,9 +286,6 @@ def do_write(
if open_stream_args is None:
open_stream_args = {}

if ray_remote_args is None:
ray_remote_args = {}

def write_block(write_path: str, block: Block):
logger.debug(f"Writing {write_path} file.")
fs = filesystem
Expand All @@ -305,29 +301,30 @@ def write_block(write_path: str, block: Block):
writer_args_fn=write_args_fn,
**write_args,
)

write_block = cached_remote_fn(write_block).options(**ray_remote_args)
# TODO: decide if we want to return richer object when the task
# succeeds.
return "ok"

file_format = self._FILE_EXTENSION
if isinstance(file_format, list):
file_format = file_format[0]

write_tasks = []
builder = DelegatingBlockBuilder()
for block in blocks:
builder.add_block(block)
block = builder.build()

if not block_path_provider:
block_path_provider = DefaultBlockWritePathProvider()
for block_idx, block in enumerate(blocks):
write_path = block_path_provider(
path,
filesystem=filesystem,
dataset_uuid=dataset_uuid,
block=block,
block_index=block_idx,
file_format=file_format,
)
write_task = write_block.remote(write_path, block)
write_tasks.append(write_task)

return write_tasks
write_path = block_path_provider(
path,
filesystem=filesystem,
dataset_uuid=dataset_uuid,
block=block,
block_index=ctx.task_idx,
file_format=file_format,
)
return write_block(write_path, block)

def _write_block(
self,
Expand Down
33 changes: 17 additions & 16 deletions python/ray/data/datasource/mongo_datasource.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import logging
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from typing import Dict, List, Optional, TYPE_CHECKING

from ray.data.datasource.datasource import Datasource, Reader, ReadTask, WriteResult
from ray.data.block import (
Block,
BlockAccessor,
BlockMetadata,
)
from ray.data._internal.remote_fn import cached_remote_fn
from ray.types import ObjectRef
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data._internal.execution.interfaces import TaskContext
from ray.util.annotations import PublicAPI
from typing import Iterable

if TYPE_CHECKING:
import pymongoarrow.api
Expand Down Expand Up @@ -37,15 +38,14 @@ class MongoDatasource(Datasource):
def create_reader(self, **kwargs) -> Reader:
return _MongoDatasourceReader(**kwargs)

def do_write(
def write(
self,
blocks: List[ObjectRef[Block]],
metadata: List[BlockMetadata],
ray_remote_args: Optional[Dict[str, Any]],
blocks: Iterable[Block],
ctx: TaskContext,
uri: str,
database: str,
collection: str,
) -> List[ObjectRef[WriteResult]]:
) -> WriteResult:
import pymongo

_validate_database_collection_exist(
Expand All @@ -59,15 +59,16 @@ def write_block(uri: str, database: str, collection: str, block: Block):
client = pymongo.MongoClient(uri)
write(client[database][collection], block)

if ray_remote_args is None:
ray_remote_args = {}

write_block = cached_remote_fn(write_block).options(**ray_remote_args)
write_tasks = []
builder = DelegatingBlockBuilder()
for block in blocks:
write_task = write_block.remote(uri, database, collection, block)
write_tasks.append(write_task)
return write_tasks
builder.add_block(block)
block = builder.build()

write_block(uri, database, collection, block)

# TODO: decide if we want to return richer object when the task
# succeeds.
return "ok"


class _MongoDatasourceReader(Reader):
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _get_write_path_for_block(
block_index=None,
file_format=None,
):
num_rows = BlockAccessor.for_block(ray.get(block)).num_rows()
num_rows = BlockAccessor.for_block(block).num_rows()
suffix = (
f"{block_index:06}_{num_rows:02}_{dataset_uuid}" f".test.{file_format}"
)
Expand Down
Loading

0 comments on commit aa504ae

Please sign in to comment.