Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Allow tasks to control concurrency in map-like APIs #42637

Merged
merged 6 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix unit test
Signed-off-by: Cheng Su <[email protected]>
  • Loading branch information
c21 committed Jan 29, 2024
commit 2ca833df82830ac64f406821c6d0f0fed673fc8a
5 changes: 3 additions & 2 deletions python/ray/data/_internal/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,9 @@ def _apply(
)

def __eq__(self, other: Any) -> bool:
return (isinstance(other, TaskPoolStrategy) and self.size == other.size)\
or (other == "tasks" and self.size is None)
return (isinstance(other, TaskPoolStrategy) and self.size == other.size) or (
other == "tasks" and self.size is None
)


@PublicAPI
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ def __init__(
self._ray_remote_args_factory = None
self._remote_args_for_metrics = copy.deepcopy(self._ray_remote_args)

# Initialize the back pressure policies to be empty list.
self._backpressure_policies = []

# Bundles block references up to the min_rows_per_bundle target.
self._block_ref_bundler = _BlockRefBundler(min_rows_per_bundle)

Expand Down
128 changes: 29 additions & 99 deletions python/ray/data/tests/test_backpressure_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import time
import unittest
from collections import defaultdict
from contextlib import contextmanager
from unittest.mock import MagicMock, patch

import numpy as np
Expand All @@ -14,6 +13,10 @@
ConcurrencyCapBackpressurePolicy,
StreamingOutputBackpressurePolicy,
)
from ray.data._internal.execution.operators.input_data_buffer import InputDataBuffer
from ray.data._internal.execution.operators.task_pool_map_operator import (
TaskPoolMapOperator,
)
from ray.data.tests.conftest import restore_data_context # noqa: F401
from ray.data.tests.conftest import (
CoreExecutionMetrics,
Expand Down Expand Up @@ -42,94 +45,39 @@ def tearDownClass(cls):
data_context = ray.data.DataContext.get_current()
data_context.remove_config(ENABLED_BACKPRESSURE_POLICIES_CONFIG_KEY)

@contextmanager
def _patch_config(self, init_cap, cap_multiply_threshold, cap_multiplier):
data_context = ray.data.DataContext.get_current()
data_context.set_config(
ConcurrencyCapBackpressurePolicy.INIT_CAP_CONFIG_KEY,
init_cap,
)
data_context.set_config(
ConcurrencyCapBackpressurePolicy.CAP_MULTIPLY_THRESHOLD_CONFIG_KEY,
cap_multiply_threshold,
)
data_context.set_config(
ConcurrencyCapBackpressurePolicy.CAP_MULTIPLIER_CONFIG_KEY,
cap_multiplier,
)
yield
data_context.remove_config(ConcurrencyCapBackpressurePolicy.INIT_CAP_CONFIG_KEY)
data_context.remove_config(
ConcurrencyCapBackpressurePolicy.CAP_MULTIPLY_THRESHOLD_CONFIG_KEY
)
data_context.remove_config(
ConcurrencyCapBackpressurePolicy.CAP_MULTIPLIER_CONFIG_KEY
)

def test_basic(self):
op = MagicMock()
op.metrics = MagicMock(
num_tasks_running=0,
num_tasks_finished=0,
concurrency = 16
input_op = InputDataBuffer(input_data=[MagicMock()])
map_op_no_concurrency = TaskPoolMapOperator(
map_transformer=MagicMock(),
input_op=input_op,
target_max_block_size=None,
)
topology = {op: MagicMock()}
map_op = TaskPoolMapOperator(
map_transformer=MagicMock(),
input_op=map_op_no_concurrency,
target_max_block_size=None,
concurrency=concurrency,
)
map_op.metrics.num_tasks_running = 0
map_op.metrics.num_tasks_finished = 0
topology = {map_op: MagicMock()}

init_cap = 4
cap_multiply_threshold = 0.5
cap_multiplier = 2.0
policy = ConcurrencyCapBackpressurePolicy(topology)

with self._patch_config(init_cap, cap_multiply_threshold, cap_multiplier):
policy = ConcurrencyCapBackpressurePolicy(topology)
self.assertEqual(policy._concurrency_caps[map_op], concurrency)
self.assertTrue(input_op not in policy._concurrency_caps)
self.assertTrue(map_op_no_concurrency not in policy._concurrency_caps)

self.assertEqual(policy._concurrency_caps[op], 4)
# Gradually increase num_tasks_running to the cap.
for i in range(1, init_cap + 1):
self.assertTrue(policy.can_add_input(op))
op.metrics.num_tasks_running = i
for i in range(1, concurrency + 1):
self.assertTrue(policy.can_add_input(map_op))
map_op.metrics.num_tasks_running = i
# Now num_tasks_running reaches the cap, so can_add_input should return False.
self.assertFalse(policy.can_add_input(op))

# If we increase num_task_finished to the threshold (4 * 0.5 = 2),
# it should trigger the cap to increase.
op.metrics.num_tasks_finished = init_cap * cap_multiply_threshold
self.assertEqual(policy.can_add_input(op), True)
self.assertEqual(policy._concurrency_caps[op], init_cap * cap_multiplier)

# Now the cap is 8 (4 * 2).
# If we increase num_tasks_finished directly to the next-level's threshold
# (8 * 2 * 0.5 = 8), it should trigger the cap to increase twice.
op.metrics.num_tasks_finished = (
policy._concurrency_caps[op] * cap_multiplier * cap_multiply_threshold
)
op.metrics.num_tasks_running = 0
self.assertEqual(policy.can_add_input(op), True)
self.assertEqual(policy._concurrency_caps[op], init_cap * cap_multiplier**3)
self.assertFalse(policy.can_add_input(map_op))

def test_config(self):
topology = {}
# Test good config.
with self._patch_config(10, 0.3, 1.5):
policy = ConcurrencyCapBackpressurePolicy(topology)
self.assertEqual(policy._init_cap, 10)
self.assertEqual(policy._cap_multiply_threshold, 0.3)
self.assertEqual(policy._cap_multiplier, 1.5)

with self._patch_config(10, 0.3, 1):
policy = ConcurrencyCapBackpressurePolicy(topology)
self.assertEqual(policy._init_cap, 10)
self.assertEqual(policy._cap_multiply_threshold, 0.3)
self.assertEqual(policy._cap_multiplier, 1)

# Test bad configs.
with self._patch_config(-1, 0.3, 1.5):
with self.assertRaises(AssertionError):
policy = ConcurrencyCapBackpressurePolicy(topology)
with self._patch_config(10, 1.1, 1.5):
with self.assertRaises(AssertionError):
policy = ConcurrencyCapBackpressurePolicy(topology)
with self._patch_config(10, 0.3, 0.5):
with self.assertRaises(AssertionError):
policy = ConcurrencyCapBackpressurePolicy(topology)
map_op.metrics.num_tasks_running = concurrency / 2
self.assertEqual(policy.can_add_input(map_op), True)

def _create_record_time_actor(self):
@ray.remote(num_cpus=0)
Expand Down Expand Up @@ -185,24 +133,6 @@ def test_e2e_normal(self):
start2, end2 = ray.get(actor.get_start_and_end_time_for_op.remote(2))
assert start1 < start2 < end1 < end2, (start1, start2, end1, end2)

def test_e2e_no_ramping_up(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also need to update the above test_e2e_normal to use the concurrency parameter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank, updated.

"""Test setting the multiplier to 1.0, which means no ramping up of the
concurrency cap."""
with self._patch_config(1, 1, 1):
actor = self._create_record_time_actor()
map_func1 = self._get_map_func(actor, 1)
N = self.__class__._cluster_cpus
ds = ray.data.range(N, parallelism=N)
ds = ds.map_batches(map_func1, batch_size=None, num_cpus=1)
res = ds.take_all()
self.assertEqual(len(res), N)

start, end = ray.get(
actor.get_start_and_end_time_for_all_tasks_of_op.remote(1)
)
for i in range(len(start) - 1):
assert start[i] < end[i] < start[i + 1], (i, start, end)


class TestStreamOutputBackpressurePolicy(unittest.TestCase):
"""Tests for StreamOutputBackpressurePolicy."""
Expand Down