From b95d611d78d93b5cfd69b78c515e75434ea5e4ee Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Thu, 14 Jul 2022 20:10:56 -0700 Subject: [PATCH] [Serve] User custom class name for replica class (#26574) Signed-off-by: Stefan van der Kleij --- python/ray/serve/replica.py | 17 ++++++++++++++--- python/ray/serve/tests/test_cluster.py | 2 +- python/ray/serve/tests/test_metrics.py | 14 ++++++++++++++ 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/python/ray/serve/replica.py b/python/ray/serve/replica.py index 1c442f30360c1..ae387baff4a54 100644 --- a/python/ray/serve/replica.py +++ b/python/ray/serve/replica.py @@ -38,6 +38,10 @@ logger = logging.getLogger(SERVE_LOGGER_NAME) +def _format_replica_actor_name(deployment_name: str): + return f"ServeReplica:{deployment_name}" + + def create_replica_wrapper(name: str): """Creates a replica class wrapping the provided function or class. @@ -211,8 +215,13 @@ async def prepare_for_shutdown(self): async def check_health(self): await self.replica.check_health() - RayServeWrappedReplica.__name__ = name - return RayServeWrappedReplica + # Dynamically create a new class with custom name here so Ray picks it up + # correctly in actor metadata table and observability stack. + return type( + _format_replica_actor_name(name), + (RayServeWrappedReplica,), + dict(RayServeWrappedReplica.__dict__), + ) class RayServeReplica: @@ -333,7 +342,9 @@ async def check_health(self): def _get_handle_request_stats(self) -> Optional[Dict[str, int]]: actor_stats = ray.runtime_context.get_runtime_context()._get_actor_call_stats() - method_stat = actor_stats.get("RayServeWrappedReplica.handle_request") + method_stat = actor_stats.get( + f"{_format_replica_actor_name(self.deployment_name)}.handle_request" + ) return method_stat def _collect_autoscaling_metrics(self): diff --git a/python/ray/serve/tests/test_cluster.py b/python/ray/serve/tests/test_cluster.py index e5222f6ad814b..31b12567bbf16 100644 --- a/python/ray/serve/tests/test_cluster.py +++ b/python/ray/serve/tests/test_cluster.py @@ -181,7 +181,7 @@ def get_actor_distributions(): actors = ray._private.state.actors() node_to_actors = defaultdict(list) for actor in actors.values(): - if "RayServeWrappedReplica" not in actor["ActorClassName"]: + if "ServeReplica" not in actor["ActorClassName"]: continue if actor["State"] != "ALIVE": continue diff --git a/python/ray/serve/tests/test_metrics.py b/python/ray/serve/tests/test_metrics.py index 5a2f828ae4951..cd26e8cf72e30 100644 --- a/python/ray/serve/tests/test_metrics.py +++ b/python/ray/serve/tests/test_metrics.py @@ -7,6 +7,7 @@ from ray import serve from ray._private.test_utils import wait_for_condition from ray.serve.utils import block_until_http_ready +import ray.experimental.state.api as state_api def test_serve_metrics_for_successful_connection(serve_instance): @@ -142,6 +143,19 @@ def verify_error_count(do_assert=False): verify_error_count(do_assert=True) +def test_actor_summary(serve_instance): + @serve.deployment + def f(): + pass + + serve.run(f.bind()) + actors = state_api.list_actors() + class_names = {actor["class_name"] for actor in actors} + assert class_names.issuperset( + {"ServeController", "HTTPProxyActor", "ServeReplica:f"} + ) + + if __name__ == "__main__": import sys