Skip to content

Commit

Permalink
Revert "Revert "[Serve] ServeHandle detects ActorError and drop repl…
Browse files Browse the repository at this point in the history
…icas from target group (ray-project#26685)" (ray-project#27283)" (ray-project#27348)

Signed-off-by: Stefan van der Kleij <[email protected]>
  • Loading branch information
simon-mo authored and Stefan van der Kleij committed Aug 18, 2022
1 parent 7d5c7dd commit 352b9e0
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 10 deletions.
7 changes: 4 additions & 3 deletions python/ray/serve/_private/long_poll.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,14 @@ def _process_update(self, updates: Dict[str, UpdatedObject]):
if isinstance(updates, (ray.exceptions.RayTaskError)):
if isinstance(updates.as_instanceof_cause(), (asyncio.TimeoutError)):
logger.debug("LongPollClient polling timed out. Retrying.")
self._schedule_to_event_loop(self._reset)
else:
# Some error happened in the controller. It could be a bug or
# some undesired state.
logger.error("LongPollHost errored\n" + updates.traceback_str)
# We must call this in event loop so it works in Ray Client.
# See https://github.com/ray-project/ray/issues/20971
self._schedule_to_event_loop(self._poll_next)
# We must call this in event loop so it works in Ray Client.
# See https://github.com/ray-project/ray/issues/20971
self._schedule_to_event_loop(self._poll_next)
return

logger.debug(
Expand Down
54 changes: 47 additions & 7 deletions python/ray/serve/_private/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import ray
from ray.actor import ActorHandle
from ray.exceptions import RayActorError, RayTaskError
from ray.util import metrics

from ray.serve._private.common import RunningReplicaInfo
Expand Down Expand Up @@ -87,6 +88,17 @@ def __init__(
{"deployment": self.deployment_name}
)

def _reset_replica_iterator(self):
"""Reset the iterator used to load balance replicas.
This call is expected to be called after the replica membership has
been updated. It will shuffle the replicas randomly to avoid multiple
handle sending requests in the same order.
"""
replicas = list(self.in_flight_queries.keys())
random.shuffle(replicas)
self.replica_iterator = itertools.cycle(replicas)

def update_running_replicas(self, running_replicas: List[RunningReplicaInfo]):
added, removed, _ = compute_iterable_delta(
self.in_flight_queries.keys(), running_replicas
Expand All @@ -97,14 +109,13 @@ def update_running_replicas(self, running_replicas: List[RunningReplicaInfo]):

for removed_replica in removed:
# Delete it directly because shutdown is processed by controller.
del self.in_flight_queries[removed_replica]
# Replicas might already been deleted due to early detection of
# actor error.
self.in_flight_queries.pop(removed_replica, None)

if len(added) > 0 or len(removed) > 0:
# Shuffle the keys to avoid synchronization across clients.
replicas = list(self.in_flight_queries.keys())
random.shuffle(replicas)
self.replica_iterator = itertools.cycle(replicas)
logger.debug(f"ReplicaSet: +{len(added)}, -{len(removed)} replicas.")
self._reset_replica_iterator()
self.config_updated_event.set()

def _try_assign_replica(self, query: Query) -> Optional[ray.ObjectRef]:
Expand Down Expand Up @@ -160,9 +171,38 @@ def _all_query_refs(self):

def _drain_completed_object_refs(self) -> int:
refs = self._all_query_refs
# NOTE(simon): even though the timeout is 0, a large number of refs can still
# cause some blocking delay in the event loop. Consider moving this to async?
done, _ = ray.wait(refs, num_returns=len(refs), timeout=0)
for replica_in_flight_queries in self.in_flight_queries.values():
replica_in_flight_queries.difference_update(done)
replicas_to_remove = []
for replica_info, replica_in_flight_queries in self.in_flight_queries.items():
completed_queries = replica_in_flight_queries.intersection(done)
if len(completed_queries):
try:
# NOTE(simon): this ray.get call should be cheap because all these
# refs are ready as indicated by previous `ray.wait` call.
ray.get(list(completed_queries))
except RayActorError:
logger.debug(
f"Removing {replica_info.replica_tag} from replica set "
"because the actor exited."
)
replicas_to_remove.append(replica_info)
except RayTaskError:
# Ignore application error.
pass
except Exception:
logger.exception(
"Handle received unexpected error when processing request."
)

replica_in_flight_queries.difference_update(completed_queries)

if len(replicas_to_remove) > 0:
for replica_info in replicas_to_remove:
self.in_flight_queries.pop(replica_info, None)
self._reset_replica_iterator()

return len(done)

async def assign_replica(self, query: Query) -> ray.ObjectRef:
Expand Down
37 changes: 37 additions & 0 deletions python/ray/serve/tests/test_standalone2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import requests

import ray
import ray.actor
import ray._private.state
from ray import serve
from ray._private.test_utils import wait_for_condition
Expand Down Expand Up @@ -650,6 +651,42 @@ def test_shutdown_remote(start_and_shutdown_ray_cli_function):
os.unlink(shutdown_file.name)


def test_handle_early_detect_failure(shutdown_ray):
"""Check that handle can be notified about replicas failure.
It should detect replica raises ActorError and take them out of the replicas set.
"""
ray.init()
serve.start(detached=True)

@serve.deployment(num_replicas=2, max_concurrent_queries=1)
def f(do_crash: bool = False):
if do_crash:
os._exit(1)
return os.getpid()

handle = serve.run(f.bind())
pids = ray.get([handle.remote() for _ in range(2)])
assert len(set(pids)) == 2
assert len(handle.router._replica_set.in_flight_queries.keys()) == 2

client = get_global_client()
# Kill the controller so that the replicas membership won't be updated
# through controller health check + long polling.
ray.kill(client._controller, no_restart=True)

with pytest.raises(RayActorError):
ray.get(handle.remote(do_crash=True))

pids = ray.get([handle.remote() for _ in range(10)])
assert len(set(pids)) == 1
assert len(handle.router._replica_set.in_flight_queries.keys()) == 1

# Restart the controller, and then clean up all the replicas
serve.start(detached=True)
serve.shutdown()


def test_autoscaler_shutdown_node_http_everynode(
shutdown_ray, call_ray_stop_only # noqa: F811
):
Expand Down

0 comments on commit 352b9e0

Please sign in to comment.