Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[serve] Enable passing starlette requests w/ a warning #37046

Merged
merged 2 commits into from
Jul 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 49 additions & 8 deletions python/ray/serve/_private/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
Tuple,
Union,
)
import warnings

from starlette.requests import Request

import ray
from ray.actor import ActorHandle
Expand All @@ -32,6 +35,7 @@
SERVE_LOGGER_NAME,
HANDLE_METRIC_PUSH_INTERVAL_S,
)
from ray.serve._private.http_util import make_buffered_asgi_receive
from ray.serve._private.long_poll import LongPollClient, LongPollNamespace
from ray.serve._private.utils import (
compute_iterable_delta,
Expand All @@ -45,6 +49,9 @@

logger = logging.getLogger(SERVE_LOGGER_NAME)

# Used to only print a single warning when users pass starlette requests via handle.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Used to only print a single warning when users pass starlette requests via handle.
# Used only to print a single warning when users pass starlette requests via handle.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The phrasing is as-is to specify that it's used "used to only print a single" versus printing multiple.

Not that this variable is "only used to print" a warning

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh got it, could we phrase it like this:

Suggested change
# Used to only print a single warning when users pass starlette requests via handle.
# Used to print only a single warning when users pass starlette requests via handle.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah not going to wait for another build for it though, if it fails i'll update it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds good.

WARNED_ABOUT_STARLETTE_REQUESTS_ONCE = False


@dataclass
class RequestMetadata:
Expand Down Expand Up @@ -77,15 +84,49 @@ class Query:
async def resolve_async_tasks(self):
"""Find all unresolved asyncio.Task and gather them all at once."""
scanner = _PyObjScanner(source_type=asyncio.Task)
tasks = scanner.find_nodes((self.args, self.kwargs))

if len(tasks) > 0:
resolved = await asyncio.gather(*tasks)
replacement_table = dict(zip(tasks, resolved))
self.args, self.kwargs = scanner.replace_nodes(replacement_table)
try:
tasks = scanner.find_nodes((self.args, self.kwargs))

if len(tasks) > 0:
resolved = await asyncio.gather(*tasks)
replacement_table = dict(zip(tasks, resolved))
self.args, self.kwargs = scanner.replace_nodes(replacement_table)
finally:
# Make the scanner GC-able to avoid memory leaks.
scanner.clear()

async def buffer_starlette_requests_and_warn(self):
"""Buffer any `starlette.request.Requests` objects to make them serializable.

This is an anti-pattern because the requests will not be fully functional, so
warn the user. We may fully disallow it in the future.
"""
global WARNED_ABOUT_STARLETTE_REQUESTS_ONCE
scanner = _PyObjScanner(source_type=Request)

try:
requests = scanner.find_nodes((self.args, self.kwargs))
if len(requests) > 0 and not WARNED_ABOUT_STARLETTE_REQUESTS_ONCE:
WARNED_ABOUT_STARLETTE_REQUESTS_ONCE = True
# TODO(edoakes): fully disallow this in the future.
warnings.warn(
"`starlette.Request` objects should not be directly passed via "
"`ServeHandle` calls. Not all functionality is guaranteed to work "
"(e.g., detecting disconnects) and this may be disallowed in a "
"future release."
)

for request in requests:

# Make the scanner GCable to avoid memory leak
scanner.clear()
async def empty_send():
pass

request._send = empty_send
request._receive = make_buffered_asgi_receive(await request.body())
finally:
# Make the scanner GC-able to avoid memory leaks.
scanner.clear()


class ReplicaWrapper(ABC):
Expand Down Expand Up @@ -806,7 +847,6 @@ async def assign_replica(self, query: Query) -> ray.ObjectRef:
if query.metadata.is_streaming:
raise NotImplementedError("Streaming requires new routing to be enabled.")

await query.resolve_async_tasks()
assigned_ref = self._try_assign_replica(query)
while assigned_ref is None: # Can't assign a replica right now.
logger.debug(
Expand Down Expand Up @@ -935,6 +975,7 @@ async def assign_request(
metadata=request_meta,
)
await query.resolve_async_tasks()
await query.buffer_starlette_requests_and_warn()
result = await self._replica_scheduler.assign_replica(query)

self.num_queued_queries -= 1
Expand Down
10 changes: 2 additions & 8 deletions python/ray/serve/http_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from ray.util.annotations import PublicAPI
from ray.serve._private.utils import require_packages
from ray.serve._private.http_util import make_buffered_asgi_receive


_1DArray = List[float]
Expand Down Expand Up @@ -63,16 +62,11 @@ def json_to_multi_ndarray(payload: Dict[str, NdArray]) -> Dict[str, np.ndarray]:


@PublicAPI(stability="beta")
async def starlette_request(
def starlette_request(
request: starlette.requests.Request,
) -> starlette.requests.Request:
"""Returns the raw request object."""
# NOTE(edoakes): the raw Request passed in may not be serializable so we
# need to convert it to a version that just wraps the body bytes.
return starlette.requests.Request(
request.scope,
make_buffered_asgi_receive(await request.body()),
)
return request


@PublicAPI(stability="beta")
Expand Down
41 changes: 39 additions & 2 deletions python/ray/serve/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
import asyncio
import os
from typing import Optional
from typing import Dict, Optional

from fastapi import FastAPI
import requests
from pydantic import BaseModel, ValidationError
import pytest
import starlette.responses
from starlette.requests import Request

import ray
from ray import serve
from ray._private.test_utils import SignalActor, wait_for_condition

from ray import serve
from ray.serve.built_application import BuiltApplication
from ray.serve.deployment import Application
from ray.serve.deployment_graph import RayServeDAGHandle
from ray.serve.drivers import DAGDriver
from ray.serve.exceptions import RayServeException
from ray.serve.handle import RayServeHandle
from ray.serve._private.api import call_app_builder_with_args_if_necessary
from ray.serve._private.constants import (
SERVE_DEFAULT_APP_NAME,
Expand Down Expand Up @@ -847,6 +850,40 @@ def f():
serve.run(f.bind(), route_prefix="no_slash")


def test_pass_starlette_request_over_handle(serve_instance):
@serve.deployment
class Downstream:
async def __call__(self, request: Request) -> Dict[str, str]:
r = await request.json()
r["foo"] = request.headers["foo"]
r.update(request.query_params)
return r

@serve.deployment
class Upstream:
def __init__(self, downstream: RayServeHandle):
self._downstream = downstream

async def __call__(self, request: Request) -> Dict[str, str]:
ref = await self._downstream.remote(request)
return await ref

serve.run(Upstream.bind(Downstream.bind()))

r = requests.get(
"https://127.0.0.1:8000/",
json={"hello": "world"},
headers={"foo": "bar"},
params={"baz": "quux"},
)
r.raise_for_status()
assert r.json() == {
"hello": "world",
"foo": "bar",
"baz": "quux",
}


if __name__ == "__main__":
import sys

Expand Down