Skip to content

Commit

Permalink
[Data] Add concurrency for write APIs (#43177)
Browse files Browse the repository at this point in the history
This PR is to add concurrency parameter to write APIs. This makes write to be consistent with read and map-like APIs. So users can control concurrency through read-map-write workflow.

Signed-off-by: Cheng Su <[email protected]>
  • Loading branch information
c21 committed Feb 15, 2024
1 parent c6ab403 commit 638f77f
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 14 deletions.
5 changes: 2 additions & 3 deletions python/ray/data/_internal/logical/operators/write_operator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Dict, Optional, Union

from ray.data._internal.compute import TaskPoolStrategy
from ray.data._internal.logical.interfaces import LogicalOperator
from ray.data._internal.logical.operators.map_operator import AbstractMap
from ray.data.datasource.datasink import Datasink
Expand All @@ -15,6 +14,7 @@ def __init__(
input_op: LogicalOperator,
datasink_or_legacy_datasource: Union[Datasink, Datasource],
ray_remote_args: Optional[Dict[str, Any]] = None,
concurrency: Optional[int] = None,
**write_args,
):
if isinstance(datasink_or_legacy_datasource, Datasink):
Expand All @@ -32,5 +32,4 @@ def __init__(
)
self._datasink_or_legacy_datasource = datasink_or_legacy_datasource
self._write_args = write_args
# Always use task to write.
self._compute = TaskPoolStrategy()
self._concurrency = concurrency
2 changes: 2 additions & 0 deletions python/ray/data/_internal/planner/plan_write_op.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Callable, Iterator, Union

from ray.data._internal.compute import TaskPoolStrategy
from ray.data._internal.execution.interfaces import PhysicalOperator
from ray.data._internal.execution.interfaces.task_context import TaskContext
from ray.data._internal.execution.operators.map_operator import MapOperator
Expand Down Expand Up @@ -51,4 +52,5 @@ def plan_write_op(op: Write, input_physical_dag: PhysicalOperator) -> PhysicalOp
target_max_block_size=None,
ray_remote_args=op._ray_remote_args,
min_rows_per_bundle=op._min_rows_per_bundled_input,
compute_strategy=TaskPoolStrategy(op._concurrency),
)
116 changes: 106 additions & 10 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2670,6 +2670,7 @@ def write_parquet(
arrow_parquet_args_fn: Callable[[], Dict[str, Any]] = lambda: {},
num_rows_per_file: Optional[int] = None,
ray_remote_args: Dict[str, Any] = None,
concurrency: Optional[int] = None,
**arrow_parquet_args,
) -> None:
"""Writes the :class:`~ray.data.Dataset` to parquet files under the provided ``path``.
Expand Down Expand Up @@ -2727,6 +2728,10 @@ def write_parquet(
num_rows_per_file: The target number of rows to write to each file. If
``None``, Ray Data writes a system-chosen number of rows to each file.
ray_remote_args: Kwargs passed to :meth:`~ray.remote` in the write tasks.
concurrency: The maximum number of Ray tasks to run concurrently. Set this
to control number of tasks to run concurrently. This doesn't change the
total number of tasks run. By default, concurrency is dynamically
decided based on the available resources.
arrow_parquet_args: Options to pass to
`pyarrow.parquet.write_table() <https://arrow.apache.org/docs/python\
/generated/pyarrow.parquet.write_table.html\
Expand All @@ -2745,7 +2750,11 @@ def write_parquet(
block_path_provider=block_path_provider,
dataset_uuid=self._uuid,
)
self.write_datasink(datasink, ray_remote_args=ray_remote_args)
self.write_datasink(
datasink,
ray_remote_args=ray_remote_args,
concurrency=concurrency,
)

@ConsumptionAPI
def write_json(
Expand All @@ -2760,6 +2769,7 @@ def write_json(
pandas_json_args_fn: Callable[[], Dict[str, Any]] = lambda: {},
num_rows_per_file: Optional[int] = None,
ray_remote_args: Dict[str, Any] = None,
concurrency: Optional[int] = None,
**pandas_json_args,
) -> None:
"""Writes the :class:`~ray.data.Dataset` to JSON and JSONL files.
Expand Down Expand Up @@ -2826,6 +2836,10 @@ def write_json(
num_rows_per_file: The target number of rows to write to each file. If
``None``, Ray Data writes a system-chosen number of rows to each file.
ray_remote_args: kwargs passed to :meth:`~ray.remote` in the write tasks.
concurrency: The maximum number of Ray tasks to run concurrently. Set this
to control number of tasks to run concurrently. This doesn't change the
total number of tasks run. By default, concurrency is dynamically
decided based on the available resources.
pandas_json_args: These args are passed to
`pandas.DataFrame.to_json() <https://pandas.pydata.org/docs/reference/\
api/pandas.DataFrame.to_json.html>`_,
Expand All @@ -2845,7 +2859,11 @@ def write_json(
block_path_provider=block_path_provider,
dataset_uuid=self._uuid,
)
self.write_datasink(datasink, ray_remote_args=ray_remote_args)
self.write_datasink(
datasink,
ray_remote_args=ray_remote_args,
concurrency=concurrency,
)

@PublicAPI(stability="alpha")
@ConsumptionAPI
Expand All @@ -2860,6 +2878,7 @@ def write_images(
arrow_open_stream_args: Optional[Dict[str, Any]] = None,
filename_provider: Optional[FilenameProvider] = None,
ray_remote_args: Dict[str, Any] = None,
concurrency: Optional[int] = None,
) -> None:
"""Writes the :class:`~ray.data.Dataset` to images.
Expand Down Expand Up @@ -2897,6 +2916,10 @@ def write_images(
implementation. Use this parameter to customize what your filenames
look like.
ray_remote_args: kwargs passed to :meth:`~ray.remote` in the write tasks.
concurrency: The maximum number of Ray tasks to run concurrently. Set this
to control number of tasks to run concurrently. This doesn't change the
total number of tasks run. By default, concurrency is dynamically
decided based on the available resources.
""" # noqa: E501
datasink = _ImageDatasink(
path,
Expand All @@ -2908,7 +2931,11 @@ def write_images(
filename_provider=filename_provider,
dataset_uuid=self._uuid,
)
self.write_datasink(datasink, ray_remote_args=ray_remote_args)
self.write_datasink(
datasink,
ray_remote_args=ray_remote_args,
concurrency=concurrency,
)

@ConsumptionAPI
def write_csv(
Expand All @@ -2923,6 +2950,7 @@ def write_csv(
arrow_csv_args_fn: Callable[[], Dict[str, Any]] = lambda: {},
num_rows_per_file: Optional[int] = None,
ray_remote_args: Dict[str, Any] = None,
concurrency: Optional[int] = None,
**arrow_csv_args,
) -> None:
"""Writes the :class:`~ray.data.Dataset` to CSV files.
Expand Down Expand Up @@ -2988,6 +3016,10 @@ def write_csv(
num_rows_per_file: The target number of rows to write to each file. If
``None``, Ray Data writes a system-chosen number of rows to each file.
ray_remote_args: kwargs passed to :meth:`~ray.remote` in the write tasks.
concurrency: The maximum number of Ray tasks to run concurrently. Set this
to control number of tasks to run concurrently. This doesn't change the
total number of tasks run. By default, concurrency is dynamically
decided based on the available resources.
arrow_csv_args: Options to pass to `pyarrow.write.write_csv <https://\
arrow.apache.org/docs/python/generated/pyarrow.csv.write_csv.html\
#pyarrow.csv.write_csv>`_
Expand All @@ -3005,7 +3037,11 @@ def write_csv(
block_path_provider=block_path_provider,
dataset_uuid=self._uuid,
)
self.write_datasink(datasink, ray_remote_args=ray_remote_args)
self.write_datasink(
datasink,
ray_remote_args=ray_remote_args,
concurrency=concurrency,
)

@ConsumptionAPI
def write_tfrecords(
Expand All @@ -3020,6 +3056,7 @@ def write_tfrecords(
block_path_provider: Optional[BlockWritePathProvider] = None,
num_rows_per_file: Optional[int] = None,
ray_remote_args: Dict[str, Any] = None,
concurrency: Optional[int] = None,
) -> None:
"""Write the :class:`~ray.data.Dataset` to TFRecord files.
Expand Down Expand Up @@ -3079,6 +3116,10 @@ def write_tfrecords(
num_rows_per_file: The target number of rows to write to each file. If
``None``, Ray Data writes a system-chosen number of rows to each file.
ray_remote_args: kwargs passed to :meth:`~ray.remote` in the write tasks.
concurrency: The maximum number of Ray tasks to run concurrently. Set this
to control number of tasks to run concurrently. This doesn't change the
total number of tasks run. By default, concurrency is dynamically
decided based on the available resources.
"""
datasink = _TFRecordDatasink(
Expand All @@ -3092,7 +3133,11 @@ def write_tfrecords(
block_path_provider=block_path_provider,
dataset_uuid=self._uuid,
)
self.write_datasink(datasink, ray_remote_args=ray_remote_args)
self.write_datasink(
datasink,
ray_remote_args=ray_remote_args,
concurrency=concurrency,
)

@PublicAPI(stability="alpha")
@ConsumptionAPI
Expand All @@ -3108,6 +3153,7 @@ def write_webdataset(
num_rows_per_file: Optional[int] = None,
ray_remote_args: Dict[str, Any] = None,
encoder: Optional[Union[bool, str, callable, list]] = True,
concurrency: Optional[int] = None,
) -> None:
"""Writes the dataset to `WebDataset <https://webdataset.github.io/webdataset/>`_ files.
Expand Down Expand Up @@ -3154,6 +3200,10 @@ def write_webdataset(
num_rows_per_file: The target number of rows to write to each file. If
``None``, Ray Data writes a system-chosen number of rows to each file.
ray_remote_args: Kwargs passed to ``ray.remote`` in the write tasks.
concurrency: The maximum number of Ray tasks to run concurrently. Set this
to control number of tasks to run concurrently. This doesn't change the
total number of tasks run. By default, concurrency is dynamically
decided based on the available resources.
"""
datasink = _WebDatasetDatasink(
Expand All @@ -3167,7 +3217,11 @@ def write_webdataset(
block_path_provider=block_path_provider,
dataset_uuid=self._uuid,
)
self.write_datasink(datasink, ray_remote_args=ray_remote_args)
self.write_datasink(
datasink,
ray_remote_args=ray_remote_args,
concurrency=concurrency,
)

@ConsumptionAPI
def write_numpy(
Expand All @@ -3182,6 +3236,7 @@ def write_numpy(
block_path_provider: Optional[BlockWritePathProvider] = None,
num_rows_per_file: Optional[int] = None,
ray_remote_args: Dict[str, Any] = None,
concurrency: Optional[int] = None,
) -> None:
"""Writes a column of the :class:`~ray.data.Dataset` to .npy files.
Expand Down Expand Up @@ -3232,6 +3287,10 @@ def write_numpy(
num_rows_per_file: The target number of rows to write to each file. If
``None``, Ray Data writes a system-chosen number of rows to each file.
ray_remote_args: kwargs passed to :meth:`~ray.remote` in the write tasks.
concurrency: The maximum number of Ray tasks to run concurrently. Set this
to control number of tasks to run concurrently. This doesn't change the
total number of tasks run. By default, concurrency is dynamically
decided based on the available resources.
"""

datasink = _NumpyDatasink(
Expand All @@ -3245,14 +3304,19 @@ def write_numpy(
block_path_provider=block_path_provider,
dataset_uuid=self._uuid,
)
self.write_datasink(datasink, ray_remote_args=ray_remote_args)
self.write_datasink(
datasink,
ray_remote_args=ray_remote_args,
concurrency=concurrency,
)

@ConsumptionAPI
def write_sql(
self,
sql: str,
connection_factory: Callable[[], Connection],
ray_remote_args: Optional[Dict[str, Any]] = None,
concurrency: Optional[int] = None,
) -> None:
"""Write to a database that provides a
`Python DB API2-compliant <https://peps.python.org/pep-0249/>`_ connector.
Expand Down Expand Up @@ -3302,9 +3366,17 @@ def write_sql(
`Connection object <https://peps.python.org/pep-0249/#connection-objects>`_.
ray_remote_args: Keyword arguments passed to :meth:`~ray.remote` in the
write tasks.
concurrency: The maximum number of Ray tasks to run concurrently. Set this
to control number of tasks to run concurrently. This doesn't change the
total number of tasks run. By default, concurrency is dynamically
decided based on the available resources.
""" # noqa: E501
datasink = _SQLDatasink(sql=sql, connection_factory=connection_factory)
self.write_datasink(datasink, ray_remote_args=ray_remote_args)
self.write_datasink(
datasink,
ray_remote_args=ray_remote_args,
concurrency=concurrency,
)

@PublicAPI(stability="alpha")
@ConsumptionAPI
Expand All @@ -3314,6 +3386,7 @@ def write_mongo(
database: str,
collection: str,
ray_remote_args: Dict[str, Any] = None,
concurrency: Optional[int] = None,
) -> None:
"""Writes the :class:`~ray.data.Dataset` to a MongoDB database.
Expand Down Expand Up @@ -3360,6 +3433,10 @@ def write_mongo(
collection: The name of the collection in the database. This collection
must exist otherwise a ValueError is raised.
ray_remote_args: kwargs passed to :meth:`~ray.remote` in the write tasks.
concurrency: The maximum number of Ray tasks to run concurrently. Set this
to control number of tasks to run concurrently. This doesn't change the
total number of tasks run. By default, concurrency is dynamically
decided based on the available resources.
Raises:
ValueError: if ``database`` doesn't exist.
Expand All @@ -3370,7 +3447,11 @@ def write_mongo(
database=database,
collection=collection,
)
self.write_datasink(datasink, ray_remote_args=ray_remote_args)
self.write_datasink(
datasink,
ray_remote_args=ray_remote_args,
concurrency=concurrency,
)

@ConsumptionAPI
def write_bigquery(
Expand All @@ -3380,6 +3461,7 @@ def write_bigquery(
max_retry_cnt: int = 10,
overwrite_table: Optional[bool] = True,
ray_remote_args: Dict[str, Any] = None,
concurrency: Optional[int] = None,
) -> None:
"""Write the dataset to a BigQuery dataset table.
Expand Down Expand Up @@ -3415,6 +3497,10 @@ def write_bigquery(
exists. The default behavior is to overwrite the table.
``overwrite_table=False`` will append to the table if it exists.
ray_remote_args: Kwargs passed to ray.remote in the write tasks.
concurrency: The maximum number of Ray tasks to run concurrently. Set this
to control number of tasks to run concurrently. This doesn't change the
total number of tasks run. By default, concurrency is dynamically
decided based on the available resources.
""" # noqa: E501
if ray_remote_args is None:
ray_remote_args = {}
Expand All @@ -3435,7 +3521,11 @@ def write_bigquery(
max_retry_cnt=max_retry_cnt,
overwrite_table=overwrite_table,
)
self.write_datasink(datasink, ray_remote_args=ray_remote_args)
self.write_datasink(
datasink,
ray_remote_args=ray_remote_args,
concurrency=concurrency,
)

@Deprecated
@ConsumptionAPI(pattern="Time complexity:")
Expand Down Expand Up @@ -3506,6 +3596,7 @@ def write_datasink(
datasink: Datasink,
*,
ray_remote_args: Dict[str, Any] = None,
concurrency: Optional[int] = None,
) -> None:
"""Writes the dataset to a custom :class:`~ray.data.Datasink`.
Expand All @@ -3514,6 +3605,10 @@ def write_datasink(
Args:
datasink: The :class:`~ray.data.Datasink` to write to.
ray_remote_args: Kwargs passed to ``ray.remote`` in the write tasks.
concurrency: The maximum number of Ray tasks to run concurrently. Set this
to control number of tasks to run concurrently. This doesn't change the
total number of tasks run. By default, concurrency is dynamically
decided based on the available resources.
""" # noqa: E501
if ray_remote_args is None:
ray_remote_args = {}
Expand All @@ -3534,6 +3629,7 @@ def write_datasink(
self._logical_plan.dag,
datasink,
ray_remote_args=ray_remote_args,
concurrency=concurrency,
)
logical_plan = LogicalPlan(write_op)

Expand Down
Loading

0 comments on commit 638f77f

Please sign in to comment.