Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make write an operator as part of the execution plan #32015

Merged
merged 51 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
edc51bd
Fix read_tfrecords_benchmark nightly test
jianoaix Dec 8, 2022
61f4d6d
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Dec 14, 2022
a33a943
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Dec 16, 2022
36ebe52
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Dec 16, 2022
ce6763e
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Dec 19, 2022
0e2c29e
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Dec 21, 2022
f2b6ed0
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Dec 22, 2022
bb6c5c4
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 4, 2023
540fe79
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 10, 2023
edad7d0
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 10, 2023
60cc079
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 11, 2023
a3d3980
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 12, 2023
001579c
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 17, 2023
8aeed6c
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 18, 2023
7a9a49b
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 19, 2023
ef97167
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 20, 2023
6f0563c
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 21, 2023
bcec4d6
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 24, 2023
ddef4e5
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 25, 2023
fc9a175
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 25, 2023
f0e90b7
Merge branch 'master' of https://github.com/ray-project/ray
jianoaix Jan 26, 2023
0c820ea
Merge branch 'master' of https://github.com/ray-project/ray into writ…
jianoaix Jan 26, 2023
253da6a
Make write an operator as part of the execution plan
jianoaix Jan 27, 2023
514ec14
fix / fix do_write
jianoaix Jan 28, 2023
204ea5f
Merge branch 'master' of https://github.com/ray-project/ray into writ…
jianoaix Jan 28, 2023
b75bf49
Fix the merge
jianoaix Jan 28, 2023
6a28257
fix arg passing
jianoaix Jan 28, 2023
3082e52
lint
jianoaix Jan 28, 2023
5ddfdc0
Reconcile taskcontext
jianoaix Jan 31, 2023
3843ff4
Reconcile taskcontext continued
jianoaix Jan 31, 2023
012a413
Merge branch 'master' of https://github.com/ray-project/ray into writ…
jianoaix Jan 31, 2023
84a74f0
Use task context in write op
jianoaix Jan 31, 2023
bb2a474
fix test
jianoaix Jan 31, 2023
ad5f7c7
feedback: backward compatibility
jianoaix Jan 31, 2023
a77053a
fix
jianoaix Jan 31, 2023
554171a
test write fusion
jianoaix Jan 31, 2023
1ba1b9f
Result of write operator; datasource callbacks
jianoaix Feb 1, 2023
5ced246
Handle an empty list on failure
jianoaix Feb 3, 2023
43eca29
execute the plan in-place in write_datasource
jianoaix Feb 3, 2023
f25d54b
Keep write_datasource semantics diff-neutral regarding the plan
jianoaix Feb 3, 2023
c5ddf07
Merge branch 'master' of https://github.com/ray-project/ray into writ…
jianoaix Feb 3, 2023
1d58e13
disable the write_XX in new optimizer: it's not supported yet
jianoaix Feb 3, 2023
d309dbd
fix comment
jianoaix Feb 3, 2023
21a50db
refactor: do_write() calls direct_write() to reduce code duplication
jianoaix Feb 4, 2023
a84e27b
refactor: for mongo datasource do_write
jianoaix Feb 4, 2023
8879df0
backward compatible
jianoaix Feb 7, 2023
d6873e1
rename: direct_write -> write
jianoaix Feb 7, 2023
10ef980
unnecessary test removed
jianoaix Feb 7, 2023
48e9415
fix
jianoaix Feb 7, 2023
87dc925
deprecation message/logging
jianoaix Feb 7, 2023
b77ca8d
deprecation logging
jianoaix Feb 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix / fix do_write
  • Loading branch information
jianoaix committed Jan 28, 2023
commit 514ec14756cbc02a06d54729e2fc95fad3a8f6bc
18 changes: 1 addition & 17 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,6 @@
TFRecordDatasource,
WriteResult,
)
from ray.data.datasource.file_based_datasource import (
_wrap_arrow_serialization_workaround,
)
from ray.data.random_access_dataset import RandomAccessDataset
from ray.data.row import TableRow
from ray.types import ObjectRef
Expand Down Expand Up @@ -2674,7 +2671,7 @@ def write_datasource(

def transform(blocks: Iterable[Block], task_idx, fn) -> []:
try:
datasource.sync_write(blocks, task_idx, **write_args)
datasource.do_write(blocks, task_idx, **write_args)
datasource.on_write_complete([])
except Exception as e:
datasource.on_write_failed([], e)
Expand Down Expand Up @@ -4414,16 +4411,3 @@ def _sliding_window(iterable: Iterable, n: int):
for elem in it:
window.append(elem)
yield tuple(window)


def _do_write(
ds: Datasource,
ctx: DatasetContext,
blocks: List[Block],
meta: List[BlockMetadata],
ray_remote_args: Dict[str, Any],
write_args: Dict[str, Any],
) -> List[ObjectRef[WriteResult]]:
write_args = _unwrap_arrow_serialization_workaround(write_args)
DatasetContext._set_current(ctx)
return ds.do_write(blocks, meta, ray_remote_args=ray_remote_args, **write_args)
31 changes: 7 additions & 24 deletions python/ray/data/datasource/datasource.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import builtins
from typing import Any, Callable, Dict, Generic, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Generic, Iterable, List, Optional, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -52,22 +52,17 @@ def prepare_read(self, parallelism: int, **read_args) -> List["ReadTask[T]"]:

def do_write(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to make this change backwards compatible, like we do for the reader API?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a new method for direct writing out, and deprecated do_write().

self,
blocks: List[ObjectRef[Block]],
metadata: List[BlockMetadata],
ray_remote_args: Dict[str, Any],
blocks: Iterable[Block],
**write_args,
) -> List[ObjectRef[WriteResult]]:
"""Launch Ray tasks for writing blocks out to the datasource.
) -> WriteResult:
"""Write blocks out to the datasource. This is used by a single write task.

Args:
blocks: List of data block references. It is recommended that one
write task be generated per block.
metadata: List of block metadata.
ray_remote_args: Kwargs passed to ray.remote in the write tasks.
blocks: List of data blocks.
write_args: Additional kwargs to pass to the datasource impl.

Returns:
A list of the output of the write tasks.
The output of the write tasks.
"""
raise NotImplementedError

Expand Down Expand Up @@ -346,7 +341,7 @@ def get_num_failed(self):

self.data_sink = DataSink.remote()

def sync_write(
def do_write(
self,
blocks: Iterable[Block],
task_idx: int,
Expand All @@ -357,18 +352,6 @@ def sync_write(
tasks.append(self.data_sink.write.remote(b))
return ray.get(tasks)

def do_write(
self,
blocks: List[ObjectRef[Block]],
metadata: List[BlockMetadata],
ray_remote_args: Dict[str, Any],
**write_args,
) -> List[ObjectRef[WriteResult]]:
tasks = []
for b in blocks:
tasks.append(self.data_sink.write.remote(b))
return tasks

def on_write_complete(self, write_results: List[WriteResult]) -> None:
assert all(w == "ok" for w in write_results), write_results
self.data_sink.increment_ok.remote()
Expand Down
81 changes: 3 additions & 78 deletions python/ray/data/datasource/file_based_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@
)

from ray.data._internal.arrow_block import ArrowRow
from ray.data._internal.block_list import BlockMetadata
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data._internal.output_buffer import BlockOutputBuffer
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.util import _check_pyarrow_version, _resolve_custom_scheme
from ray.data.block import Block, BlockAccessor
from ray.data.context import DatasetContext
from ray.data.datasource.datasource import Datasource, Reader, ReadTask, WriteResult
from ray.data.datasource.datasource import Datasource, Reader, ReadTask
from ray.data.datasource.file_meta_provider import (
BaseFileMetadataProvider,
DefaultFileMetadataProvider,
Expand Down Expand Up @@ -258,8 +256,7 @@ def _convert_block_to_tabular_block(
"then you need to implement `_convert_block_to_tabular_block."
)

# This doesn't launch Ray tasks to write.
def sync_write(
def do_write(
self,
blocks: Iterable[Block],
task_idx: int,
Expand Down Expand Up @@ -323,79 +320,7 @@ def write_block(write_path: str, block: Block):
block_index=task_idx,
file_format=file_format,
)
write_block(write_path, block)

def do_write(
self,
blocks: List[ObjectRef[Block]],
metadata: List[BlockMetadata],
path: str,
dataset_uuid: str,
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
try_create_dir: bool = True,
open_stream_args: Optional[Dict[str, Any]] = None,
block_path_provider: BlockWritePathProvider = DefaultBlockWritePathProvider(),
write_args_fn: Callable[[], Dict[str, Any]] = lambda: {},
_block_udf: Optional[Callable[[Block], Block]] = None,
ray_remote_args: Dict[str, Any] = None,
**write_args,
) -> List[ObjectRef[WriteResult]]:
"""Creates and returns write tasks for a file-based datasource."""
path, filesystem = _resolve_paths_and_filesystem(path, filesystem)
path = path[0]
if try_create_dir:
# Arrow's S3FileSystem doesn't allow creating buckets by default, so we add
# a query arg enabling bucket creation if an S3 URI is provided.
tmp = _add_creatable_buckets_param_if_s3_uri(path)
filesystem.create_dir(tmp, recursive=True)
filesystem = _wrap_s3_serialization_workaround(filesystem)

_write_block_to_file = self._write_block

if open_stream_args is None:
open_stream_args = {}

if ray_remote_args is None:
ray_remote_args = {}

def write_block(write_path: str, block: Block):
logger.debug(f"Writing {write_path} file.")
fs = filesystem
if isinstance(fs, _S3FileSystemWrapper):
fs = fs.unwrap()
if _block_udf is not None:
block = _block_udf(block)

with fs.open_output_stream(write_path, **open_stream_args) as f:
_write_block_to_file(
f,
BlockAccessor.for_block(block),
writer_args_fn=write_args_fn,
**write_args,
)

write_block = cached_remote_fn(write_block).options(**ray_remote_args)

file_format = self._FILE_EXTENSION
if isinstance(file_format, list):
file_format = file_format[0]

write_tasks = []
if not block_path_provider:
block_path_provider = DefaultBlockWritePathProvider()
for block_idx, block in enumerate(blocks):
write_path = block_path_provider(
path,
filesystem=filesystem,
dataset_uuid=dataset_uuid,
block=block,
block_index=block_idx,
file_format=file_format,
)
write_task = write_block.remote(write_path, block)
write_tasks.append(write_task)

return write_tasks
return write_block(write_path, block)

def _write_block(
self,
Expand Down
37 changes: 2 additions & 35 deletions python/ray/data/datasource/mongo_datasource.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import logging
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from typing import Dict, List, Optional, TYPE_CHECKING

from ray.data.datasource.datasource import Datasource, Reader, ReadTask, WriteResult
from ray.data.block import (
Block,
BlockAccessor,
BlockMetadata,
)
from ray.data._internal.remote_fn import cached_remote_fn
from ray.types import ObjectRef
from ray.util.annotations import PublicAPI
from typing import Iterable
Expand Down Expand Up @@ -38,7 +37,7 @@ class MongoDatasource(Datasource):
def create_reader(self, **kwargs) -> Reader:
return _MongoDatasourceReader(**kwargs)

def sync_write(
def do_write(
self,
blocks: Iterable[Block],
task_idx: int,
Expand All @@ -65,38 +64,6 @@ def write_block(uri: str, database: str, collection: str, block: Block):
write_tasks.append(write_task)
return write_tasks

def do_write(
self,
blocks: List[ObjectRef[Block]],
metadata: List[BlockMetadata],
ray_remote_args: Optional[Dict[str, Any]],
uri: str,
database: str,
collection: str,
) -> List[ObjectRef[WriteResult]]:
import pymongo

_validate_database_collection_exist(
pymongo.MongoClient(uri), database, collection
)

def write_block(uri: str, database: str, collection: str, block: Block):
from pymongoarrow.api import write

block = BlockAccessor.for_block(block).to_arrow()
client = pymongo.MongoClient(uri)
write(client[database][collection], block)

if ray_remote_args is None:
ray_remote_args = {}

write_block = cached_remote_fn(write_block).options(**ray_remote_args)
write_tasks = []
for block in blocks:
write_task = write_block.remote(uri, database, collection, block)
write_tasks.append(write_task)
return write_tasks


class _MongoDatasourceReader(Reader):
def __init__(
Expand Down
43 changes: 27 additions & 16 deletions python/ray/data/tests/test_dataset_formats.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Any, Dict, List, Union
from typing import List, Union

import pandas as pd
import pyarrow as pa
Expand All @@ -13,7 +13,7 @@

import ray
from ray.data._internal.arrow_block import ArrowRow
from ray.data.block import Block, BlockAccessor, BlockMetadata
from ray.data.block import Block, BlockAccessor
from ray.data.datasource import (
Datasource,
DummyOutputDatasource,
Expand All @@ -24,6 +24,7 @@
from ray.data.tests.mock_http_server import * # noqa
from ray.tests.conftest import * # noqa
from ray.types import ObjectRef
from typing import Iterable


def maybe_pipeline(ds, enabled):
Expand Down Expand Up @@ -227,6 +228,8 @@ def __init__(self):
self.rows_written = 0
self.enabled = True
self.node_ids = set()
self.num_ok = 0
self.num_failed = 0

def write(self, node_id: str, block: Block) -> str:
block = BlockAccessor.for_block(block)
Expand All @@ -245,40 +248,48 @@ def get_node_ids(self):
def set_enabled(self, enabled):
self.enabled = enabled

def increment_ok(self):
self.num_ok += 1

def get_num_ok(self):
return self.num_ok

def increment_failed(self):
self.num_failed += 1

def get_num_failed(self):
return self.num_failed

self.data_sink = DataSink.remote()
self.num_ok = 0
self.num_failed = 0

def do_write(
self,
blocks: List[ObjectRef[Block]],
metadata: List[BlockMetadata],
ray_remote_args: Dict[str, Any],
blocks: Iterable[Block],
task_idx: int,
**write_args,
) -> List[ObjectRef[WriteResult]]:
) -> WriteResult:
data_sink = self.data_sink

@ray.remote
def write(b):
node_id = ray.get_runtime_context().get_node_id()
return ray.get(data_sink.write.remote(node_id, b))

tasks = []
for b in blocks:
tasks.append(write.options(**ray_remote_args).remote(b))
return tasks
result = write(b)
return result

def on_write_complete(self, write_results: List[WriteResult]) -> None:
assert all(w == "ok" for w in write_results), write_results
self.num_ok += 1
self.data_sink.increment_ok.remote()

def on_write_failed(
self, write_results: List[ObjectRef[WriteResult]], error: Exception
) -> None:
self.num_failed += 1
self.data_sink.increment_failed.remote()


def test_write_datasource_ray_remote_args(ray_start_cluster):
ray.shutdown()
cluster = ray_start_cluster
cluster.add_node(
resources={"foo": 100},
Expand All @@ -298,8 +309,8 @@ def get_node_id():
ds = ray.data.range(100, parallelism=10)
# Pin write tasks to
ds.write_datasource(output, ray_remote_args={"resources": {"bar": 1}})
assert output.num_ok == 1
assert output.num_failed == 0
assert ray.get(output.data_sink.get_num_ok.remote()) == 10
assert ray.get(output.data_sink.get_num_failed.remote()) == 0
assert ray.get(output.data_sink.get_rows_written.remote()) == 100

node_ids = ray.get(output.data_sink.get_node_ids.remote())
Expand Down