Skip to content

Commit

Permalink
[workflow] Defining and updating workflow options (ray-project#24498)
Browse files Browse the repository at this point in the history
* implement "options" for workflow

* update tests
  • Loading branch information
suquark committed May 6, 2022
1 parent 189f7a4 commit 84ccab2
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 110 deletions.
74 changes: 46 additions & 28 deletions python/ray/workflow/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import types
from typing import Dict, Set, List, Tuple, Union, Optional, Any, Callable, TYPE_CHECKING
from typing import Dict, Set, List, Tuple, Union, Optional, Any, TYPE_CHECKING
import time

import ray
Expand Down Expand Up @@ -624,37 +624,55 @@ def continuation(dag_node: "DAGNode") -> Union[Workflow, ray.ObjectRef]:


@PublicAPI(stability="beta")
def options(
**workflow_options: Dict[str, Any]
) -> Callable[[RemoteFunction], RemoteFunction]:
# TODO(suquark): More rigid arguments check like @ray.remote arguments. This is
# fairly complex, but we should enable it later.
valid_options = {
"name",
"metadata",
"catch_exceptions",
"max_retries",
"allow_inplace",
"checkpoint",
}
invalid_keywords = set(workflow_options.keys()) - valid_options
if invalid_keywords:
raise ValueError(
f"Invalid option keywords {invalid_keywords} for workflow steps. "
f"Valid ones are {valid_options}."
)
class options:
"""This class serves both as a decorator and options for workflow.
def _apply_workflow_options(f: RemoteFunction):
if not isinstance(f, RemoteFunction):
raise ValueError("Only apply 'workflow.options' to Ray remote functions.")
Examples:
>>> import ray
>>> from ray import workflow
>>>
>>> # specify workflow options with a decorator
>>> @workflow.options(catch_exceptions=True):
>>> @ray.remote
>>> def foo():
>>> return 1
>>>
>>> # speficy workflow options in ".options"
>>> foo_new = foo.options(**workflow.options(catch_exceptions=False))
"""

def __init__(self, **workflow_options: Dict[str, Any]):
# TODO(suquark): More rigid arguments check like @ray.remote arguments. This is
# fairly complex, but we should enable it later.
valid_options = {
"name",
"metadata",
"catch_exceptions",
"max_retries",
"allow_inplace",
"checkpoint",
}
invalid_keywords = set(workflow_options.keys()) - valid_options
if invalid_keywords:
raise ValueError(
f"Invalid option keywords {invalid_keywords} for workflow steps. "
f"Valid ones are {valid_options}."
)
from ray.workflow.common import WORKFLOW_OPTIONS

if "_metadata" not in f._default_options:
f._default_options["_metadata"] = {}
f._default_options["_metadata"][WORKFLOW_OPTIONS] = workflow_options
return f
self.options = {"_metadata": {WORKFLOW_OPTIONS: workflow_options}}

def keys(self):
return ("_metadata",)

def __getitem__(self, key):
return self.options[key]

return _apply_workflow_options
def __call__(self, f: RemoteFunction) -> RemoteFunction:
if not isinstance(f, RemoteFunction):
raise ValueError("Only apply 'workflow.options' to Ray remote functions.")
f._default_options.update(self.options)
return f


__all__ = (
Expand Down
34 changes: 18 additions & 16 deletions python/ray/workflow/tests/test_basic_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import ray
from ray import workflow
from ray.workflow import workflow_access
from ray.workflow.tests.utils import update_workflow_options


def test_basic_workflows(workflow_start_regular_shared):
Expand Down Expand Up @@ -208,29 +207,31 @@ def unstable_step():
return v

with pytest.raises(Exception):
workflow.create(update_workflow_options(unstable_step, max_retries=-2).bind())
workflow.create(
unstable_step.options(**workflow.options(max_retries=-2).bind())
)

with pytest.raises(Exception):
workflow.create(
update_workflow_options(unstable_step, max_retries=2).bind()
unstable_step.options(**workflow.options(max_retries=2)).bind()
).run()
assert (
10
== workflow.create(
update_workflow_options(unstable_step, max_retries=7).bind()
unstable_step.options(**workflow.options(max_retries=7)).bind()
).run()
)
(tmp_path / "test").write_text("0")
(ret, err) = workflow.create(
update_workflow_options(
unstable_step, max_retries=2, catch_exceptions=True
unstable_step.options(
**workflow.options(max_retries=2, catch_exceptions=True)
).bind()
).run()
assert ret is None
assert isinstance(err, ValueError)
(ret, err) = workflow.create(
update_workflow_options(
unstable_step, max_retries=7, catch_exceptions=True
unstable_step.options(
**workflow.options(max_retries=7, catch_exceptions=True)
).bind()
).run()
assert ret == 10
Expand Down Expand Up @@ -293,7 +294,7 @@ def f1():
return workflow.continuation(f2.bind())

assert (10, None) == workflow.create(
update_workflow_options(f1, catch_exceptions=True).bind()
f1.options(**workflow.options(catch_exceptions=True)).bind()
).run()


Expand All @@ -306,7 +307,7 @@ def f1(n):
return workflow.continuation(f1.bind(n - 1))

ret, err = workflow.create(
update_workflow_options(f1, catch_exceptions=True).bind(5)
f1.options(**workflow.options(catch_exceptions=True)).bind(5)
).run()
assert ret is None
assert isinstance(err, ValueError)
Expand All @@ -319,7 +320,7 @@ def exponential_fail(k, n):
if n < 3:
raise Exception("Failed intentionally")
return workflow.continuation(
update_workflow_options(exponential_fail, name=f"step_{n}").bind(
exponential_fail.options(**workflow.options(name=f"step_{n}")).bind(
k * 2, n - 1
)
)
Expand All @@ -329,7 +330,7 @@ def exponential_fail(k, n):
# latest successful step.
try:
workflow.create(
update_workflow_options(exponential_fail, name="step_0").bind(3, 10)
exponential_fail.options(**workflow.options(name="step_0")).bind(3, 10)
).run(workflow_id="dynamic_output")
except Exception:
pass
Expand Down Expand Up @@ -357,7 +358,7 @@ def test_workflow_error_message():
assert str(e.value) == expected_error_msg


def test_options_update(workflow_start_regular_shared):
def test_options_update():
from ray.workflow.common import WORKFLOW_OPTIONS

# Options are given in decorator first, then in the first .options()
Expand All @@ -371,18 +372,19 @@ def f():
# .options(), then preserved in the second options.
# metadata and ray_options are "updated"
# max_retries only defined in the decorator and it got preserved all the way
new_f = update_workflow_options(
f, name="new_name", num_returns=2, metadata={"extra_k2": "extra_v2"}
new_f = f.options(
num_returns=2,
**workflow.options(name="new_name", metadata={"extra_k2": "extra_v2"}),
)
options = new_f.bind().get_options()
assert options == {
"num_cpus": 2,
"num_returns": 2,
"_metadata": {
WORKFLOW_OPTIONS: {
"name": "new_name",
"metadata": {"extra_k2": "extra_v2"},
"max_retries": 1,
"num_returns": 2,
}
},
}
Expand Down
18 changes: 10 additions & 8 deletions python/ray/workflow/tests/test_basic_workflows_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from filelock import FileLock
from ray._private.test_utils import SignalActor
from ray import workflow
from ray.workflow.tests.utils import update_workflow_options
from ray.tests.conftest import * # noqa


Expand Down Expand Up @@ -84,7 +83,9 @@ def incr():
return 10

with pytest.raises(ray.exceptions.RaySystemError):
workflow.create(update_workflow_options(incr, max_retries=0).bind()).run("incr")
workflow.create(incr.options(**workflow.options(max_retries=0)).bind()).run(
"incr"
)

assert cnt_file.read_text() == "1"

Expand All @@ -105,8 +106,8 @@ def double(v):

# Get the result from named step after workflow finished
assert 4 == workflow.create(
update_workflow_options(double, name="outer").bind(
update_workflow_options(double, name="inner").bind(1)
double.options(**workflow.options(name="outer")).bind(
double.options(**workflow.options(name="inner")).bind(1)
)
).run("double")
assert ray.get(workflow.get_output("double", name="inner")) == 2
Expand All @@ -127,8 +128,9 @@ def double(v, lock=None):
lock = FileLock(lock_path)
lock.acquire()
output = workflow.create(
update_workflow_options(double, name="outer").bind(
update_workflow_options(double, name="inner").bind(1, lock_path), lock_path
double.options(**workflow.options(name="outer")).bind(
double.options(**workflow.options(name="inner")).bind(1, lock_path),
lock_path,
)
).run_async("double-2")

Expand Down Expand Up @@ -176,8 +178,8 @@ def double(v, error):
# Force it to fail for the outer step
with pytest.raises(Exception):
workflow.create(
update_workflow_options(double, name="outer").bind(
update_workflow_options(double, name="inner").bind(1, False), True
double.options(**workflow.options(name="outer")).bind(
double.options(**workflow.options(name="inner")).bind(1, False), True
)
).run("double")

Expand Down
38 changes: 16 additions & 22 deletions python/ray/workflow/tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import numpy as np
from ray import workflow
from ray.workflow.tests import utils
from ray.workflow import workflow_storage


Expand All @@ -22,14 +21,14 @@ def identity(x):
def average(x):
return np.mean(x)

x = utils.update_workflow_options(
large_input, name="large_input", checkpoint=checkpoint
x = large_input.options(
**workflow.options(name="large_input", checkpoint=checkpoint)
).bind()
y = utils.update_workflow_options(
identity, name="identity", checkpoint=checkpoint
y = identity.options(
**workflow.options(name="identity", checkpoint=checkpoint)
).bind(x)
return workflow.continuation(
utils.update_workflow_options(average, name="average").bind(y)
average.options(**workflow.options(name="average")).bind(y)
)


Expand All @@ -54,13 +53,11 @@ def _assert_step_checkpoints(wf_storage, step_id, mode):


def test_checkpoint_dag_skip_all(workflow_start_regular_shared):
outputs = utils.run_workflow_dag_with_options(
checkpoint_dag,
(False,),
workflow_id="checkpoint_skip",
name="checkpoint_dag",
checkpoint=False,
)
outputs = workflow.create(
checkpoint_dag.options(
**workflow.options(name="checkpoint_dag", checkpoint=False)
).bind(False)
).run(workflow_id="checkpoint_skip")
assert np.isclose(outputs, 8388607.5)
recovered = ray.get(workflow.resume("checkpoint_skip"))
assert np.isclose(recovered, 8388607.5)
Expand All @@ -73,12 +70,9 @@ def test_checkpoint_dag_skip_all(workflow_start_regular_shared):


def test_checkpoint_dag_skip_partial(workflow_start_regular_shared):
outputs = utils.run_workflow_dag_with_options(
checkpoint_dag,
(False,),
workflow_id="checkpoint_partial",
name="checkpoint_dag",
)
outputs = workflow.create(
checkpoint_dag.options(**workflow.options(name="checkpoint_dag")).bind(False)
).run(workflow_id="checkpoint_partial")
assert np.isclose(outputs, 8388607.5)
recovered = ray.get(workflow.resume("checkpoint_partial"))
assert np.isclose(recovered, 8388607.5)
Expand All @@ -91,9 +85,9 @@ def test_checkpoint_dag_skip_partial(workflow_start_regular_shared):


def test_checkpoint_dag_full(workflow_start_regular_shared):
outputs = utils.run_workflow_dag_with_options(
checkpoint_dag, (True,), workflow_id="checkpoint_whole", name="checkpoint_dag"
)
outputs = workflow.create(
checkpoint_dag.options(**workflow.options(name="checkpoint_dag")).bind(True)
).run(workflow_id="checkpoint_whole")
assert np.isclose(outputs, 8388607.5)
recovered = ray.get(workflow.resume("checkpoint_whole"))
assert np.isclose(recovered, 8388607.5)
Expand Down
13 changes: 5 additions & 8 deletions python/ray/workflow/tests/test_checkpoint_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def identity(x):
def average(x):
return np.mean(x)

x = utils.update_workflow_options(large_input, checkpoint=checkpoint).bind()
y = utils.update_workflow_options(identity, checkpoint=checkpoint).bind(x)
x = large_input.options(**workflow.options(checkpoint=checkpoint)).bind()
y = identity.options(**workflow.options(checkpoint=checkpoint)).bind(x)
return workflow.continuation(average.bind(y))


Expand All @@ -40,12 +40,9 @@ def test_checkpoint_dag_recovery_skip(workflow_start_regular_shared):

start = time.time()
with pytest.raises(RaySystemError):
utils.run_workflow_dag_with_options(
checkpoint_dag,
(False,),
workflow_id="checkpoint_skip_recovery",
checkpoint=False,
)
workflow.create(
checkpoint_dag.options(**workflow.options(checkpoint=False)).bind(False)
).run(workflow_id="checkpoint_skip_recovery")
run_duration_skipped = time.time() - start

utils.set_global_mark()
Expand Down
Loading

0 comments on commit 84ccab2

Please sign in to comment.