diff --git a/python/ray/data/_internal/compute.py b/python/ray/data/_internal/compute.py index 8c96688925980..fecc415c8983c 100644 --- a/python/ray/data/_internal/compute.py +++ b/python/ray/data/_internal/compute.py @@ -56,6 +56,20 @@ def _apply( @DeveloperAPI class TaskPoolStrategy(ComputeStrategy): + def __init__( + self, + size: Optional[int] = None, + ): + """Construct TaskPoolStrategy for a Dataset transform. + + Args: + size: Specify the maximum size of the task pool. + """ + + if size is not None and size < 1: + raise ValueError("`size` must be >= 1", size) + self.size = size + def _apply( self, block_fn: BlockTransform, @@ -148,7 +162,9 @@ def _apply( ) def __eq__(self, other: Any) -> bool: - return isinstance(other, TaskPoolStrategy) or other == "tasks" + return (isinstance(other, TaskPoolStrategy) and self.size == other.size) or ( + other == "tasks" and self.size is None + ) @PublicAPI diff --git a/python/ray/data/_internal/execution/backpressure_policy/__init__.py b/python/ray/data/_internal/execution/backpressure_policy/__init__.py index 9f15fddc59cbf..57d52f96460fa 100644 --- a/python/ray/data/_internal/execution/backpressure_policy/__init__.py +++ b/python/ray/data/_internal/execution/backpressure_policy/__init__.py @@ -10,9 +10,8 @@ # Default enabled backpressure policies and its config key. # Use `DataContext.set_config` to config it. -# TODO(hchen): Enable ConcurrencyCapBackpressurePolicy by default. # TODO(hchen): Enable StreamingOutputBackpressurePolicy by default. -ENABLED_BACKPRESSURE_POLICIES = [] +ENABLED_BACKPRESSURE_POLICIES = [ConcurrencyCapBackpressurePolicy] ENABLED_BACKPRESSURE_POLICIES_CONFIG_KEY = "backpressure_policies.enabled" diff --git a/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py b/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py index 9052447c2587b..a52bd1f6ab9f7 100644 --- a/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py +++ b/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py @@ -1,8 +1,10 @@ import logging from typing import TYPE_CHECKING -import ray from .backpressure_policy import BackpressurePolicy +from ray.data._internal.execution.operators.task_pool_map_operator import ( + TaskPoolMapOperator, +) if TYPE_CHECKING: from ray.data._internal.execution.interfaces.physical_operator import ( @@ -16,67 +18,26 @@ class ConcurrencyCapBackpressurePolicy(BackpressurePolicy): """A backpressure policy that caps the concurrency of each operator. - The concurrency cap limits the number of concurrently running tasks. - It will be set to an intial value, and will ramp up exponentially. + The policy will limit the number of concurrently running tasks based on its + concurrency cap parameter. - The concrete stategy is as follows: - - Each PhysicalOperator is assigned an initial concurrency cap. - - An PhysicalOperator can run new tasks if the number of running tasks is less - than the cap. - - When the number of finished tasks reaches a threshold, the concurrency cap will - increase. + NOTE: Only support setting concurrency cap for `TaskPoolMapOperator` for now. + TODO(chengsu): Consolidate with actor scaling logic of `ActorPoolMapOperator`. """ - # Following are the default values followed by the config keys of the - # available configs. - # Use `DataContext.set_config` to config them. - - # The intial concurrency cap for each operator. - INIT_CAP = 4 - INIT_CAP_CONFIG_KEY = "backpressure_policies.concurrency_cap.init_cap" - # When the number of finished tasks reaches this threshold, the concurrency cap - # will be multiplied by the multiplier. - CAP_MULTIPLY_THRESHOLD = 0.5 - CAP_MULTIPLY_THRESHOLD_CONFIG_KEY = ( - "backpressure_policies.concurrency_cap.cap_multiply_threshold" - ) - # The multiplier to multiply the concurrency cap by. - CAP_MULTIPLIER = 2.0 - CAP_MULTIPLIER_CONFIG_KEY = "backpressure_policies.concurrency_cap.cap_multiplier" - def __init__(self, topology: "Topology"): self._concurrency_caps: dict["PhysicalOperator", float] = {} - data_context = ray.data.DataContext.get_current() - self._init_cap = data_context.get_config( - self.INIT_CAP_CONFIG_KEY, self.INIT_CAP - ) - self._cap_multiplier = data_context.get_config( - self.CAP_MULTIPLIER_CONFIG_KEY, self.CAP_MULTIPLIER - ) - self._cap_multiply_threshold = data_context.get_config( - self.CAP_MULTIPLY_THRESHOLD_CONFIG_KEY, self.CAP_MULTIPLY_THRESHOLD - ) - - assert self._init_cap > 0 - assert 0 < self._cap_multiply_threshold <= 1 - assert self._cap_multiplier >= 1 + for op, _ in topology.items(): + if isinstance(op, TaskPoolMapOperator) and op.get_concurrency() is not None: + self._concurrency_caps[op] = op.get_concurrency() + else: + self._concurrency_caps[op] = float("inf") logger.debug( - "ConcurrencyCapBackpressurePolicy initialized with config: " - f"{self._init_cap}, {self._cap_multiply_threshold}, {self._cap_multiplier}" + "ConcurrencyCapBackpressurePolicy initialized with: " + f"{self._concurrency_caps}" ) - for op, _ in topology.items(): - self._concurrency_caps[op] = self._init_cap - def can_add_input(self, op: "PhysicalOperator") -> bool: - metrics = op.metrics - while self._cap_multiplier > 1 and metrics.num_tasks_finished >= ( - self._concurrency_caps[op] * self._cap_multiply_threshold - ): - self._concurrency_caps[op] *= self._cap_multiplier - logger.debug( - f"Concurrency cap for {op} increased to {self._concurrency_caps[op]}" - ) - return metrics.num_tasks_running < self._concurrency_caps[op] + return op.metrics.num_tasks_running < self._concurrency_caps[op] diff --git a/python/ray/data/_internal/execution/operators/map_operator.py b/python/ray/data/_internal/execution/operators/map_operator.py index 774fde880f0a5..83eeb58f0a594 100644 --- a/python/ray/data/_internal/execution/operators/map_operator.py +++ b/python/ray/data/_internal/execution/operators/map_operator.py @@ -152,6 +152,7 @@ def create( name=name, target_max_block_size=target_max_block_size, min_rows_per_bundle=min_rows_per_bundle, + concurrency=compute_strategy.size, ray_remote_args=ray_remote_args, ) elif isinstance(compute_strategy, ActorPoolStrategy): diff --git a/python/ray/data/_internal/execution/operators/task_pool_map_operator.py b/python/ray/data/_internal/execution/operators/task_pool_map_operator.py index a3404897d4407..b46a039f44ff9 100644 --- a/python/ray/data/_internal/execution/operators/task_pool_map_operator.py +++ b/python/ray/data/_internal/execution/operators/task_pool_map_operator.py @@ -23,6 +23,7 @@ def __init__( target_max_block_size: Optional[int], name: str = "TaskPoolMap", min_rows_per_bundle: Optional[int] = None, + concurrency: Optional[int] = None, ray_remote_args: Optional[Dict[str, Any]] = None, ): """Create an TaskPoolMapOperator instance. @@ -37,6 +38,8 @@ def __init__( transform_fn, or None to use the block size. Setting the batch size is important for the performance of GPU-accelerated transform functions. The actual rows passed may be less if the dataset is small. + concurrency: The maximum number of Ray tasks to use concurrently, + or None to use as many tasks as possible. ray_remote_args: Customize the ray remote args for this op's tasks. """ super().__init__( @@ -47,6 +50,7 @@ def __init__( min_rows_per_bundle, ray_remote_args, ) + self._concurrency = concurrency def _add_bundled_input(self, bundle: RefBundle): # Submit the task as a normal Ray task. @@ -114,3 +118,6 @@ def incremental_resource_usage(self) -> ExecutionResources: gpu=self._ray_remote_args.get("num_gpus", 0), object_store_memory=self._metrics.average_bytes_outputs_per_task, ) + + def get_concurrency(self) -> Optional[int]: + return self._concurrency diff --git a/python/ray/data/_internal/util.py b/python/ray/data/_internal/util.py index d7b0b644f8bcb..d5dfedff84fe2 100644 --- a/python/ray/data/_internal/util.py +++ b/python/ray/data/_internal/util.py @@ -587,32 +587,32 @@ def get_compute_strategy( ) return compute elif concurrency is not None: - if not is_callable_class: - # Currently do not support concurrency control with function, - # i.e., running with Ray Tasks (`TaskPoolMapOperator`). - logger.warning( - "``concurrency`` is set, but ``fn`` is not a callable class: " - f"{fn}. ``concurrency`` are currently only supported when " - "``fn`` is a callable class." - ) - return TaskPoolStrategy() - if isinstance(concurrency, tuple): if ( len(concurrency) == 2 and isinstance(concurrency[0], int) and isinstance(concurrency[1], int) ): - return ActorPoolStrategy( - min_size=concurrency[0], max_size=concurrency[1] - ) + if is_callable_class: + return ActorPoolStrategy( + min_size=concurrency[0], max_size=concurrency[1] + ) + else: + raise ValueError( + "``concurrency`` is set as a tuple of integers, but ``fn`` " + f"is not a callable class: {fn}. Use ``concurrency=n`` to " + "control maximum number of workers to use." + ) else: raise ValueError( "``concurrency`` is expected to be set as a tuple of " f"integers, but got: {concurrency}." ) elif isinstance(concurrency, int): - return ActorPoolStrategy(size=concurrency) + if is_callable_class: + return ActorPoolStrategy(size=concurrency) + else: + return TaskPoolStrategy(size=concurrency) else: raise ValueError( "``concurrency`` is expected to be set as an integer or a " diff --git a/python/ray/data/tests/test_backpressure_policies.py b/python/ray/data/tests/test_backpressure_policies.py index 2673d069e802d..e852e182f4149 100644 --- a/python/ray/data/tests/test_backpressure_policies.py +++ b/python/ray/data/tests/test_backpressure_policies.py @@ -1,8 +1,8 @@ import functools +import math import time import unittest from collections import defaultdict -from contextlib import contextmanager from unittest.mock import MagicMock, patch import numpy as np @@ -14,6 +14,10 @@ ConcurrencyCapBackpressurePolicy, StreamingOutputBackpressurePolicy, ) +from ray.data._internal.execution.operators.input_data_buffer import InputDataBuffer +from ray.data._internal.execution.operators.task_pool_map_operator import ( + TaskPoolMapOperator, +) from ray.data.tests.conftest import restore_data_context # noqa: F401 from ray.data.tests.conftest import ( CoreExecutionMetrics, @@ -42,94 +46,43 @@ def tearDownClass(cls): data_context = ray.data.DataContext.get_current() data_context.remove_config(ENABLED_BACKPRESSURE_POLICIES_CONFIG_KEY) - @contextmanager - def _patch_config(self, init_cap, cap_multiply_threshold, cap_multiplier): - data_context = ray.data.DataContext.get_current() - data_context.set_config( - ConcurrencyCapBackpressurePolicy.INIT_CAP_CONFIG_KEY, - init_cap, - ) - data_context.set_config( - ConcurrencyCapBackpressurePolicy.CAP_MULTIPLY_THRESHOLD_CONFIG_KEY, - cap_multiply_threshold, - ) - data_context.set_config( - ConcurrencyCapBackpressurePolicy.CAP_MULTIPLIER_CONFIG_KEY, - cap_multiplier, - ) - yield - data_context.remove_config(ConcurrencyCapBackpressurePolicy.INIT_CAP_CONFIG_KEY) - data_context.remove_config( - ConcurrencyCapBackpressurePolicy.CAP_MULTIPLY_THRESHOLD_CONFIG_KEY - ) - data_context.remove_config( - ConcurrencyCapBackpressurePolicy.CAP_MULTIPLIER_CONFIG_KEY - ) - def test_basic(self): - op = MagicMock() - op.metrics = MagicMock( - num_tasks_running=0, - num_tasks_finished=0, + concurrency = 16 + input_op = InputDataBuffer(input_data=[MagicMock()]) + map_op_no_concurrency = TaskPoolMapOperator( + map_transformer=MagicMock(), + input_op=input_op, + target_max_block_size=None, + ) + map_op = TaskPoolMapOperator( + map_transformer=MagicMock(), + input_op=map_op_no_concurrency, + target_max_block_size=None, + concurrency=concurrency, ) - topology = {op: MagicMock()} + map_op.metrics.num_tasks_running = 0 + map_op.metrics.num_tasks_finished = 0 + topology = { + map_op: MagicMock(), + input_op: MagicMock(), + map_op_no_concurrency: MagicMock(), + } - init_cap = 4 - cap_multiply_threshold = 0.5 - cap_multiplier = 2.0 + policy = ConcurrencyCapBackpressurePolicy(topology) - with self._patch_config(init_cap, cap_multiply_threshold, cap_multiplier): - policy = ConcurrencyCapBackpressurePolicy(topology) + self.assertEqual(policy._concurrency_caps[map_op], concurrency) + self.assertTrue(math.isinf(policy._concurrency_caps[input_op])) + self.assertTrue(math.isinf(policy._concurrency_caps[map_op_no_concurrency])) - self.assertEqual(policy._concurrency_caps[op], 4) # Gradually increase num_tasks_running to the cap. - for i in range(1, init_cap + 1): - self.assertTrue(policy.can_add_input(op)) - op.metrics.num_tasks_running = i + for i in range(1, concurrency + 1): + self.assertTrue(policy.can_add_input(map_op)) + map_op.metrics.num_tasks_running = i # Now num_tasks_running reaches the cap, so can_add_input should return False. - self.assertFalse(policy.can_add_input(op)) - - # If we increase num_task_finished to the threshold (4 * 0.5 = 2), - # it should trigger the cap to increase. - op.metrics.num_tasks_finished = init_cap * cap_multiply_threshold - self.assertEqual(policy.can_add_input(op), True) - self.assertEqual(policy._concurrency_caps[op], init_cap * cap_multiplier) - - # Now the cap is 8 (4 * 2). - # If we increase num_tasks_finished directly to the next-level's threshold - # (8 * 2 * 0.5 = 8), it should trigger the cap to increase twice. - op.metrics.num_tasks_finished = ( - policy._concurrency_caps[op] * cap_multiplier * cap_multiply_threshold - ) - op.metrics.num_tasks_running = 0 - self.assertEqual(policy.can_add_input(op), True) - self.assertEqual(policy._concurrency_caps[op], init_cap * cap_multiplier**3) + self.assertFalse(policy.can_add_input(map_op)) - def test_config(self): - topology = {} - # Test good config. - with self._patch_config(10, 0.3, 1.5): - policy = ConcurrencyCapBackpressurePolicy(topology) - self.assertEqual(policy._init_cap, 10) - self.assertEqual(policy._cap_multiply_threshold, 0.3) - self.assertEqual(policy._cap_multiplier, 1.5) - - with self._patch_config(10, 0.3, 1): - policy = ConcurrencyCapBackpressurePolicy(topology) - self.assertEqual(policy._init_cap, 10) - self.assertEqual(policy._cap_multiply_threshold, 0.3) - self.assertEqual(policy._cap_multiplier, 1) - - # Test bad configs. - with self._patch_config(-1, 0.3, 1.5): - with self.assertRaises(AssertionError): - policy = ConcurrencyCapBackpressurePolicy(topology) - with self._patch_config(10, 1.1, 1.5): - with self.assertRaises(AssertionError): - policy = ConcurrencyCapBackpressurePolicy(topology) - with self._patch_config(10, 0.3, 0.5): - with self.assertRaises(AssertionError): - policy = ConcurrencyCapBackpressurePolicy(topology) + map_op.metrics.num_tasks_running = concurrency / 2 + self.assertEqual(policy.can_add_input(map_op), True) def _create_record_time_actor(self): @ray.remote(num_cpus=0) @@ -172,8 +125,8 @@ def test_e2e_normal(self): N = self.__class__._cluster_cpus ds = ray.data.range(N, parallelism=N) # Use different `num_cpus` to make sure they don't fuse. - ds = ds.map_batches(map_func1, batch_size=None, num_cpus=1) - ds = ds.map_batches(map_func2, batch_size=None, num_cpus=1.1) + ds = ds.map_batches(map_func1, batch_size=None, num_cpus=1, concurrency=1) + ds = ds.map_batches(map_func2, batch_size=None, num_cpus=1.1, concurrency=1) res = ds.take_all() self.assertEqual(len(res), N) @@ -185,24 +138,6 @@ def test_e2e_normal(self): start2, end2 = ray.get(actor.get_start_and_end_time_for_op.remote(2)) assert start1 < start2 < end1 < end2, (start1, start2, end1, end2) - def test_e2e_no_ramping_up(self): - """Test setting the multiplier to 1.0, which means no ramping up of the - concurrency cap.""" - with self._patch_config(1, 1, 1): - actor = self._create_record_time_actor() - map_func1 = self._get_map_func(actor, 1) - N = self.__class__._cluster_cpus - ds = ray.data.range(N, parallelism=N) - ds = ds.map_batches(map_func1, batch_size=None, num_cpus=1) - res = ds.take_all() - self.assertEqual(len(res), N) - - start, end = ray.get( - actor.get_start_and_end_time_for_all_tasks_of_op.remote(1) - ) - for i in range(len(start) - 1): - assert start[i] < end[i] < start[i + 1], (i, start, end) - class TestStreamOutputBackpressurePolicy(unittest.TestCase): """Tests for StreamOutputBackpressurePolicy.""" diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index a3dc65e921460..a08642c8caba3 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -250,8 +250,13 @@ def __call__(self, x): for fn in [udf, UDFClass]: # Test concurrency with None, single integer and a tuple of integers. for concurrency in [2, (2, 4)]: - result = ds.map(fn, concurrency=concurrency).take_all() - assert sorted(extract_values("id", result)) == list(range(10)), result + if fn == udf and concurrency == (2, 4): + error_message = "``concurrency`` is set as a tuple of integers" + with pytest.raises(ValueError, match=error_message): + ds.map(fn, concurrency=concurrency).take_all() + else: + result = ds.map(fn, concurrency=concurrency).take_all() + assert sorted(extract_values("id", result)) == list(range(10)), result # Test concurrency with an illegal value. error_message = "``concurrency`` is expected to be set a"