Skip to content

Commit

Permalink
[data] Make ActorPoolStrategy kill pool of actors if exception is rai…
Browse files Browse the repository at this point in the history
  • Loading branch information
peytondmurray committed Jun 18, 2022
1 parent 9fe3c81 commit 815dba5
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 67 deletions.
147 changes: 80 additions & 67 deletions python/ray/data/_internal/compute.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import collections
from typing import TypeVar, Any, Union, Callable, List, Tuple, Optional
import logging
from typing import Any, Callable, List, Optional, Tuple, TypeVar, Union

import ray
from ray.util.annotations import PublicAPI, DeveloperAPI
from ray.data._internal.block_list import BlockList
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data._internal.progress_bar import ProgressBar
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data.block import (
Block,
BlockAccessor,
BlockExecStats,
BlockMetadata,
BlockPartition,
BlockExecStats,
)
from ray.data.context import DatasetContext, DEFAULT_SCHEDULING_STRATEGY
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data._internal.block_list import BlockList
from ray.data._internal.progress_bar import ProgressBar
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data.context import DEFAULT_SCHEDULING_STRATEGY, DatasetContext
from ray.util.annotations import DeveloperAPI, PublicAPI

logger = logging.getLogger(__name__)

T = TypeVar("T")
U = TypeVar("U")
Expand Down Expand Up @@ -203,73 +206,83 @@ def map_block_nosplit(
block_indices = {}
ready_workers = set()

while len(results) < orig_num_blocks:
ready, _ = ray.wait(
list(tasks.keys()), timeout=0.01, num_returns=1, fetch_local=False
)
if not ready:
if (
len(workers) < self.max_size
and len(ready_workers) / len(workers) > 0.8
):
w = BlockWorker.remote()
workers.append(w)
tasks[w.ready.remote()] = w
try:
while len(results) < orig_num_blocks:
ready, _ = ray.wait(
list(tasks.keys()), timeout=0.01, num_returns=1, fetch_local=False
)
if not ready:
if (
len(workers) < self.max_size
and len(ready_workers) / len(workers) > 0.8
):
w = BlockWorker.remote()
workers.append(w)
tasks[w.ready.remote()] = w
map_bar.set_description(
"Map Progress ({} actors {} pending)".format(
len(ready_workers), len(workers) - len(ready_workers)
)
)
continue

[obj_id] = ready
worker = tasks.pop(obj_id)

# Process task result.
if worker in ready_workers:
results.append(obj_id)
tasks_in_flight[worker] -= 1
map_bar.update(1)
else:
ready_workers.add(worker)
map_bar.set_description(
"Map Progress ({} actors {} pending)".format(
len(ready_workers), len(workers) - len(ready_workers)
)
)
continue

[obj_id] = ready
worker = tasks.pop(obj_id)

# Process task result.
if worker in ready_workers:
results.append(obj_id)
tasks_in_flight[worker] -= 1
map_bar.update(1)
# Schedule a new task.
while (
blocks_in
and tasks_in_flight[worker] < self.max_tasks_in_flight_per_actor
):
block, meta = blocks_in.pop()
if context.block_splitting_enabled:
ref = worker.map_block_split.remote(block, meta.input_files)
else:
ref, meta_ref = worker.map_block_nosplit.remote(
block, meta.input_files
)
metadata_mapping[ref] = meta_ref
tasks[ref] = worker
block_indices[ref] = len(blocks_in)
tasks_in_flight[worker] += 1

map_bar.close()
new_blocks, new_metadata = [], []
# Put blocks in input order.
results.sort(key=block_indices.get)
if context.block_splitting_enabled:
for result in ray.get(results):
for block, metadata in result:
new_blocks.append(block)
new_metadata.append(metadata)
else:
ready_workers.add(worker)
map_bar.set_description(
"Map Progress ({} actors {} pending)".format(
len(ready_workers), len(workers) - len(ready_workers)
)
)

# Schedule a new task.
while (
blocks_in
and tasks_in_flight[worker] < self.max_tasks_in_flight_per_actor
):
block, meta = blocks_in.pop()
if context.block_splitting_enabled:
ref = worker.map_block_split.remote(block, meta.input_files)
else:
ref, meta_ref = worker.map_block_nosplit.remote(
block, meta.input_files
)
metadata_mapping[ref] = meta_ref
tasks[ref] = worker
block_indices[ref] = len(blocks_in)
tasks_in_flight[worker] += 1

map_bar.close()
new_blocks, new_metadata = [], []
# Put blocks in input order.
results.sort(key=block_indices.get)
if context.block_splitting_enabled:
for result in ray.get(results):
for block, metadata in result:
for block in results:
new_blocks.append(block)
new_metadata.append(metadata)
else:
for block in results:
new_blocks.append(block)
new_metadata.append(metadata_mapping[block])
new_metadata = ray.get(new_metadata)
return BlockList(new_blocks, new_metadata)
new_metadata.append(metadata_mapping[block])
new_metadata = ray.get(new_metadata)
return BlockList(new_blocks, new_metadata)

except Exception as e:
try:
for worker in workers:
ray.kill(worker)
except Exception as err:
logger.exception(f"Error killing workers: {err}")
finally:
raise e


def cache_wrapper(
Expand Down
27 changes: 27 additions & 0 deletions python/ray/data/tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
import os
import random
import signal
import time

import numpy as np
Expand All @@ -10,6 +11,7 @@
import pytest

import ray
from ray._private.test_utils import wait_for_condition
from ray.data._internal.arrow_block import ArrowRow
from ray.data._internal.block_builder import BlockBuilder
from ray.data._internal.lazy_block_list import LazyBlockList
Expand Down Expand Up @@ -3975,6 +3977,31 @@ def f(should_import_polars):
ctx.use_polars = original_use_polars


def test_actorpoolstrategy_apply_interrupt():
"""Test that _apply kills the actor pool if an interrupt is raised."""
ray.init(include_dashboard=False, num_cpus=1)

cpus = ray.available_resources()["CPU"]
ds = ray.data.range(5)
aps = ray.data.ActorPoolStrategy(max_size=5)
blocks = ds._plan.execute()

# Start some actors, the first one sends a SIGINT, emulating a KeyboardInterrupt
def test_func(block):
for i, _ in enumerate(BlockAccessor.for_block(block).iter_rows()):
if i == 0:
os.kill(os.getpid(), signal.SIGINT)
else:
time.sleep(1000)
return block

with pytest.raises(ray.exceptions.RayTaskError):
aps._apply(test_func, {}, blocks, False)

# Check that all actors have been killed by counting the available CPUs
wait_for_condition(lambda: (ray.available_resources().get("CPU", 0) == cpus))


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit 815dba5

Please sign in to comment.