Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make write an operator as part of the execution plan #32015

Merged
merged 51 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
edc51bd
Fix read_tfrecords_benchmark nightly test
jianoaix Dec 8, 2022
61f4d6d
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Dec 14, 2022
a33a943
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Dec 16, 2022
36ebe52
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Dec 16, 2022
ce6763e
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Dec 19, 2022
0e2c29e
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Dec 21, 2022
f2b6ed0
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Dec 22, 2022
bb6c5c4
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 4, 2023
540fe79
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 10, 2023
edad7d0
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 10, 2023
60cc079
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 11, 2023
a3d3980
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 12, 2023
001579c
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 17, 2023
8aeed6c
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 18, 2023
7a9a49b
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 19, 2023
ef97167
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 20, 2023
6f0563c
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 21, 2023
bcec4d6
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 24, 2023
ddef4e5
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 25, 2023
fc9a175
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 25, 2023
f0e90b7
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 26, 2023
0c820ea
Merge branch 'master' of https://github.com/ray-project/ray into writ…
jianoaix Jan 26, 2023
253da6a
Make write an operator as part of the execution plan
jianoaix Jan 27, 2023
514ec14
fix / fix do_write
jianoaix Jan 28, 2023
204ea5f
Merge branch 'master' of https://github.com/ray-project/ray into writ…
jianoaix Jan 28, 2023
b75bf49
Fix the merge
jianoaix Jan 28, 2023
6a28257
fix arg passing
jianoaix Jan 28, 2023
3082e52
lint
jianoaix Jan 28, 2023
5ddfdc0
Reconcile taskcontext
jianoaix Jan 31, 2023
3843ff4
Reconcile taskcontext continued
jianoaix Jan 31, 2023
012a413
Merge branch 'master' of https://github.com/ray-project/ray into writ…
jianoaix Jan 31, 2023
84a74f0
Use task context in write op
jianoaix Jan 31, 2023
bb2a474
fix test
jianoaix Jan 31, 2023
ad5f7c7
feedback: backward compatibility
jianoaix Jan 31, 2023
a77053a
fix
jianoaix Jan 31, 2023
554171a
test write fusion
jianoaix Jan 31, 2023
1ba1b9f
Result of write operator; datasource callbacks
jianoaix Feb 1, 2023
5ced246
Handle an empty list on failure
jianoaix Feb 3, 2023
43eca29
execute the plan in-place in write_datasource
jianoaix Feb 3, 2023
f25d54b
Keep write_datasource semantics diff-neutral regarding the plan
jianoaix Feb 3, 2023
c5ddf07
Merge branch 'master' of https://github.com/ray-project/ray into writ…
jianoaix Feb 3, 2023
1d58e13
disable the write_XX in new optimizer: it's not supported yet
jianoaix Feb 3, 2023
d309dbd
fix comment
jianoaix Feb 3, 2023
21a50db
refactor: do_write() calls direct_write() to reduce code duplication
jianoaix Feb 4, 2023
a84e27b
refactor: for mongo datasource do_write
jianoaix Feb 4, 2023
8879df0
backward compatible
jianoaix Feb 7, 2023
d6873e1
rename: direct_write -> write
jianoaix Feb 7, 2023
10ef980
unnecessary test removed
jianoaix Feb 7, 2023
48e9415
fix
jianoaix Feb 7, 2023
87dc925
deprecation message/logging
jianoaix Feb 7, 2023
b77ca8d
deprecation logging
jianoaix Feb 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 61 additions & 29 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2640,8 +2640,6 @@ def write_datasource(
ray_remote_args: Kwargs passed to ray.remote in the write tasks.
write_args: Additional write args to pass to the datasource.
"""

ctx = DatasetContext.get_current()
if ray_remote_args is None:
ray_remote_args = {}
path = write_args.get("path", None)
Expand All @@ -2655,37 +2653,71 @@ def write_datasource(
soft=False,
)

blocks, metadata = zip(*self._plan.execute().get_blocks_with_metadata())

# TODO(ekl) remove this feature flag.
if "RAY_DATASET_FORCE_LOCAL_METADATA" in os.environ:
write_results: List[ObjectRef[WriteResult]] = datasource.do_write(
blocks, metadata, ray_remote_args=ray_remote_args, **write_args
)
else:
# Prepare write in a remote task so that in Ray client mode, we
# don't do metadata resolution from the client machine.
do_write = cached_remote_fn(_do_write, retry_exceptions=False, num_cpus=0)
write_results: List[ObjectRef[WriteResult]] = ray.get(
do_write.remote(
datasource,
ctx,
blocks,
metadata,
if hasattr(datasource, "write"):
# If the write operator succeeds, the resulting Dataset is a list of
# WriteResult (one element per write task). Otherwise, an error will
# be raised. The Datasource can handle execution outcomes with the
# on_write_complete() and on_write_failed().
def transform(blocks: Iterable[Block], ctx, fn) -> Iterable[Block]:
return [[datasource.write(blocks, ctx, **write_args)]]

plan = self._plan.with_stage(
OneToOneStage(
"write",
transform,
"tasks",
ray_remote_args,
_wrap_arrow_serialization_workaround(write_args),
fn=lambda x: x,
)
)
try:
self._write_ds = Dataset(plan, self._epoch, self._lazy).fully_executed()
datasource.on_write_complete(
ray.get(self._write_ds._plan.execute().get_blocks())
)
except Exception as e:
datasource.on_write_failed([], e)
raise
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't the same semantics right? on_write_complete/failed should be called on the driver only, not once per block.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is changed to per block/task. I'm not sure how to properly support this with write becoming an operator that can be fused with prior operators.

Do these callbacks have to be on the driver? It looks they are not quite used yet, so unclear the use case. But as callbacks for each write task seems a fine semantics?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'd need these to implement commit semantics / rollback changes on failure, for example. They are part of the DataSource public API, so we shouldn't change them without a separate API discussion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it's too rush for 2.3 release if we need to change Datasource API at this moment, and agree that separate discussion needs to happen if we change Datasource API here. How about we taking a non-intrusive approach to not change Datasource API at all?

How about we are doing this:

if hasattr(datasource, "_direct_write"):
  def transform(blocks: Iterable[Block], ctx) -> Iterator[Block]:
    yield from datasource._direct_write(blocks, ctx, **write_args)
  self._plan = self._plan.with_stage(transform, ...)
  write_results = self._plan.execute().get_blocks()
else:
  # Old existing code path, or we implement a AllToAllStage, that should give us Block and BlockMetadata
  write_results = ...

try:
  progress.block_until_complete(write_results)
  datasource.on_write_complete(ray.get(write_results))
except Exception as e:
  datasource.on_write_failed(write_results, e)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no issue blocked by not having write operator/fusion at the moment, so I'd fix it properly instead of having a half baked feature (2.3 is not the ultimate cut to catch anyway).

Copy link
Contributor

@ericl ericl Feb 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the status Dataset approach still works right? Errors can get raised as usual, so retries will work. In the write function here we can add a try catch around the entire Dataset execution operation, and call the right data source callbacks if the entire thing fails / succeeds.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, do you mean it's hard to provide List[ObjectRef[WriteResult]] to the error callback? I think we can probably deprecate this part of the API for the new backend (pass empty list). Afaik it's not really useful and would be best effort in any scenario for a distributed system.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, we will not know the resulting status of each individual tasks, which I think may be used for error handling (rollback etc.) in the callback function. If we just try catch the entire execution, and just produce an empty list (as the result of executing write operator), then yes it should be able to work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we go for that then? It seems the simplest approach, and we can also fix the new API to not have the list of error refs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I'll update the PR.

else:
ctx = DatasetContext.get_current()
blocks, metadata = zip(*self._plan.execute().get_blocks_with_metadata())

progress = ProgressBar("Write Progress", len(write_results))
try:
progress.block_until_complete(write_results)
datasource.on_write_complete(ray.get(write_results))
except Exception as e:
datasource.on_write_failed(write_results, e)
raise
finally:
progress.close()
# TODO(ekl) remove this feature flag.
if "RAY_DATASET_FORCE_LOCAL_METADATA" in os.environ:
jianoaix marked this conversation as resolved.
Show resolved Hide resolved
write_results: List[ObjectRef[WriteResult]] = datasource.do_write(
blocks, metadata, ray_remote_args=ray_remote_args, **write_args
)
else:
logger.warning(
"The Datasource.do_write() is deprecated in "
"Ray 2.4 and will be removed in future release. Use "
"Datasource.write() instead."
)
# Prepare write in a remote task so that in Ray client mode, we
# don't do metadata resolution from the client machine.
do_write = cached_remote_fn(
_do_write, retry_exceptions=False, num_cpus=0
)
write_results: List[ObjectRef[WriteResult]] = ray.get(
do_write.remote(
datasource,
ctx,
blocks,
metadata,
ray_remote_args,
_wrap_arrow_serialization_workaround(write_args),
)
)

progress = ProgressBar("Write Progress", len(write_results))
try:
progress.block_until_complete(write_results)
datasource.on_write_complete(ray.get(write_results))
except Exception as e:
datasource.on_write_failed(write_results, e)
raise
finally:
progress.close()

def iterator(self) -> DatasetIterator:
"""Return a :class:`~ray.data.DatasetIterator` that
Expand Down
44 changes: 31 additions & 13 deletions python/ray/data/datasource/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import ray
from ray.data._internal.arrow_block import ArrowRow
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data._internal.execution.interfaces import TaskContext
from ray.data._internal.util import _check_pyarrow_version
from ray.data.block import (
Block,
Expand All @@ -31,7 +32,7 @@ class Datasource(Generic[T]):
of how to implement readable and writable datasources.

Datasource instances must be serializable, since ``create_reader()`` and
``do_write()`` are called in remote tasks.
``write()`` are called in remote tasks.
"""

def create_reader(self, **read_args) -> "Reader[T]":
Expand All @@ -50,6 +51,25 @@ def prepare_read(self, parallelism: int, **read_args) -> List["ReadTask[T]"]:
"""Deprecated: Please implement create_reader() instead."""
raise NotImplementedError

def write(
self,
blocks: Iterable[Block],
**write_args,
) -> WriteResult:
"""Write blocks out to the datasource. This is used by a single write task.

Args:
blocks: List of data blocks.
write_args: Additional kwargs to pass to the datasource impl.

Returns:
The output of the write task.
"""
raise NotImplementedError

@Deprecated(
message="do_write() is deprecated in Ray 2.4. Use write() instead", warning=True
)
def do_write(
self,
blocks: List[ObjectRef[Block]],
Expand Down Expand Up @@ -319,35 +339,33 @@ def __init__(self):

def write(self, block: Block) -> str:
block = BlockAccessor.for_block(block)
if not self.enabled:
raise ValueError("disabled")
self.rows_written += block.num_rows()
return "ok"

def get_rows_written(self):
return self.rows_written

def set_enabled(self, enabled):
self.enabled = enabled

self.data_sink = DataSink.remote()
self.num_ok = 0
self.num_failed = 0
self.enabled = True

def do_write(
def write(
self,
blocks: List[ObjectRef[Block]],
metadata: List[BlockMetadata],
ray_remote_args: Dict[str, Any],
blocks: Iterable[Block],
ctx: TaskContext,
**write_args,
) -> List[ObjectRef[WriteResult]]:
) -> WriteResult:
tasks = []
if not self.enabled:
raise ValueError("disabled")
for b in blocks:
tasks.append(self.data_sink.write.remote(b))
return tasks
ray.get(tasks)
return "ok"

def on_write_complete(self, write_results: List[WriteResult]) -> None:
assert all(w == "ok" for w in write_results), write_results
assert all(w == ["ok"] for w in write_results), write_results
self.num_ok += 1

def on_write_failed(
Expand Down
57 changes: 27 additions & 30 deletions python/ray/data/datasource/file_based_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
)

from ray.data._internal.arrow_block import ArrowRow
from ray.data._internal.block_list import BlockMetadata
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data._internal.execution.interfaces import TaskContext
from ray.data._internal.output_buffer import BlockOutputBuffer
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.util import _check_pyarrow_version, _resolve_custom_scheme
from ray.data.block import Block, BlockAccessor
from ray.data.context import DatasetContext
Expand Down Expand Up @@ -60,7 +60,7 @@ def _get_write_path_for_block(
*,
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
dataset_uuid: Optional[str] = None,
block: Optional[ObjectRef[Block]] = None,
block: Optional[Block] = None,
block_index: Optional[int] = None,
file_format: Optional[str] = None,
) -> str:
Expand All @@ -77,7 +77,7 @@ def _get_write_path_for_block(
write a file out to the write path returned.
dataset_uuid: Unique identifier for the dataset that this block
belongs to.
block: Object reference to the block to write.
block: The block to write.
block_index: Ordered index of the block to write within its parent
dataset.
file_format: File format string for the block that can be used as
Expand All @@ -94,7 +94,7 @@ def __call__(
*,
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
dataset_uuid: Optional[str] = None,
block: Optional[ObjectRef[Block]] = None,
block: Optional[Block] = None,
block_index: Optional[int] = None,
file_format: Optional[str] = None,
) -> str:
Expand Down Expand Up @@ -257,10 +257,10 @@ def _convert_block_to_tabular_block(
"then you need to implement `_convert_block_to_tabular_block."
)

def do_write(
def write(
self,
blocks: List[ObjectRef[Block]],
metadata: List[BlockMetadata],
blocks: Iterable[Block],
ctx: TaskContext,
path: str,
dataset_uuid: str,
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
Expand All @@ -269,10 +269,9 @@ def do_write(
block_path_provider: BlockWritePathProvider = DefaultBlockWritePathProvider(),
write_args_fn: Callable[[], Dict[str, Any]] = lambda: {},
_block_udf: Optional[Callable[[Block], Block]] = None,
ray_remote_args: Dict[str, Any] = None,
**write_args,
) -> List[ObjectRef[WriteResult]]:
"""Creates and returns write tasks for a file-based datasource."""
) -> WriteResult:
"""Write blocks for a file-based datasource."""
path, filesystem = _resolve_paths_and_filesystem(path, filesystem)
path = path[0]
if try_create_dir:
Expand All @@ -287,9 +286,6 @@ def do_write(
if open_stream_args is None:
open_stream_args = {}

if ray_remote_args is None:
ray_remote_args = {}

def write_block(write_path: str, block: Block):
logger.debug(f"Writing {write_path} file.")
fs = filesystem
Expand All @@ -305,29 +301,30 @@ def write_block(write_path: str, block: Block):
writer_args_fn=write_args_fn,
**write_args,
)

write_block = cached_remote_fn(write_block).options(**ray_remote_args)
# TODO: decide if we want to return richer object when the task
# succeeds.
return "ok"

file_format = self._FILE_EXTENSION
if isinstance(file_format, list):
file_format = file_format[0]

write_tasks = []
builder = DelegatingBlockBuilder()
for block in blocks:
builder.add_block(block)
block = builder.build()

if not block_path_provider:
block_path_provider = DefaultBlockWritePathProvider()
for block_idx, block in enumerate(blocks):
write_path = block_path_provider(
path,
filesystem=filesystem,
dataset_uuid=dataset_uuid,
block=block,
block_index=block_idx,
file_format=file_format,
)
write_task = write_block.remote(write_path, block)
write_tasks.append(write_task)

return write_tasks
write_path = block_path_provider(
path,
filesystem=filesystem,
dataset_uuid=dataset_uuid,
block=block,
block_index=ctx.task_idx,
file_format=file_format,
)
return write_block(write_path, block)

def _write_block(
self,
Expand Down
33 changes: 17 additions & 16 deletions python/ray/data/datasource/mongo_datasource.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import logging
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from typing import Dict, List, Optional, TYPE_CHECKING

from ray.data.datasource.datasource import Datasource, Reader, ReadTask, WriteResult
from ray.data.block import (
Block,
BlockAccessor,
BlockMetadata,
)
from ray.data._internal.remote_fn import cached_remote_fn
from ray.types import ObjectRef
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data._internal.execution.interfaces import TaskContext
from ray.util.annotations import PublicAPI
from typing import Iterable

if TYPE_CHECKING:
import pymongoarrow.api
Expand Down Expand Up @@ -37,15 +38,14 @@ class MongoDatasource(Datasource):
def create_reader(self, **kwargs) -> Reader:
return _MongoDatasourceReader(**kwargs)

def do_write(
def write(
self,
blocks: List[ObjectRef[Block]],
metadata: List[BlockMetadata],
ray_remote_args: Optional[Dict[str, Any]],
blocks: Iterable[Block],
ctx: TaskContext,
uri: str,
database: str,
collection: str,
) -> List[ObjectRef[WriteResult]]:
) -> WriteResult:
import pymongo

_validate_database_collection_exist(
Expand All @@ -59,15 +59,16 @@ def write_block(uri: str, database: str, collection: str, block: Block):
client = pymongo.MongoClient(uri)
write(client[database][collection], block)

if ray_remote_args is None:
ray_remote_args = {}

write_block = cached_remote_fn(write_block).options(**ray_remote_args)
write_tasks = []
builder = DelegatingBlockBuilder()
for block in blocks:
write_task = write_block.remote(uri, database, collection, block)
write_tasks.append(write_task)
return write_tasks
builder.add_block(block)
block = builder.build()

write_block(uri, database, collection, block)

# TODO: decide if we want to return richer object when the task
# succeeds.
return "ok"


class _MongoDatasourceReader(Reader):
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _get_write_path_for_block(
block_index=None,
file_format=None,
):
num_rows = BlockAccessor.for_block(ray.get(block)).num_rows()
num_rows = BlockAccessor.for_block(block).num_rows()
suffix = (
f"{block_index:06}_{num_rows:02}_{dataset_uuid}" f".test.{file_format}"
)
Expand Down
Loading