Skip to content

Commit

Permalink
Revert "[serve] Add test for Gradio queues (which use websockets) (ra…
Browse files Browse the repository at this point in the history
  • Loading branch information
edoakes authored and harborn committed Aug 17, 2023
1 parent 608a466 commit e6ca90d
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 69 deletions.
8 changes: 4 additions & 4 deletions doc/requirements-doc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ autodoc_pydantic==1.6.1
sphinx_design==0.4.1

# MyST
myst-parser==0.18.1
myst-nb==0.17.2
myst-parser==0.15.2
myst-nb==0.13.1

# Jupyter conversion
jupytext==1.14.6
jupytext==1.13.6

# Pin urllib to avoid downstream ssl incompatibility issues
urllib3 < 1.27
urllib3 < 1.27
48 changes: 37 additions & 11 deletions python/ray/serve/tests/test_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,27 @@
import sys

import gradio as gr
from gradio_client import Client
import pytest
import requests

import ray
from ray._private.test_utils import wait_for_condition

from ray import serve
from ray.serve.gradio_integrations import GradioIngress, GradioServer


@pytest.fixture
def serve_start_shutdown():
ray.init()
serve.start()
yield
serve.shutdown()
ray.shutdown()


@pytest.mark.parametrize("use_user_defined_class", [False, True])
def test_gradio_ingress_correctness(serve_instance, use_user_defined_class: bool):
def test_gradio_ingress_correctness(serve_start_shutdown, use_user_defined_class: bool):
"""
Ensure a Gradio app deployed to a cluster through GradioIngress still
produces the correct output.
Expand All @@ -38,11 +48,16 @@ def __init__(self):

serve.run(app)

client = Client("http:https://localhost:8000")
assert client.predict("Alice") == "Good morning Alice!"
test_input = "Alice"
response = requests.post(
"http:https://127.0.0.1:8000/api/predict/", json={"data": [test_input]}
)
assert response.status_code == 200 and response.json()["data"][0] == greet(
test_input
)


def test_gradio_ingress_scaling(serve_instance):
def test_gradio_ingress_scaling(serve_start_shutdown):
"""
Check that a Gradio app that has been deployed to a cluster through
GradioIngress scales as needed, i.e. separate client requests are served by
Expand All @@ -52,15 +67,26 @@ def test_gradio_ingress_scaling(serve_instance):
def f(*args):
return os.getpid()

serve.run(
GradioServer.options(num_replicas=2).bind(
lambda: gr.Interface(fn=f, inputs="text", outputs="text")
)
app = GradioServer.options(num_replicas=2).bind(
lambda: gr.Interface(fn=f, inputs="text", outputs="text")
)
serve.run(app)

def two_pids_returned():
@ray.remote
def get_pid_from_request():
r = requests.post(
"http:https://127.0.0.1:8000/api/predict/", json={"data": ["input"]}
)
r.raise_for_status()
return r.json()["data"][0]

return (
len(set(ray.get([get_pid_from_request.remote() for _ in range(10)]))) == 2
)

client = Client("http:https://localhost:8000")
# Verify that the requests are handled by two separate replicas.
wait_for_condition(lambda: len({client.predict("input") for _ in range(3)}) == 2)
wait_for_condition(two_pids_returned)


if __name__ == "__main__":
Expand Down
28 changes: 23 additions & 5 deletions python/ray/serve/tests/test_gradio_visualization.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import pytest
from collections import defaultdict
import asyncio
import aiohttp
import random

from gradio_client import Client

from ray.serve.experimental.gradio_visualize_graph import GraphVisualizer
from ray.dag.utils import _DAGNodeNameGenerator
from ray import serve
Expand Down Expand Up @@ -235,9 +234,28 @@ async def test_gradio_visualization_e2e(graph1):
visualizer = GraphVisualizer()
(_, url, _) = visualizer.visualize_with_gradio(handle, _launch=True, _block=False)

client = Client(url)
client.predict(random.randint(0, 100), 1, 2, fn_index=0)
assert {client.predict(fn_index=1), client.predict(fn_index=2)} == {1, 2}
async with aiohttp.ClientSession() as session:

async def fetch(data, fn_index):
async with session.post(
f"{url.strip('/')}/api/predict/",
json={
"session_hash": "random_hash",
"data": data,
"fn_index": fn_index,
},
) as resp:
return (await resp.json())["data"]

await fetch(
[random.randint(0, 100), 1, 2], 0
) # sends request to dag with input (1,2)
values = await asyncio.gather(
fetch([], 1), # fetches return value for one of the nodes
fetch([], 2), # fetches return value for the other node
)

assert [1] in values and [2] in values


@pytest.mark.asyncio
Expand Down
43 changes: 0 additions & 43 deletions python/ray/serve/tests/test_websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from websockets.sync.client import connect

import ray
from ray._private.test_utils import wait_for_condition

from ray import serve
from ray.serve._private.constants import RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING
Expand Down Expand Up @@ -153,47 +152,5 @@ async def ws_hi(self, ws: WebSocket):
assert ws.recv() == "hi"


@pytest.mark.skipif(
not RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING,
reason="Streaming feature flag is disabled.",
)
@pytest.mark.skipif(sys.platform == "win32", reason="Gradio doesn't work on Windows.")
@pytest.mark.skipif(sys.version_info.minor < 8, reason="Times out on Python 3.7.")
def test_gradio_queue(serve_instance):
"""Test the Gradio integration with a Gradio Queue.
Gradio Queues use websockets under the hood.
"""

# Delayed imports because these aren't installed on Windows.
import gradio as gr
from gradio_client import Client
from ray.serve.gradio_integrations import GradioIngress

def counter(num_steps: int = 3):
for i in range(num_steps):
yield str(i)

@serve.deployment
class GradioGenerator(GradioIngress):
def __init__(self):
g = gr.Interface(counter, inputs=gr.Slider(1, 10, 3), outputs="text")
super().__init__(lambda: g.queue())

serve.run(GradioGenerator.bind())

client = Client("http:https://localhost:8000")
job1 = client.submit(3, api_name="/predict")
job2 = client.submit(5, api_name="/predict")

wait_for_condition(
lambda: (
(job1.done() and job2.done())
and job1.outputs() == [str(i) for i in range(3)]
and job2.outputs() == [str(i) for i in range(5)]
)
)


if __name__ == "__main__":
sys.exit(pytest.main(["-v", "-s", __file__]))
11 changes: 5 additions & 6 deletions python/requirements/test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ flask==2.1.3
freezegun==1.1.0
google-api-python-client==2.65.0
google-cloud-storage==2.5.0
gradio==3.34; platform_system != "Windows"
gradio-client==0.2.6; platform_system != "Windows"
gradio==3.11; platform_system != "Windows"
websockets==11.0.3
joblib==1.2.0
jsonpatch==1.32
Expand Down Expand Up @@ -79,10 +78,10 @@ xlrd==2.0.1
memray; platform_system != "Windows" and sys_platform != "darwin"

# For doc tests
myst-parser==0.18.1
myst-nb==0.17.2
myst-parser==0.15.2
myst-nb==0.13.1
sphinx==4.3.2
jupytext==1.14.6
jupytext==1.13.6
jinja2==3.0.3
pytest-docker-tools==3.1.3
pytest-forked==1.4.0
Expand All @@ -100,7 +99,7 @@ importlib-metadata==4.10.0
tensorboardX==2.6.0
starlette==0.27.0
h11==0.12.0
markdown-it-py==2.2.0
markdown-it-py==1.1.0
attrs==21.4.0
pytz==2022.7.1
# Compatibility with spacy 3.5 (model en_core_web_sm)
Expand Down

0 comments on commit e6ca90d

Please sign in to comment.