Skip to content

Commit

Permalink
[Serve] User custom class name for replica class (ray-project#26574)
Browse files Browse the repository at this point in the history
Signed-off-by: Stefan van der Kleij <[email protected]>
  • Loading branch information
simon-mo authored and Stefan van der Kleij committed Aug 18, 2022
1 parent 58cd156 commit b95d611
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 4 deletions.
17 changes: 14 additions & 3 deletions python/ray/serve/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
logger = logging.getLogger(SERVE_LOGGER_NAME)


def _format_replica_actor_name(deployment_name: str):
return f"ServeReplica:{deployment_name}"


def create_replica_wrapper(name: str):
"""Creates a replica class wrapping the provided function or class.
Expand Down Expand Up @@ -211,8 +215,13 @@ async def prepare_for_shutdown(self):
async def check_health(self):
await self.replica.check_health()

RayServeWrappedReplica.__name__ = name
return RayServeWrappedReplica
# Dynamically create a new class with custom name here so Ray picks it up
# correctly in actor metadata table and observability stack.
return type(
_format_replica_actor_name(name),
(RayServeWrappedReplica,),
dict(RayServeWrappedReplica.__dict__),
)


class RayServeReplica:
Expand Down Expand Up @@ -333,7 +342,9 @@ async def check_health(self):

def _get_handle_request_stats(self) -> Optional[Dict[str, int]]:
actor_stats = ray.runtime_context.get_runtime_context()._get_actor_call_stats()
method_stat = actor_stats.get("RayServeWrappedReplica.handle_request")
method_stat = actor_stats.get(
f"{_format_replica_actor_name(self.deployment_name)}.handle_request"
)
return method_stat

def _collect_autoscaling_metrics(self):
Expand Down
2 changes: 1 addition & 1 deletion python/ray/serve/tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def get_actor_distributions():
actors = ray._private.state.actors()
node_to_actors = defaultdict(list)
for actor in actors.values():
if "RayServeWrappedReplica" not in actor["ActorClassName"]:
if "ServeReplica" not in actor["ActorClassName"]:
continue
if actor["State"] != "ALIVE":
continue
Expand Down
14 changes: 14 additions & 0 deletions python/ray/serve/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ray import serve
from ray._private.test_utils import wait_for_condition
from ray.serve.utils import block_until_http_ready
import ray.experimental.state.api as state_api


def test_serve_metrics_for_successful_connection(serve_instance):
Expand Down Expand Up @@ -142,6 +143,19 @@ def verify_error_count(do_assert=False):
verify_error_count(do_assert=True)


def test_actor_summary(serve_instance):
@serve.deployment
def f():
pass

serve.run(f.bind())
actors = state_api.list_actors()
class_names = {actor["class_name"] for actor in actors}
assert class_names.issuperset(
{"ServeController", "HTTPProxyActor", "ServeReplica:f"}
)


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit b95d611

Please sign in to comment.