Skip to content

Commit

Permalink
[workflow] Enable workflow storage test with cluster (ray-project#24401)
Browse files Browse the repository at this point in the history
* update
  • Loading branch information
suquark committed May 2, 2022
1 parent 3c9e704 commit 1282ae1
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 15 deletions.
30 changes: 30 additions & 0 deletions python/ray/workflow/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from contextlib import contextmanager
import subprocess
import pytest
import ray

Expand Down Expand Up @@ -47,6 +48,35 @@ def workflow_start_regular_shared(storage_type, request):
yield res


@pytest.fixture(scope="function")
def workflow_start_cluster(storage_type, request):
# This code follows the design of "call_ray_start" fixture.
with simulate_storage(storage_type) as storage_url:
utils.clear_marks()
parameter = getattr(
request,
"param",
"ray start --head --num-cpus=1 --min-worker-port=0 "
"--max-worker-port=0 --port 0 --storage=" + storage_url,
)
command_args = parameter.split(" ")
out = ray._private.utils.decode(
subprocess.check_output(command_args, stderr=subprocess.STDOUT)
)
# Get the redis address from the output.
address_prefix = "--address='"
address_location = out.find(address_prefix) + len(address_prefix)
address = out[address_location:]
address = address.split("'")[0]

yield address, storage_url

# Disconnect from the Ray cluster.
ray.shutdown()
# Kill the Ray cluster.
subprocess.check_call(["ray", "stop"])


def pytest_generate_tests(metafunc):
if "storage_type" in metafunc.fixturenames:
metafunc.parametrize("storage_type", ["s3", "fs"], scope="session")
26 changes: 11 additions & 15 deletions python/ray/workflow/tests/test_lifetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,35 +26,31 @@ def foo(x):
if __name__ == "__main__":
workflow.init()
ray.init(storage="{}")
output = workflow.create(foo.bind(0)).run_async(workflow_id="driver_terminated")
time.sleep({})
"""


@pytest.mark.skip(
reason="TODO (suquark): Figure out how to config a storage using 'ray start'."
)
def test_workflow_lifetime_1(call_ray_start, reset_workflow):
def test_workflow_lifetime_1(workflow_start_cluster):
# Case 1: driver exits normally
with patch.dict(os.environ, {"RAY_ADDRESS": call_ray_start}):
run_string_as_driver(driver_script.format(5))
workflow.init()
address, storage_uri = workflow_start_cluster
with patch.dict(os.environ, {"RAY_ADDRESS": address}):
ray.init(storage=storage_uri)
run_string_as_driver(driver_script.format(storage_uri, 5))
output = workflow.get_output("driver_terminated")
assert ray.get(output) == 20


@pytest.mark.skip(
reason="TODO (suquark): Figure out how to config a storage using 'ray start'."
)
def test_workflow_lifetime_2(call_ray_start, reset_workflow):
def test_workflow_lifetime_2(workflow_start_cluster):
# Case 2: driver terminated
with patch.dict(os.environ, {"RAY_ADDRESS": call_ray_start}):
proc = run_string_as_driver_nonblocking(driver_script.format(100))
address, storage_uri = workflow_start_cluster
with patch.dict(os.environ, {"RAY_ADDRESS": address}):
ray.init(storage=storage_uri)
proc = run_string_as_driver_nonblocking(driver_script.format(storage_uri, 100))
time.sleep(10)
proc.kill()
time.sleep(1)
workflow.init()
output = workflow.get_output("driver_terminated")
assert ray.get(output) == 20

Expand Down

0 comments on commit 1282ae1

Please sign in to comment.