From 9fb9743ce853e427e0ea077c4d280dcba10d7a46 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Tue, 2 Jan 2024 13:18:55 -0600 Subject: [PATCH] [serve] Unify the `handle_request` unary path for http and grpc in `UserCallableWrapper` (#42141) We had two separate methods, refactored to share one following the same style as the HTTP path. --------- Signed-off-by: Edward Oakes --- python/ray/serve/_private/replica.py | 67 ++++++++-------------------- 1 file changed, 19 insertions(+), 48 deletions(-) diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index 8dfa0fda1fb56..1ec77d8e43b09 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -299,18 +299,9 @@ async def handle_request( **request_kwargs, ) -> Tuple[bytes, Any]: request_metadata = pickle.loads(pickled_request_metadata) - if request_metadata.is_grpc_request: - # Ensure the request args are a single gRPCRequest object. - assert len(request_args) == 1 and isinstance(request_args[0], gRPCRequest) - result = await self._user_callable_wrapper.call_user_method_grpc_unary( - request_metadata=request_metadata, request=request_args[0] - ) - else: - result = await self._user_callable_wrapper.call_user_method( - request_metadata, request_args, request_kwargs - ) - - return result + return await self._user_callable_wrapper.call_user_method( + request_metadata, request_args, request_kwargs + ) async def _handle_http_request_generator( self, @@ -855,38 +846,6 @@ async def call_user_method_with_grpc_unary_stream( f"function, but '{user_method.__name__}' is not." ) - async def call_user_method_grpc_unary( - self, request_metadata: RequestMetadata, request: gRPCRequest - ) -> Tuple[RayServegRPCContext, bytes]: - """Call a user method that is *not* expected to be a generator. - - Deserializes gRPC request into protobuf object and pass into replica's runner - method. Returns a serialized protobuf bytes from the replica. - """ - async with self.wrap_user_method_call(request_metadata): - user_request = pickle.loads(request.grpc_user_request) - - runner_method = self.get_runner_method(request_metadata) - if inspect.isgeneratorfunction(runner_method) or inspect.isasyncgenfunction( - runner_method - ): - raise TypeError( - f"Method '{runner_method.__name__}' is a generator function. " - "You must use `handle.options(stream=True)` to call " - "generators on a deployment." - ) - - method_to_call = sync_to_async(runner_method) - - if GRPC_CONTEXT_ARG_NAME in inspect.signature(runner_method).parameters: - result = await method_to_call( - user_request, - grpc_context=request_metadata.grpc_context, - ) - else: - result = await method_to_call(user_request) - return request_metadata.grpc_context, result.SerializeToString() - async def call_user_method( self, request_metadata: RequestMetadata, @@ -908,6 +867,12 @@ async def call_user_method( request_args = (scope, receive, send) else: request_args = (Request(scope, receive, send),) + elif request_metadata.is_grpc_request: + # Ensure the request args are a single gRPCRequest object. + assert len(request_args) == 1 and isinstance( + request_args[0], gRPCRequest + ) + request_args = (pickle.loads(request_args[0].grpc_user_request),) runner_method = None try: @@ -925,11 +890,15 @@ async def call_user_method( # Edge case to support empty HTTP handlers: don't pass the Request # argument if the callable has no parameters. - if ( - request_metadata.is_http_request - and len(inspect.signature(runner_method).parameters) == 0 - ): + params = inspect.signature(runner_method).parameters + if request_metadata.is_http_request and len(params) == 0: request_args, request_kwargs = tuple(), {} + elif ( + request_metadata.is_grpc_request and GRPC_CONTEXT_ARG_NAME in params + ): + request_kwargs = { + GRPC_CONTEXT_ARG_NAME: request_metadata.grpc_context + } result = await method_to_call(*request_args, **request_kwargs) if inspect.isgenerator(result) or inspect.isasyncgen(result): @@ -959,6 +928,8 @@ async def call_user_method( # ASGI interface, but for the vanilla deployment codepath we need to # send it. await self.send_user_result_over_asgi(result, scope, receive, send) + elif request_metadata.is_grpc_request: + result = (request_metadata.grpc_context, result.SerializeToString()) return result