Skip to content

Commit

Permalink
[Workflow] Simplify recovery algorithm (ray-project#24594)
Browse files Browse the repository at this point in the history
* simplify recovery algorithm
  • Loading branch information
suquark committed May 10, 2022
1 parent 4a99977 commit bf6b7f4
Showing 1 changed file with 102 additions and 129 deletions.
231 changes: 102 additions & 129 deletions python/ray/workflow/recovery.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Any, Union, Dict, Callable, Tuple, Optional
from collections import deque, defaultdict

import ray
from ray.workflow import workflow_context
Expand All @@ -12,7 +13,6 @@
StepType,
)
from ray.workflow import workflow_storage
from ray.workflow.step_function import WorkflowStepFunction


class WorkflowStepNotRecoverableError(Exception):
Expand All @@ -32,136 +32,111 @@ def __init__(self, workflow_id: str):
super().__init__(self.message)


@WorkflowStepFunction
def _recover_workflow_step(
input_workflows: List[Any],
input_workflow_refs: List[WorkflowRef],
*args,
**kwargs,
):
"""A workflow step that recovers the output of an unfinished step.
Args:
args: The positional arguments for the step function.
kwargs: The keyword args for the step function.
input_workflows: The workflows in the argument of the (original) step.
They are resolved into physical objects (i.e. the output of the
workflows) here. They come from other recover workflows we
construct recursively.
Returns:
The output of the recovered step.
"""
reader = workflow_storage.get_workflow_storage()
step_id = workflow_context.get_current_step_id()
func: Callable = reader.load_step_func_body(step_id)
return func(*args, **kwargs)


def _reconstruct_wait_step(
reader: workflow_storage.WorkflowStorage,
step_id: StepID,
result: workflow_storage.StepInspectResult,
input_map: Dict[StepID, Any],
):
input_workflows = []
step_options = result.step_options
wait_options = step_options.ray_options.get("wait_options", {})
for i, _step_id in enumerate(result.workflows):
# Check whether the step has been loaded or not to avoid
# duplication
if _step_id in input_map:
r = input_map[_step_id]
else:
r = _construct_resume_workflow_from_step(reader, _step_id, input_map)
input_map[_step_id] = r
if isinstance(r, Workflow):
input_workflows.append(r)
else:
assert isinstance(r, StepID)
# TODO (Alex): We should consider caching these outputs too.
output = reader.load_step_output(r)
# Simulate a workflow with a workflow reference so it could be
# used directly by 'workflow.wait'.
static_ref = WorkflowStaticRef(step_id=r, ref=ray.put(output))
wf = Workflow.from_ref(static_ref)
input_workflows.append(wf)

from ray import workflow

wait_step = workflow.wait(input_workflows, **wait_options)
# override step id
wait_step._step_id = step_id
return wait_step


def _construct_resume_workflow_from_step(
reader: workflow_storage.WorkflowStorage,
step_id: StepID,
input_map: Dict[StepID, Any],
) -> Union[Workflow, StepID]:
workflow_id: str, step_id: StepID
) -> Union[Workflow, Any]:
"""Try to construct a workflow (step) that recovers the workflow step.
If the workflow step already has an output checkpointing file, we return
the workflow step id instead.
Args:
reader: The storage reader for inspecting the step.
workflow_id: The ID of the workflow.
step_id: The ID of the step we want to recover.
input_map: This is a context storing the input which has been loaded.
This context is important for dedupe
Returns:
A workflow that recovers the step, or a ID of a step
that contains the output checkpoint file.
A workflow that recovers the step, or the output of the step
if it has been checkpointed.
"""
result: workflow_storage.StepInspectResult = reader.inspect_step(step_id)
if result.output_object_valid:
# we already have the output
return step_id
if isinstance(result.output_step_id, str):
return _construct_resume_workflow_from_step(
reader, result.output_step_id, input_map
)
# output does not exists or not valid. try to reconstruct it.
if not result.is_recoverable():
raise WorkflowStepNotRecoverableError(step_id)

step_options = result.step_options
# Process the wait step as a special case.
if step_options.step_type == StepType.WAIT:
return _reconstruct_wait_step(reader, step_id, result, input_map)
reader = workflow_storage.WorkflowStorage(workflow_id)

# Step 1: construct dependency of the DAG (BFS)
inpsect_results = {}
dependency_map = defaultdict(list)
num_in_edges = {}

dag_visit_queue = deque([step_id])
while dag_visit_queue:
s: StepID = dag_visit_queue.popleft()
if s in inpsect_results:
continue
r = reader.inspect_step(s)
inpsect_results[s] = r
if not r.is_recoverable():
raise WorkflowStepNotRecoverableError(s)
if r.output_object_valid:
deps = []
elif isinstance(r.output_step_id, str):
deps = [r.output_step_id]
else:
deps = r.workflows
for w in deps:
dependency_map[w].append(s)
num_in_edges[s] = len(deps)
dag_visit_queue.extend(deps)

# Step 2: topological sort to determine the execution order (Kahn's algorithm)
execution_queue: List[StepID] = []

start_nodes = deque(k for k, v in num_in_edges.items() if v == 0)
while start_nodes:
n = start_nodes.popleft()
execution_queue.append(n)
for m in dependency_map[n]:
num_in_edges[m] -= 1
assert num_in_edges[m] >= 0, (m, n)
if num_in_edges[m] == 0:
start_nodes.append(m)

# Step 3: recover the workflow by the order of the execution queue
with serialization.objectref_cache():
input_workflows = []
for i, _step_id in enumerate(result.workflows):
# Check whether the step has been loaded or not to avoid
# duplication
if _step_id in input_map:
r = input_map[_step_id]
else:
r = _construct_resume_workflow_from_step(reader, _step_id, input_map)
input_map[_step_id] = r
if isinstance(r, Workflow):
input_workflows.append(r)
# "input_map" is a context storing the input which has been loaded.
# This context is important for deduplicate step inputs.
input_map: Dict[StepID, Any] = {}

for _step_id in execution_queue:
result = inpsect_results[_step_id]
if result.output_object_valid:
input_map[_step_id] = reader.load_step_output(_step_id)
continue
if isinstance(result.output_step_id, str):
input_map[_step_id] = input_map[result.output_step_id]
continue

# Process the wait step as a special case.
if result.step_options.step_type == StepType.WAIT:
wait_input_workflows = []
for w in result.workflows:
output = input_map[w]
if isinstance(output, Workflow):
wait_input_workflows.append(output)
else:
# Simulate a workflow with a workflow reference so it could be
# used directly by 'workflow.wait'.
static_ref = WorkflowStaticRef(step_id=w, ref=ray.put(output))
wait_input_workflows.append(Workflow.from_ref(static_ref))
recovery_workflow = ray.workflow.wait(
wait_input_workflows,
**result.step_options.ray_options.get("wait_options", {}),
)
else:
assert isinstance(r, StepID)
# TODO (Alex): We should consider caching these outputs too.
input_workflows.append(reader.load_step_output(r))
workflow_refs = list(map(WorkflowRef, result.workflow_refs))

args, kwargs = reader.load_step_args(step_id, input_workflows, workflow_refs)
# Note: we must uppack args and kwargs, so the refs in the args/kwargs can get
# resolved consistently like in Ray.
recovery_workflow: Workflow = _recover_workflow_step.step(
input_workflows,
workflow_refs,
*args,
**kwargs,
)
recovery_workflow._step_id = step_id
# override step_options
recovery_workflow.data.step_options = step_options
return recovery_workflow
args, kwargs = reader.load_step_args(
_step_id,
workflows=[input_map[w] for w in result.workflows],
workflow_refs=list(map(WorkflowRef, result.workflow_refs)),
)
func: Callable = reader.load_step_func_body(_step_id)
# TODO(suquark): Use an alternative function when "workflow.step"
# is fully deprecated.
recovery_workflow = ray.workflow.step(func).step(*args, **kwargs)

# override step_options
recovery_workflow._step_id = _step_id
recovery_workflow.data.step_options = result.step_options

input_map[_step_id] = recovery_workflow

# Step 4: return the output of the requested step
return input_map[step_id]


@ray.remote(num_returns=2)
Expand All @@ -183,21 +158,19 @@ def _resume_workflow_step_executor(
except Exception:
pass
try:
wf_store = workflow_storage.WorkflowStorage(workflow_id)
r = _construct_resume_workflow_from_step(wf_store, step_id, {})
r = _construct_resume_workflow_from_step(workflow_id, step_id)
except Exception as e:
raise WorkflowNotResumableError(workflow_id) from e

if isinstance(r, Workflow):
with workflow_context.workflow_step_context(
workflow_id, last_step_of_workflow=True
):
from ray.workflow.step_executor import execute_workflow
if not isinstance(r, Workflow):
return r, None
with workflow_context.workflow_step_context(
workflow_id, last_step_of_workflow=True
):
from ray.workflow.step_executor import execute_workflow

result = execute_workflow(job_id, r)
return result.persisted_output, result.volatile_output
assert isinstance(r, StepID)
return wf_store.load_step_output(r), None
result = execute_workflow(job_id, r)
return result.persisted_output, result.volatile_output


def resume_workflow_step(
Expand Down

0 comments on commit bf6b7f4

Please sign in to comment.