-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Data] Support class constructor args for map() and flat_map() #38606
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -286,6 +286,7 @@ def map( | |
fn: UserDefinedFunction[Dict[str, Any], Dict[str, Any]], | ||
*, | ||
compute: Optional[ComputeStrategy] = None, | ||
fn_constructor_args: Optional[Iterable[Any]] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should also add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I prefer us to do the minimal we need here. We can add later if there's ask. |
||
num_cpus: Optional[float] = None, | ||
num_gpus: Optional[float] = None, | ||
**ray_remote_args, | ||
|
@@ -336,6 +337,9 @@ def parse_filename(row: Dict[str, Any]) -> Dict[str, Any]: | |
tasks, ``ray.data.ActorPoolStrategy(size=n)`` to use a fixed-size actor | ||
pool, or ``ray.data.ActorPoolStrategy(min_size=m, max_size=n)`` for an | ||
autoscaling actor pool. | ||
fn_constructor_args: Positional arguments to pass to ``fn``'s constructor. | ||
You can only provide this if ``fn`` is a callable class. These arguments | ||
are top-level arguments in the underlying Ray actor construction task. | ||
num_cpus: The number of CPUs to reserve for each parallel map worker. | ||
num_gpus: The number of GPUs to reserve for each parallel map worker. For | ||
example, specify `num_gpus=1` to request 1 GPU for each parallel map | ||
|
@@ -353,7 +357,7 @@ def parse_filename(row: Dict[str, Any]) -> Dict[str, Any]: | |
:meth:`~Dataset.map_batches` | ||
Call this method to transform batches of data. | ||
""" # noqa: E501 | ||
validate_compute(fn, compute) | ||
validate_compute(fn, compute, fn_constructor_args) | ||
|
||
transform_fn = generate_map_rows_fn() | ||
|
||
|
@@ -370,6 +374,7 @@ def parse_filename(row: Dict[str, Any]) -> Dict[str, Any]: | |
compute, | ||
ray_remote_args, | ||
fn=fn, | ||
fn_constructor_args=fn_constructor_args, | ||
) | ||
) | ||
|
||
|
@@ -378,6 +383,7 @@ def parse_filename(row: Dict[str, Any]) -> Dict[str, Any]: | |
map_op = MapRows( | ||
logical_plan.dag, | ||
fn, | ||
fn_constructor_args=fn_constructor_args, | ||
compute=compute, | ||
ray_remote_args=ray_remote_args, | ||
) | ||
|
@@ -552,22 +558,20 @@ def map_fn_with_large_output(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarr | |
f"{batch_format}" | ||
) | ||
|
||
validate_compute(fn, compute) | ||
validate_compute(fn, compute, fn_constructor_args) | ||
|
||
if fn_constructor_args is not None or fn_constructor_kwargs is not None: | ||
if fn_constructor_kwargs is not None: | ||
if compute is None or ( | ||
compute != "actors" and not isinstance(compute, ActorPoolStrategy) | ||
): | ||
raise ValueError( | ||
"fn_constructor_args and fn_constructor_kwargs can only be " | ||
"specified if using the actor pool compute strategy, but got: " | ||
f"{compute}" | ||
"fn_constructor_kwargs can only be specified if using the actor " | ||
f"pool compute strategy, but got: {compute}" | ||
) | ||
if not isinstance(fn, CallableClass): | ||
raise ValueError( | ||
"fn_constructor_args and fn_constructor_kwargs can only be " | ||
"specified if providing a CallableClass instance for fn, but got: " | ||
f"{fn}" | ||
"fn_constructor_kwargs can only be specified if providing a " | ||
f"CallableClass instance for fn, but got: {fn}" | ||
) | ||
|
||
transform_fn = generate_map_batches_fn( | ||
|
@@ -789,6 +793,7 @@ def flat_map( | |
fn: UserDefinedFunction[Dict[str, Any], List[Dict[str, Any]]], | ||
*, | ||
compute: Optional[ComputeStrategy] = None, | ||
fn_constructor_args: Optional[Iterable[Any]] = None, | ||
num_cpus: Optional[float] = None, | ||
num_gpus: Optional[float] = None, | ||
**ray_remote_args, | ||
|
@@ -833,6 +838,9 @@ def duplicate_row(row: Dict[str, Any]) -> List[Dict[str, Any]]: | |
tasks, ``ray.data.ActorPoolStrategy(size=n)`` to use a fixed-size actor | ||
pool, or ``ray.data.ActorPoolStrategy(min_size=m, max_size=n)`` for an | ||
autoscaling actor pool. | ||
fn_constructor_args: Positional arguments to pass to ``fn``'s constructor. | ||
You can only provide this if ``fn`` is a callable class. These arguments | ||
are top-level arguments in the underlying Ray actor construction task. | ||
num_cpus: The number of CPUs to reserve for each parallel map worker. | ||
num_gpus: The number of GPUs to reserve for each parallel map worker. For | ||
example, specify `num_gpus=1` to request 1 GPU for each parallel map | ||
|
@@ -848,7 +856,7 @@ def duplicate_row(row: Dict[str, Any]) -> List[Dict[str, Any]]: | |
:meth:`~Dataset.map` | ||
Call this method to transform one row at time. | ||
""" | ||
validate_compute(fn, compute) | ||
validate_compute(fn, compute, fn_constructor_args) | ||
|
||
transform_fn = generate_flat_map_fn() | ||
|
||
|
@@ -859,14 +867,22 @@ def duplicate_row(row: Dict[str, Any]) -> List[Dict[str, Any]]: | |
ray_remote_args["num_gpus"] = num_gpus | ||
|
||
plan = self._plan.with_stage( | ||
OneToOneStage("FlatMap", transform_fn, compute, ray_remote_args, fn=fn) | ||
OneToOneStage( | ||
"FlatMap", | ||
transform_fn, | ||
compute, | ||
ray_remote_args, | ||
fn=fn, | ||
fn_constructor_args=fn_constructor_args, | ||
) | ||
) | ||
|
||
logical_plan = self._logical_plan | ||
if logical_plan is not None: | ||
op = FlatMap( | ||
input_op=logical_plan.dag, | ||
fn=fn, | ||
fn_constructor_args=fn_constructor_args, | ||
compute=compute, | ||
ray_remote_args=ray_remote_args, | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -178,14 +178,47 @@ def __call__(self, x, arg, kwarg): | |
return x | ||
|
||
# map_batches with args & kwargs | ||
ds.map_batches( | ||
result = ds.map_batches( | ||
StatefulFnWithArgs, | ||
compute=ray.data.ActorPoolStrategy(), | ||
fn_args=(1,), | ||
fn_kwargs={"kwarg": 2}, | ||
fn_constructor_args=(1,), | ||
fn_constructor_kwargs={"kwarg": 2}, | ||
).take() == list(range(10)) | ||
).take() | ||
assert sorted(extract_values("id", result)) == list(range(10)), result | ||
Comment on lines
+181
to
+189
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is bugfix for previous test. |
||
|
||
class StatefulFlatMapFnWithInitArg: | ||
def __init__(self, arg): | ||
self._arg = arg | ||
assert arg == 1 | ||
|
||
def __call__(self, x): | ||
return [x] * self._arg | ||
|
||
# flat_map with args | ||
result = ds.flat_map( | ||
StatefulFlatMapFnWithInitArg, | ||
compute=ray.data.ActorPoolStrategy(), | ||
fn_constructor_args=(1,), | ||
).take() | ||
assert sorted(extract_values("id", result)) == list(range(10)), result | ||
|
||
class StatefulMapFnWithInitArg: | ||
def __init__(self, arg): | ||
self._arg = arg | ||
assert arg == 1 | ||
|
||
def __call__(self, x): | ||
return x | ||
|
||
# map with args | ||
result = ds.map( | ||
StatefulMapFnWithInitArg, | ||
compute=ray.data.ActorPoolStrategy(), | ||
fn_constructor_args=(1,), | ||
).take() | ||
assert sorted(extract_values("id", result)) == list(range(10)), result | ||
|
||
|
||
def test_concurrent_callable_classes(shutdown_only): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type of the compute arg is too loose. This makes us need duplicated null checks.
It'd be better to split this into 2 steps.
parse_compute
util function that convertsOptional[Union[str, "ComputeStrategy"]]
to justComputeStrategy
.And by doing this,
compute
can only be None or str in the Dataset API level. under the hood, it's always the concrete type.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense, we already have a
get_compute()
method.Let me make a separate PR to do the refactoring, it would involve change
AbstractUDFMap.compute
and_plan_udf_map_op
.