Skip to content

Commit

Permalink
[Data] Allow tasks to control concurrency in map-like APIs (#42637)
Browse files Browse the repository at this point in the history
This PR is to allow tasks to control concurrency in map-like APIs, when user uses `map_batches(fn, concurrency=...)`. Each `TaskPoolMapOperator` will have a concurrency cap to control the concurrency.

Signed-off-by: Cheng Su <[email protected]>
  • Loading branch information
c21 committed Jan 30, 2024
1 parent c1aaaa3 commit 46a9efe
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 174 deletions.
18 changes: 17 additions & 1 deletion python/ray/data/_internal/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,20 @@ def _apply(

@DeveloperAPI
class TaskPoolStrategy(ComputeStrategy):
def __init__(
self,
size: Optional[int] = None,
):
"""Construct TaskPoolStrategy for a Dataset transform.
Args:
size: Specify the maximum size of the task pool.
"""

if size is not None and size < 1:
raise ValueError("`size` must be >= 1", size)
self.size = size

def _apply(
self,
block_fn: BlockTransform,
Expand Down Expand Up @@ -148,7 +162,9 @@ def _apply(
)

def __eq__(self, other: Any) -> bool:
return isinstance(other, TaskPoolStrategy) or other == "tasks"
return (isinstance(other, TaskPoolStrategy) and self.size == other.size) or (
other == "tasks" and self.size is None
)


@PublicAPI
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@

# Default enabled backpressure policies and its config key.
# Use `DataContext.set_config` to config it.
# TODO(hchen): Enable ConcurrencyCapBackpressurePolicy by default.
# TODO(hchen): Enable StreamingOutputBackpressurePolicy by default.
ENABLED_BACKPRESSURE_POLICIES = []
ENABLED_BACKPRESSURE_POLICIES = [ConcurrencyCapBackpressurePolicy]
ENABLED_BACKPRESSURE_POLICIES_CONFIG_KEY = "backpressure_policies.enabled"


Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
from typing import TYPE_CHECKING

import ray
from .backpressure_policy import BackpressurePolicy
from ray.data._internal.execution.operators.task_pool_map_operator import (
TaskPoolMapOperator,
)

if TYPE_CHECKING:
from ray.data._internal.execution.interfaces.physical_operator import (
Expand All @@ -16,67 +18,26 @@
class ConcurrencyCapBackpressurePolicy(BackpressurePolicy):
"""A backpressure policy that caps the concurrency of each operator.
The concurrency cap limits the number of concurrently running tasks.
It will be set to an intial value, and will ramp up exponentially.
The policy will limit the number of concurrently running tasks based on its
concurrency cap parameter.
The concrete stategy is as follows:
- Each PhysicalOperator is assigned an initial concurrency cap.
- An PhysicalOperator can run new tasks if the number of running tasks is less
than the cap.
- When the number of finished tasks reaches a threshold, the concurrency cap will
increase.
NOTE: Only support setting concurrency cap for `TaskPoolMapOperator` for now.
TODO(chengsu): Consolidate with actor scaling logic of `ActorPoolMapOperator`.
"""

# Following are the default values followed by the config keys of the
# available configs.
# Use `DataContext.set_config` to config them.

# The intial concurrency cap for each operator.
INIT_CAP = 4
INIT_CAP_CONFIG_KEY = "backpressure_policies.concurrency_cap.init_cap"
# When the number of finished tasks reaches this threshold, the concurrency cap
# will be multiplied by the multiplier.
CAP_MULTIPLY_THRESHOLD = 0.5
CAP_MULTIPLY_THRESHOLD_CONFIG_KEY = (
"backpressure_policies.concurrency_cap.cap_multiply_threshold"
)
# The multiplier to multiply the concurrency cap by.
CAP_MULTIPLIER = 2.0
CAP_MULTIPLIER_CONFIG_KEY = "backpressure_policies.concurrency_cap.cap_multiplier"

def __init__(self, topology: "Topology"):
self._concurrency_caps: dict["PhysicalOperator", float] = {}

data_context = ray.data.DataContext.get_current()
self._init_cap = data_context.get_config(
self.INIT_CAP_CONFIG_KEY, self.INIT_CAP
)
self._cap_multiplier = data_context.get_config(
self.CAP_MULTIPLIER_CONFIG_KEY, self.CAP_MULTIPLIER
)
self._cap_multiply_threshold = data_context.get_config(
self.CAP_MULTIPLY_THRESHOLD_CONFIG_KEY, self.CAP_MULTIPLY_THRESHOLD
)

assert self._init_cap > 0
assert 0 < self._cap_multiply_threshold <= 1
assert self._cap_multiplier >= 1
for op, _ in topology.items():
if isinstance(op, TaskPoolMapOperator) and op.get_concurrency() is not None:
self._concurrency_caps[op] = op.get_concurrency()
else:
self._concurrency_caps[op] = float("inf")

logger.debug(
"ConcurrencyCapBackpressurePolicy initialized with config: "
f"{self._init_cap}, {self._cap_multiply_threshold}, {self._cap_multiplier}"
"ConcurrencyCapBackpressurePolicy initialized with: "
f"{self._concurrency_caps}"
)

for op, _ in topology.items():
self._concurrency_caps[op] = self._init_cap

def can_add_input(self, op: "PhysicalOperator") -> bool:
metrics = op.metrics
while self._cap_multiplier > 1 and metrics.num_tasks_finished >= (
self._concurrency_caps[op] * self._cap_multiply_threshold
):
self._concurrency_caps[op] *= self._cap_multiplier
logger.debug(
f"Concurrency cap for {op} increased to {self._concurrency_caps[op]}"
)
return metrics.num_tasks_running < self._concurrency_caps[op]
return op.metrics.num_tasks_running < self._concurrency_caps[op]
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def create(
name=name,
target_max_block_size=target_max_block_size,
min_rows_per_bundle=min_rows_per_bundle,
concurrency=compute_strategy.size,
ray_remote_args=ray_remote_args,
)
elif isinstance(compute_strategy, ActorPoolStrategy):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
target_max_block_size: Optional[int],
name: str = "TaskPoolMap",
min_rows_per_bundle: Optional[int] = None,
concurrency: Optional[int] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
):
"""Create an TaskPoolMapOperator instance.
Expand All @@ -37,6 +38,8 @@ def __init__(
transform_fn, or None to use the block size. Setting the batch size is
important for the performance of GPU-accelerated transform functions.
The actual rows passed may be less if the dataset is small.
concurrency: The maximum number of Ray tasks to use concurrently,
or None to use as many tasks as possible.
ray_remote_args: Customize the ray remote args for this op's tasks.
"""
super().__init__(
Expand All @@ -47,6 +50,7 @@ def __init__(
min_rows_per_bundle,
ray_remote_args,
)
self._concurrency = concurrency

def _add_bundled_input(self, bundle: RefBundle):
# Submit the task as a normal Ray task.
Expand Down Expand Up @@ -114,3 +118,6 @@ def incremental_resource_usage(self) -> ExecutionResources:
gpu=self._ray_remote_args.get("num_gpus", 0),
object_store_memory=self._metrics.average_bytes_outputs_per_task,
)

def get_concurrency(self) -> Optional[int]:
return self._concurrency
28 changes: 14 additions & 14 deletions python/ray/data/_internal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,32 +587,32 @@ def get_compute_strategy(
)
return compute
elif concurrency is not None:
if not is_callable_class:
# Currently do not support concurrency control with function,
# i.e., running with Ray Tasks (`TaskPoolMapOperator`).
logger.warning(
"``concurrency`` is set, but ``fn`` is not a callable class: "
f"{fn}. ``concurrency`` are currently only supported when "
"``fn`` is a callable class."
)
return TaskPoolStrategy()

if isinstance(concurrency, tuple):
if (
len(concurrency) == 2
and isinstance(concurrency[0], int)
and isinstance(concurrency[1], int)
):
return ActorPoolStrategy(
min_size=concurrency[0], max_size=concurrency[1]
)
if is_callable_class:
return ActorPoolStrategy(
min_size=concurrency[0], max_size=concurrency[1]
)
else:
raise ValueError(
"``concurrency`` is set as a tuple of integers, but ``fn`` "
f"is not a callable class: {fn}. Use ``concurrency=n`` to "
"control maximum number of workers to use."
)
else:
raise ValueError(
"``concurrency`` is expected to be set as a tuple of "
f"integers, but got: {concurrency}."
)
elif isinstance(concurrency, int):
return ActorPoolStrategy(size=concurrency)
if is_callable_class:
return ActorPoolStrategy(size=concurrency)
else:
return TaskPoolStrategy(size=concurrency)
else:
raise ValueError(
"``concurrency`` is expected to be set as an integer or a "
Expand Down
Loading

0 comments on commit 46a9efe

Please sign in to comment.