Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Support class constructor args for map() and flat_map() #38606

Merged
merged 1 commit into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/ray/data/_internal/logical/operators/map_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,15 @@ def __init__(
self,
input_op: LogicalOperator,
fn: UserDefinedFunction,
fn_constructor_args: Optional[Iterable[Any]] = None,
compute: Optional[Union[str, ComputeStrategy]] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
):
super().__init__(
"Map",
input_op,
fn,
fn_constructor_args=fn_constructor_args,
compute=compute,
ray_remote_args=ray_remote_args,
)
Expand Down Expand Up @@ -192,13 +194,15 @@ def __init__(
self,
input_op: LogicalOperator,
fn: UserDefinedFunction,
fn_constructor_args: Optional[Iterable[Any]] = None,
compute: Optional[Union[str, ComputeStrategy]] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
):
super().__init__(
"FlatMap",
input_op,
fn,
fn_constructor_args=fn_constructor_args,
compute=compute,
ray_remote_args=ray_remote_args,
)
Expand Down
22 changes: 19 additions & 3 deletions python/ray/data/_internal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
import urllib.parse
from types import ModuleType
from typing import TYPE_CHECKING, Any, List, Optional, Union
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union

import numpy as np

Expand Down Expand Up @@ -406,10 +406,12 @@ def _split_list(arr: List[Any], num_splits: int) -> List[List[Any]]:


def validate_compute(
fn: "UserDefinedFunction", compute: Optional[Union[str, "ComputeStrategy"]]
fn: "UserDefinedFunction",
compute: Optional[Union[str, "ComputeStrategy"]],
fn_constructor_args: Optional[Iterable[Any]] = None,
) -> None:
# Lazily import these objects to avoid circular imports.
from ray.data._internal.compute import TaskPoolStrategy
from ray.data._internal.compute import ActorPoolStrategy, TaskPoolStrategy
from ray.data.block import CallableClass

if isinstance(fn, CallableClass) and (
Expand All @@ -421,6 +423,20 @@ def validate_compute(
"For example, use ``compute=ray.data.ActorPoolStrategy(size=n)``."
)

if fn_constructor_args is not None:
if compute is None or (
compute != "actors" and not isinstance(compute, ActorPoolStrategy)
):
raise ValueError(
"fn_constructor_args can only be specified if using the actor "
f"pool compute strategy, but got: {compute}"
)
if not isinstance(fn, CallableClass):
raise ValueError(
"fn_constructor_args can only be specified if providing a "
f"CallableClass instance for fn, but got: {fn}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type of the compute arg is too loose. This makes us need duplicated null checks.

It'd be better to split this into 2 steps.

  1. a parse_compute util function that converts Optional[Union[str, "ComputeStrategy"]] to just ComputeStrategy.
  2. Then check the fn and args.

And by doing this, compute can only be None or str in the Dataset API level. under the hood, it's always the concrete type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense, we already have a get_compute() method.

Let me make a separate PR to do the refactoring, it would involve change AbstractUDFMap.compute and _plan_udf_map_op.



def capfirst(s: str):
"""Capitalize the first letter of a string
Expand Down
38 changes: 27 additions & 11 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def map(
fn: UserDefinedFunction[Dict[str, Any], Dict[str, Any]],
*,
compute: Optional[ComputeStrategy] = None,
fn_constructor_args: Optional[Iterable[Any]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should also add fn_constructor_kwargs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer us to do the minimal we need here. We can add later if there's ask.

num_cpus: Optional[float] = None,
num_gpus: Optional[float] = None,
**ray_remote_args,
Expand Down Expand Up @@ -336,6 +337,9 @@ def parse_filename(row: Dict[str, Any]) -> Dict[str, Any]:
tasks, ``ray.data.ActorPoolStrategy(size=n)`` to use a fixed-size actor
pool, or ``ray.data.ActorPoolStrategy(min_size=m, max_size=n)`` for an
autoscaling actor pool.
fn_constructor_args: Positional arguments to pass to ``fn``'s constructor.
You can only provide this if ``fn`` is a callable class. These arguments
are top-level arguments in the underlying Ray actor construction task.
num_cpus: The number of CPUs to reserve for each parallel map worker.
num_gpus: The number of GPUs to reserve for each parallel map worker. For
example, specify `num_gpus=1` to request 1 GPU for each parallel map
Expand All @@ -353,7 +357,7 @@ def parse_filename(row: Dict[str, Any]) -> Dict[str, Any]:
:meth:`~Dataset.map_batches`
Call this method to transform batches of data.
""" # noqa: E501
validate_compute(fn, compute)
validate_compute(fn, compute, fn_constructor_args)

transform_fn = generate_map_rows_fn()

Expand All @@ -370,6 +374,7 @@ def parse_filename(row: Dict[str, Any]) -> Dict[str, Any]:
compute,
ray_remote_args,
fn=fn,
fn_constructor_args=fn_constructor_args,
)
)

Expand All @@ -378,6 +383,7 @@ def parse_filename(row: Dict[str, Any]) -> Dict[str, Any]:
map_op = MapRows(
logical_plan.dag,
fn,
fn_constructor_args=fn_constructor_args,
compute=compute,
ray_remote_args=ray_remote_args,
)
Expand Down Expand Up @@ -552,22 +558,20 @@ def map_fn_with_large_output(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarr
f"{batch_format}"
)

validate_compute(fn, compute)
validate_compute(fn, compute, fn_constructor_args)

if fn_constructor_args is not None or fn_constructor_kwargs is not None:
if fn_constructor_kwargs is not None:
if compute is None or (
compute != "actors" and not isinstance(compute, ActorPoolStrategy)
):
raise ValueError(
"fn_constructor_args and fn_constructor_kwargs can only be "
"specified if using the actor pool compute strategy, but got: "
f"{compute}"
"fn_constructor_kwargs can only be specified if using the actor "
f"pool compute strategy, but got: {compute}"
)
if not isinstance(fn, CallableClass):
raise ValueError(
"fn_constructor_args and fn_constructor_kwargs can only be "
"specified if providing a CallableClass instance for fn, but got: "
f"{fn}"
"fn_constructor_kwargs can only be specified if providing a "
f"CallableClass instance for fn, but got: {fn}"
)

transform_fn = generate_map_batches_fn(
Expand Down Expand Up @@ -789,6 +793,7 @@ def flat_map(
fn: UserDefinedFunction[Dict[str, Any], List[Dict[str, Any]]],
*,
compute: Optional[ComputeStrategy] = None,
fn_constructor_args: Optional[Iterable[Any]] = None,
num_cpus: Optional[float] = None,
num_gpus: Optional[float] = None,
**ray_remote_args,
Expand Down Expand Up @@ -833,6 +838,9 @@ def duplicate_row(row: Dict[str, Any]) -> List[Dict[str, Any]]:
tasks, ``ray.data.ActorPoolStrategy(size=n)`` to use a fixed-size actor
pool, or ``ray.data.ActorPoolStrategy(min_size=m, max_size=n)`` for an
autoscaling actor pool.
fn_constructor_args: Positional arguments to pass to ``fn``'s constructor.
You can only provide this if ``fn`` is a callable class. These arguments
are top-level arguments in the underlying Ray actor construction task.
num_cpus: The number of CPUs to reserve for each parallel map worker.
num_gpus: The number of GPUs to reserve for each parallel map worker. For
example, specify `num_gpus=1` to request 1 GPU for each parallel map
Expand All @@ -848,7 +856,7 @@ def duplicate_row(row: Dict[str, Any]) -> List[Dict[str, Any]]:
:meth:`~Dataset.map`
Call this method to transform one row at time.
"""
validate_compute(fn, compute)
validate_compute(fn, compute, fn_constructor_args)

transform_fn = generate_flat_map_fn()

Expand All @@ -859,14 +867,22 @@ def duplicate_row(row: Dict[str, Any]) -> List[Dict[str, Any]]:
ray_remote_args["num_gpus"] = num_gpus

plan = self._plan.with_stage(
OneToOneStage("FlatMap", transform_fn, compute, ray_remote_args, fn=fn)
OneToOneStage(
"FlatMap",
transform_fn,
compute,
ray_remote_args,
fn=fn,
fn_constructor_args=fn_constructor_args,
)
)

logical_plan = self._logical_plan
if logical_plan is not None:
op = FlatMap(
input_op=logical_plan.dag,
fn=fn,
fn_constructor_args=fn_constructor_args,
compute=compute,
ray_remote_args=ray_remote_args,
)
Expand Down
37 changes: 35 additions & 2 deletions python/ray/data/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,47 @@ def __call__(self, x, arg, kwarg):
return x

# map_batches with args & kwargs
ds.map_batches(
result = ds.map_batches(
StatefulFnWithArgs,
compute=ray.data.ActorPoolStrategy(),
fn_args=(1,),
fn_kwargs={"kwarg": 2},
fn_constructor_args=(1,),
fn_constructor_kwargs={"kwarg": 2},
).take() == list(range(10))
).take()
assert sorted(extract_values("id", result)) == list(range(10)), result
Comment on lines +181 to +189
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is bugfix for previous test.


class StatefulFlatMapFnWithInitArg:
def __init__(self, arg):
self._arg = arg
assert arg == 1

def __call__(self, x):
return [x] * self._arg

# flat_map with args
result = ds.flat_map(
StatefulFlatMapFnWithInitArg,
compute=ray.data.ActorPoolStrategy(),
fn_constructor_args=(1,),
).take()
assert sorted(extract_values("id", result)) == list(range(10)), result

class StatefulMapFnWithInitArg:
def __init__(self, arg):
self._arg = arg
assert arg == 1

def __call__(self, x):
return x

# map with args
result = ds.map(
StatefulMapFnWithInitArg,
compute=ray.data.ActorPoolStrategy(),
fn_constructor_args=(1,),
).take()
assert sorted(extract_values("id", result)) == list(range(10)), result


def test_concurrent_callable_classes(shutdown_only):
Expand Down
Loading