Skip to content

Commit

Permalink
[serve] Unify the handle_request unary path for http and grpc in `U…
Browse files Browse the repository at this point in the history
…serCallableWrapper` (ray-project#42141)

We had two separate methods, refactored to share one following the same style as the HTTP path.

---------

Signed-off-by: Edward Oakes <[email protected]>
  • Loading branch information
edoakes committed Jan 2, 2024
1 parent 923ea82 commit 9fb9743
Showing 1 changed file with 19 additions and 48 deletions.
67 changes: 19 additions & 48 deletions python/ray/serve/_private/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,18 +299,9 @@ async def handle_request(
**request_kwargs,
) -> Tuple[bytes, Any]:
request_metadata = pickle.loads(pickled_request_metadata)
if request_metadata.is_grpc_request:
# Ensure the request args are a single gRPCRequest object.
assert len(request_args) == 1 and isinstance(request_args[0], gRPCRequest)
result = await self._user_callable_wrapper.call_user_method_grpc_unary(
request_metadata=request_metadata, request=request_args[0]
)
else:
result = await self._user_callable_wrapper.call_user_method(
request_metadata, request_args, request_kwargs
)

return result
return await self._user_callable_wrapper.call_user_method(
request_metadata, request_args, request_kwargs
)

async def _handle_http_request_generator(
self,
Expand Down Expand Up @@ -855,38 +846,6 @@ async def call_user_method_with_grpc_unary_stream(
f"function, but '{user_method.__name__}' is not."
)

async def call_user_method_grpc_unary(
self, request_metadata: RequestMetadata, request: gRPCRequest
) -> Tuple[RayServegRPCContext, bytes]:
"""Call a user method that is *not* expected to be a generator.
Deserializes gRPC request into protobuf object and pass into replica's runner
method. Returns a serialized protobuf bytes from the replica.
"""
async with self.wrap_user_method_call(request_metadata):
user_request = pickle.loads(request.grpc_user_request)

runner_method = self.get_runner_method(request_metadata)
if inspect.isgeneratorfunction(runner_method) or inspect.isasyncgenfunction(
runner_method
):
raise TypeError(
f"Method '{runner_method.__name__}' is a generator function. "
"You must use `handle.options(stream=True)` to call "
"generators on a deployment."
)

method_to_call = sync_to_async(runner_method)

if GRPC_CONTEXT_ARG_NAME in inspect.signature(runner_method).parameters:
result = await method_to_call(
user_request,
grpc_context=request_metadata.grpc_context,
)
else:
result = await method_to_call(user_request)
return request_metadata.grpc_context, result.SerializeToString()

async def call_user_method(
self,
request_metadata: RequestMetadata,
Expand All @@ -908,6 +867,12 @@ async def call_user_method(
request_args = (scope, receive, send)
else:
request_args = (Request(scope, receive, send),)
elif request_metadata.is_grpc_request:
# Ensure the request args are a single gRPCRequest object.
assert len(request_args) == 1 and isinstance(
request_args[0], gRPCRequest
)
request_args = (pickle.loads(request_args[0].grpc_user_request),)

runner_method = None
try:
Expand All @@ -925,11 +890,15 @@ async def call_user_method(

# Edge case to support empty HTTP handlers: don't pass the Request
# argument if the callable has no parameters.
if (
request_metadata.is_http_request
and len(inspect.signature(runner_method).parameters) == 0
):
params = inspect.signature(runner_method).parameters
if request_metadata.is_http_request and len(params) == 0:
request_args, request_kwargs = tuple(), {}
elif (
request_metadata.is_grpc_request and GRPC_CONTEXT_ARG_NAME in params
):
request_kwargs = {
GRPC_CONTEXT_ARG_NAME: request_metadata.grpc_context
}

result = await method_to_call(*request_args, **request_kwargs)
if inspect.isgenerator(result) or inspect.isasyncgen(result):
Expand Down Expand Up @@ -959,6 +928,8 @@ async def call_user_method(
# ASGI interface, but for the vanilla deployment codepath we need to
# send it.
await self.send_user_result_over_asgi(result, scope, receive, send)
elif request_metadata.is_grpc_request:
result = (request_metadata.grpc_context, result.SerializeToString())

return result

Expand Down

0 comments on commit 9fb9743

Please sign in to comment.