Skip to content

Commit

Permalink
[serve] Wait until replicas have finished recovering (with timeout) t…
Browse files Browse the repository at this point in the history
…o broadcast `LongPoll` updates (ray-project#34675)

When the controller recovers, all replicas are put into the `RECOVERING` state. These are not included in long poll updates for running replicas, which means we broadcast an update that effectively clears out all available replicas in all handles.

This PR addresses this problem by avoiding broadcasting any updates until all replicas are fully recovered (or a 10s timeout is reached).

We also wait to run the `http_state` update loop because if a new proxy is started, it won't be able to serve any traffic due to having no replicas available.
  • Loading branch information
edoakes committed Apr 25, 2023
1 parent 72268e8 commit 38f4e44
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 29 deletions.
4 changes: 4 additions & 0 deletions python/ray/serve/_private/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@
# Env var to control legacy sync deployment handle behavior in DAG.
SYNC_HANDLE_IN_DAG_FEATURE_FLAG_ENV_KEY = "SERVE_DEPLOYMENT_HANDLE_IS_SYNC"

# Maximum duration to wait until broadcasting a long poll update if there are
# still replicas in the RECOVERING state.
RECOVERING_LONG_POLL_BROADCAST_TIMEOUT_S = 10.0


class ServeHandleType(str, Enum):
SYNC = "SYNC"
Expand Down
45 changes: 28 additions & 17 deletions python/ray/serve/_private/deployment_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,7 +1434,7 @@ def _scale_deployment_replicas(self) -> bool:

return replicas_stopped

def _check_curr_status(self) -> bool:
def _check_curr_status(self) -> Tuple[bool, bool]:
"""Check the current deployment status.
Checks the difference between the target vs. running replica count for
Expand All @@ -1443,8 +1443,7 @@ def _check_curr_status(self) -> bool:
This will update the current deployment status depending on the state
of the replicas.
Returns:
was_deleted
Returns (deleted, any_replicas_recovering).
"""
# TODO(edoakes): we could make this more efficient in steady-state by
# having a "healthy" flag that gets flipped if an update or replica
Expand All @@ -1453,6 +1452,9 @@ def _check_curr_status(self) -> bool:
target_version = self._target_state.version
target_replica_count = self._target_state.num_replicas

any_replicas_recovering = (
self._replicas.count(states=[ReplicaState.RECOVERING]) > 0
)
all_running_replica_cnt = self._replicas.count(states=[ReplicaState.RUNNING])
running_at_target_version_replica_cnt = self._replicas.count(
states=[ReplicaState.RUNNING], version=target_version
Expand Down Expand Up @@ -1488,7 +1490,7 @@ def _check_curr_status(self) -> bool:
f"details. Retrying after {self._backoff_time_s} seconds."
),
)
return False
return False, any_replicas_recovering

# If we have pending ops, the current goal is *not* ready.
if (
Expand All @@ -1504,16 +1506,16 @@ def _check_curr_status(self) -> bool:
):
# Check for deleting.
if self._target_state.deleting and all_running_replica_cnt == 0:
return True
return True, any_replicas_recovering

# Check for a non-zero number of deployments.
if target_replica_count == running_at_target_version_replica_cnt:
self._curr_status_info = DeploymentStatusInfo(
self._name, DeploymentStatus.HEALTHY
)
return False
return False, any_replicas_recovering

return False
return False, any_replicas_recovering

def _check_startup_replicas(
self, original_state: ReplicaState, stop_on_slow=False
Expand Down Expand Up @@ -1707,16 +1709,17 @@ def _check_and_update_replicas(self) -> bool:

return running_replicas_changed

def update(self) -> bool:
def update(self) -> Tuple[bool, bool]:
"""Attempts to reconcile this deployment to match its goal state.
This is an asynchronous call; it's expected to be called repeatedly.
Also updates the internal DeploymentStatusInfo based on the current
state of the system.
Returns true if this deployment was successfully deleted.
Returns (deleted, any_replicas_recovering).
"""
deleted, any_replicas_recovering = False, False
try:
# Add or remove DeploymentReplica instances in self._replicas.
# This should be the only place we adjust total number of replicas
Expand All @@ -1730,16 +1733,15 @@ def update(self) -> bool:
if running_replicas_changed:
self._notify_running_replicas_changed()

deleted = self._check_curr_status()
deleted, any_replicas_recovering = self._check_curr_status()
except Exception:
self._curr_status_info = DeploymentStatusInfo(
name=self._name,
status=DeploymentStatus.UNHEALTHY,
message="Failed to update deployment:" f"\n{traceback.format_exc()}",
)
deleted = False

return deleted
return deleted, any_replicas_recovering

def _stop_one_running_replica_for_testing(self):
running_replicas = self._replicas.pop(states=[ReplicaState.RUNNING])
Expand Down Expand Up @@ -1837,7 +1839,8 @@ def _calculate_max_replicas_to_stop(self) -> int:
pending_replicas = nums_nodes - new_running_replicas - old_running_replicas
return max(rollout_size - pending_replicas, 0)

def update(self) -> bool:
def update(self) -> Tuple[bool, bool]:
"""Returns (deleted, any_replicas_recovering)."""
try:
if self._target_state.deleting:
self._stop_all_replicas()
Expand All @@ -1864,7 +1867,7 @@ def update(self) -> bool:
status=DeploymentStatus.UNHEALTHY,
message="Failed to update deployment:" f"\n{traceback.format_exc()}",
)
return False
return False, False

def should_autoscale(self) -> bool:
return False
Expand Down Expand Up @@ -2194,9 +2197,13 @@ def get_handle_queueing_metrics(
current_handle_queued_queries = 0
return current_handle_queued_queries

def update(self):
"""Updates the state of all deployments to match their goal state."""
def update(self) -> bool:
"""Updates the state of all deployments to match their goal state.
Returns True if any of the deployments have replicas in the RECOVERING state.
"""
deleted_tags = []
any_recovering = False
for deployment_name, deployment_state in self._deployment_states.items():
if deployment_state.should_autoscale():
current_num_ongoing_requests = self.get_replica_ongoing_request_metrics(
Expand All @@ -2210,7 +2217,7 @@ def update(self):
deployment_state.autoscale(
current_num_ongoing_requests, current_handle_queued_queries
)
deleted = deployment_state.update()
deleted, recovering = deployment_state.update()
if deleted:
deleted_tags.append(deployment_name)
deployment_info = deployment_state.target_info
Expand All @@ -2219,12 +2226,16 @@ def update(self):
self._deleted_deployment_metadata.popitem(last=False)
self._deleted_deployment_metadata[deployment_name] = deployment_info

any_recovering |= recovering

for tag in deleted_tags:
del self._deployment_states[tag]

if len(deleted_tags):
self._record_deployment_usage()

return any_recovering

def _record_deployment_usage(self):
record_extra_usage_tag(
TagKey.SERVE_NUM_DEPLOYMENTS, str(len(self._deployment_states))
Expand Down
31 changes: 29 additions & 2 deletions python/ray/serve/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
SERVE_ROOT_URL_ENV_KEY,
SERVE_NAMESPACE,
RAY_INTERNAL_SERVE_CONTROLLER_PIN_ON_NODE,
RECOVERING_LONG_POLL_BROADCAST_TIMEOUT_S,
SERVE_DEFAULT_APP_NAME,
DEPLOYMENT_NAME_PREFIX_SEPARATOR,
MULTI_APP_MIGRATION_MESSAGE,
Expand Down Expand Up @@ -120,6 +121,7 @@ async def __init__(
self.deployment_stats = defaultdict(lambda: defaultdict(dict))

self.long_poll_host = LongPollHost()
self.done_recovering_event = asyncio.Event()

if _disable_http_proxy:
self.http_state = None
Expand Down Expand Up @@ -197,6 +199,9 @@ async def listen_for_change(self, keys_to_snapshot_ids: Dict[str, int]):
determine whether or not the host should immediately return the
data or wait for the value to be changed.
"""
if not self.done_recovering_event.is_set():
await self.done_recovering_event.wait()

return await (self.long_poll_host.listen_for_change(keys_to_snapshot_ids))

async def listen_for_change_java(self, keys_to_snapshot_ids_bytes: bytes):
Expand All @@ -206,6 +211,9 @@ async def listen_for_change_java(self, keys_to_snapshot_ids_bytes: bytes):
keys_to_snapshot_ids_bytes (Dict[str, int]): the protobuf bytes of
keys_to_snapshot_ids (Dict[str, int]).
"""
if not self.done_recovering_event.is_set():
await self.done_recovering_event.wait()

return await (
self.long_poll_host.listen_for_change_java(keys_to_snapshot_ids_bytes)
)
Expand Down Expand Up @@ -250,16 +258,35 @@ async def run_control_loop(self) -> None:
# NOTE(edoakes): we catch all exceptions here and simply log them,
# because an unhandled exception would cause the main control loop to
# halt, which should *never* happen.
recovering_timeout = RECOVERING_LONG_POLL_BROADCAST_TIMEOUT_S
start_time = time.time()
while True:
if self.http_state:
if (
not self.done_recovering_event.is_set()
and time.time() - start_time > recovering_timeout
):
logger.warning(
f"Replicas still recovering after {recovering_timeout}s, "
"setting done recovering event to broadcast long poll updates."
)
self.done_recovering_event.set()

# Don't update http_state until after the done recovering event is set,
# otherwise we may start a new HTTP proxy but not broadcast it any
# info about available deployments & their replicas.
if self.http_state and self.done_recovering_event.is_set():
try:
self.http_state.update()
except Exception:
logger.exception("Exception updating HTTP state.")

try:
self.deployment_state_manager.update()
any_recovering = self.deployment_state_manager.update()
if not self.done_recovering_event.is_set() and not any_recovering:
self.done_recovering_event.set()
except Exception:
logger.exception("Exception updating deployment state.")

try:
self.application_state_manager.update()
except Exception:
Expand Down
26 changes: 16 additions & 10 deletions python/ray/serve/tests/test_deployment_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def test_create_delete_single_replica(mock_get_all_node_ids, mock_deployment_sta
# Once it's done stopping, replica should be removed.
replica = deployment_state._replicas.get()[0]
replica._actor.set_done_stopping()
deleted = deployment_state.update()
deleted, _ = deployment_state.update()
assert deleted
check_counts(deployment_state, total=0)

Expand Down Expand Up @@ -557,7 +557,7 @@ def test_force_kill(mock_get_all_node_ids, mock_deployment_state):
# Once the replica is done stopping, it should be removed.
replica = deployment_state._replicas.get()[0]
replica._actor.set_done_stopping()
deleted = deployment_state.update()
deleted, _ = deployment_state.update()
assert deleted
check_counts(deployment_state, total=0)

Expand Down Expand Up @@ -689,7 +689,7 @@ def test_redeploy_no_version(mock_get_all_node_ids, mock_deployment_state):
check_counts(deployment_state, total=1, by_state=[(ReplicaState.STARTING, 1)])
assert deployment_state.curr_status_info.status == DeploymentStatus.UPDATING

deleted = deployment_state.update()
deleted, _ = deployment_state.update()
assert not deleted
check_counts(deployment_state, total=1, by_state=[(ReplicaState.RUNNING, 1)])
assert deployment_state.curr_status_info.status == DeploymentStatus.HEALTHY
Expand Down Expand Up @@ -793,7 +793,7 @@ def test_redeploy_new_version(mock_get_all_node_ids, mock_deployment_state):
by_state=[(ReplicaState.STARTING, 1)],
)

deleted = deployment_state.update()
deleted, _ = deployment_state.update()
assert not deleted
check_counts(
deployment_state,
Expand Down Expand Up @@ -2148,7 +2148,8 @@ def test_resume_deployment_state_from_replica_tags(
deployment_state_manager._deployment_states[tag] = deployment_state

# Single replica should be created.
deployment_state_manager.update()
any_recovering = deployment_state_manager.update()
assert not any_recovering
check_counts(
deployment_state,
total=1,
Expand All @@ -2158,7 +2159,8 @@ def test_resume_deployment_state_from_replica_tags(
deployment_state._replicas.get()[0]._actor.set_ready()

# Now the replica should be marked running.
deployment_state_manager.update()
any_recovering = deployment_state_manager.update()
assert not any_recovering
check_counts(
deployment_state,
total=1,
Expand All @@ -2170,8 +2172,8 @@ def test_resume_deployment_state_from_replica_tags(

# Step 2: Delete _replicas from deployment_state
deployment_state._replicas = ReplicaStateContainer()
# Step 3: Create new deployment_state by resuming from passed in replicas

# Step 3: Create new deployment_state by resuming from passed in replicas
deployment_state_manager._recover_from_checkpoint(
[ReplicaName.prefix + mocked_replica.replica_tag]
)
Expand All @@ -2183,11 +2185,12 @@ def test_resume_deployment_state_from_replica_tags(
check_counts(
deployment_state, total=1, version=None, by_state=[(ReplicaState.RECOVERING, 1)]
)
deployment_state._replicas.get()[0]._actor.set_ready()
deployment_state._replicas.get()[0]._actor.set_starting_version(b_version_1)

# Now the replica should be marked running.
deployment_state_manager.update()
deployment_state._replicas.get()[0]._actor.set_ready()
deployment_state._replicas.get()[0]._actor.set_starting_version(b_version_1)
any_recovering = deployment_state_manager.update()
assert not any_recovering
check_counts(
deployment_state,
total=1,
Expand All @@ -2197,6 +2200,9 @@ def test_resume_deployment_state_from_replica_tags(
# Ensure same replica name is used
assert deployment_state._replicas.get()[0].replica_tag == mocked_replica.replica_tag

any_recovering = deployment_state_manager.update()
assert not any_recovering


def test_stopping_replicas_ranking():
@dataclass
Expand Down

0 comments on commit 38f4e44

Please sign in to comment.