Skip to content

Commit

Permalink
[serve] Resolve FunctionNode with no arguments to RayServeHandle,…
Browse files Browse the repository at this point in the history
… not `RayServeDAGHandle` (ray-project#37261)

Currently function deployments are resolved to `RayServeDAGHandle`s even if they are not used as part of the deployment graph API (with an `InputNode`).

This leads to some unexpected behaviors such as `.options()` not working properly and input arguments not being propagated.

This PR fixes the issue by replacing the function node with a regular `RayServeHandle` when there are zero bound arguments (which is always the case when not using the deployment graph API).

I've also added a test to prevent us regressing on the handle typing.
  • Loading branch information
edoakes committed Jul 11, 2023
1 parent 8ffb7aa commit f872d4a
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 9 deletions.
13 changes: 6 additions & 7 deletions python/ray/serve/_private/deployment_graph_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,13 +382,12 @@ def generate_executor_dag_driver_deployment(
def replace_with_handle(node):
if isinstance(node, DeploymentExecutorNode):
return node._deployment_handle
elif isinstance(
node,
(
DeploymentMethodExecutorNode,
DeploymentFunctionExecutorNode,
),
):
elif isinstance(node, DeploymentFunctionExecutorNode):
if len(node.get_args()) == 0 and len(node.get_kwargs()) == 0:
return node._deployment_function_handle
else:
return RayServeDAGHandle(cloudpickle.dumps(node))
elif isinstance(node, DeploymentMethodExecutorNode):
return RayServeDAGHandle(cloudpickle.dumps(node))

(
Expand Down
2 changes: 1 addition & 1 deletion python/ray/serve/tests/test_config_files/arithmetic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ deployments:
ray_actor_options:
runtime_env:
py_modules:
- "https://github.com/ray-project/test_module/archive/aa6f366f7daa78c98408c27d917a983caa9f888b.zip"
- "https://github.com/ray-project/test_module/archive/aa6f366f7daa78c98408c27d917a983caa9f888b.zip"
41 changes: 40 additions & 1 deletion python/ray/serve/tests/test_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ray import serve
from ray.serve.context import get_global_client
from ray.serve.exceptions import RayServeException
from ray.serve.handle import HandleOptions
from ray.serve.handle import HandleOptions, RayServeHandle, RayServeSyncHandle
from ray.serve._private.constants import (
DEPLOYMENT_NAME_PREFIX_SEPARATOR,
RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING,
Expand Down Expand Up @@ -314,6 +314,45 @@ async def cache_get():
asyncio.run_coroutine_threadsafe(refresh_get(), loop).result()


def test_handle_typing(serve_instance):
@serve.deployment
class DeploymentClass:
pass

@serve.deployment
def deployment_func():
pass

@serve.deployment
class Ingress:
def __init__(
self, class_downstream: RayServeHandle, func_downstream: RayServeHandle
):
# serve.run()'ing this deployment fails if these assertions fail.
assert isinstance(class_downstream, RayServeHandle)
assert isinstance(func_downstream, RayServeHandle)

h = serve.run(Ingress.bind(DeploymentClass.bind(), deployment_func.bind()))
assert isinstance(h, RayServeSyncHandle)


def test_call_function_with_argument(serve_instance):
@serve.deployment
def echo(name: str):
return f"Hi {name}"

@serve.deployment
class Ingress:
def __init__(self, h: RayServeHandle):
self._h = h

async def __call__(self, name: str):
return await (await self._h.remote(name))

h = serve.run(Ingress.bind(echo.bind()))
assert ray.get(h.remote("sned")) == "Hi sned"


if __name__ == "__main__":
import sys
import pytest
Expand Down
41 changes: 41 additions & 0 deletions python/ray/serve/tests/test_handle_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@ def unary(self, n: int) -> int:
return n


@serve.deployment
def sync_gen_function(n: int):
for i in range(n):
yield i


@serve.deployment
async def async_gen_function(n: int):
for i in range(n):
yield i


@pytest.mark.skipif(
not RAY_SERVE_ENABLE_NEW_ROUTING, reason="Routing FF must be enabled."
)
Expand Down Expand Up @@ -271,5 +283,34 @@ async def __call__(self):
ray.get(h.remote())


@pytest.mark.skipif(
not RAY_SERVE_ENABLE_NEW_ROUTING, reason="Routing FF must be enabled."
)
@pytest.mark.parametrize("deployment", [sync_gen_function, async_gen_function])
class TestGeneratorFunctionDeployment:
def test_app_handle(self, deployment: Deployment):
h = serve.run(deployment.bind()).options(stream=True)
obj_ref_gen = h.remote(5)
assert ray.get(list(obj_ref_gen)) == list(range(5))

def test_deployment_handle(self, deployment: Deployment):
@serve.deployment
class Delegate:
def __init__(self, f: RayServeHandle):
self._f = f.options(stream=True)

async def __call__(self):
obj_ref_gen = await self._f.remote(5)

results = []
async for obj_ref in obj_ref_gen:
results.append(await obj_ref)

assert results == list(range(5))

h = serve.run(Delegate.bind(deployment.bind()))
ray.get(h.remote())


if __name__ == "__main__":
sys.exit(pytest.main(["-v", "-s", __file__]))

0 comments on commit f872d4a

Please sign in to comment.