Skip to content

Commit

Permalink
[core][experimental] Pass torch.Tensors through accelerated DAGs (ray…
Browse files Browse the repository at this point in the history
…-project#44825)

This PR adds support for passing torch.Tensors to local actors in an
accelerated DAG, via Ray's shared memory store. It supports the
following transfer cases, as long as the sending and receiving actors
are on the same node: CPU-CPU, CPU-GPU, GPU-CPU, GPU-GPU (via CPU).

This iteration requires the user to explicitly declare which DAG nodes
contain torch.Tensors and the tensors' shape and dtype, with a new
`with_type_hint` decorator. For example:

```python
    with InputNode() as inp:
        dag = sender.send.bind(inp)
        dag = dag.with_type_hint(TorchTensorType(SHAPE, DTYPE))
        dag = receiver.recv.bind(dag)

    compiled_dag = dag.experimental_compile()
```
This declaration isn't necessarily useful for this PR, but it is
included now because it makes it much simpler to efficiently support
other cases in the future, such as p2p GPU-GPU transfers.

When a TorchTensor node is declared, the serialization of the underlying
torch.Tensor is performed differently from vanilla Ray. In particular,
we store the numpy view of the data. On the receiving actor, we
deserialize to a torch.Tensor and move it to the device assigned to the
actor, if any. Microbenchmarking shows that this is 4x faster than
normal pickling and unpickling of a torch.Tensor, likely due to Ray's
serialization support for numpy. Also, when moving the torch.Tensor to a
GPU on the receiving side, we can avoid one extra data copy by copying
directly from Ray's shared memory buffer to GPU memory.

Limitations:
- Only supports tasks that directly return a torch.Tensor, i.e. the
torch.Tensor cannot be nested in other data.
- The task must declare the shape and dtype of its torch.Tensor at DAG
compile time.
- Does not support local p2p GPU-GPU transfer, either using `cudaMemCpy`
or NCCL. Microbenchmark shows this can be >10x faster than transfer via
CPU.
- Does not support multinode GPU-GPU transfer, e.g., via RPC between
hosts or NCCL.

---------

Signed-off-by: Stephanie Wang <[email protected]>
  • Loading branch information
stephanie-wang committed May 3, 2024
1 parent ab8e10a commit bbcdc49
Show file tree
Hide file tree
Showing 5 changed files with 354 additions and 22 deletions.
109 changes: 87 additions & 22 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from collections import defaultdict
from typing import Any, Dict, List, Tuple, Union, Optional
from typing import Any, Callable, Dict, List, Tuple, Union, Optional
import logging
import traceback
import threading
Expand All @@ -17,7 +17,7 @@
AwaitableBackgroundWriter,
)
from ray.util.annotations import DeveloperAPI, PublicAPI

from ray.dag.experimental.types import _do_register_custom_dag_serializers

MAX_BUFFER_SIZE = int(100 * 1e6) # 100MB

Expand All @@ -44,11 +44,26 @@ def do_allocate_channel(
return self._output_channel


def _wrap_exception(exc):
backtrace = ray._private.utils.format_error_message(
"".join(traceback.format_exception(type(exc), exc, exc.__traceback__)),
task_exception=True,
)
wrapped = RayTaskError(
function_name="do_exec_compiled_task",
traceback_str=backtrace,
cause=exc,
)
return wrapped


@DeveloperAPI
def do_exec_compiled_task(
self,
inputs: List[Union[Any, Channel]],
actor_method_name: str,
output_wrapper_fn: Optional[Callable[[Any], Any]],
has_type_hints: bool,
) -> None:
"""Generic actor method to begin executing a compiled DAG. This runs an
infinite loop to repeatedly read input channel(s), execute the given
Expand All @@ -64,6 +79,9 @@ def do_exec_compiled_task(
the loop.
"""
try:
if has_type_hints:
_do_register_custom_dag_serializers(self)

method = getattr(self, actor_method_name)

resolved_inputs = []
Expand All @@ -84,9 +102,14 @@ def do_exec_compiled_task(
self._output_writer.start()

while True:
res = None
try:
res = self._input_reader.begin_read()
except ValueError as exc:
# ValueError is raised if a type hint was set and the returned
# type did not match the hint.
self._output_writer.write(exc)
self._input_reader.end_read()
continue
except IOError:
break

Expand All @@ -95,19 +118,10 @@ def do_exec_compiled_task(

try:
output_val = method(*resolved_inputs)
if output_wrapper_fn is not None:
output_val = output_wrapper_fn(output_val)
except Exception as exc:
backtrace = ray._private.utils.format_error_message(
"".join(
traceback.format_exception(type(exc), exc, exc.__traceback__)
),
task_exception=True,
)
wrapped = RayTaskError(
function_name="do_exec_compiled_task",
traceback_str=backtrace,
cause=exc,
)
self._output_writer.write(wrapped)
self._output_writer.write(_wrap_exception(exc))
else:
self._output_writer.write(output_val)

Expand Down Expand Up @@ -160,12 +174,34 @@ def __init__(self, idx: int, dag_node: "ray.dag.DAGNode"):
idx: A unique index into the original DAG.
dag_node: The original DAG node created by the user.
"""
from ray.dag.experimental.types import (
TorchTensorType,
_TorchTensorWrapper,
)

self.idx = idx
self.dag_node = dag_node
self.arg_idx_to_tensor_meta: Dict[int, Dict[str, Any]] = {}

self.downstream_node_idxs = set()
self.output_channel = None

# If set, a lambda to apply to the task output. This can be used to
# check type hints, if any.
self.output_wrapper_fn = None
if self.dag_node.type_hint is not None:
print(self.dag_node.type_hint)
if isinstance(self.dag_node.type_hint, TorchTensorType):
# Wrap outputs produced by this task to indicate that it
# should be specially serialized.
self.output_wrapper_fn = lambda t: _TorchTensorWrapper(
t, self.dag_node.type_hint
)
else:
raise ValueError(
"DAGNode.with_type_hint may only be called on " "TorchTensorType"
)

@property
def args(self) -> Tuple[Any]:
return self.dag_node.get_args()
Expand Down Expand Up @@ -246,6 +282,7 @@ def __init__(
# Attributes that are set during preprocessing.
# Preprocessing identifies the input node and output node.
self.input_task_idx: Optional[int] = None
self.input_wrapper_fn: Optional[Callable[[Any], Any]] = None
self.output_task_idx: Optional[int] = None
self.has_single_output: bool = False
self.actor_task_count: Dict["ray._raylet.ActorID", int] = defaultdict(int)
Expand All @@ -262,6 +299,9 @@ def __init__(
# Set of actors present in the DAG.
self.actor_refs = set()

# Type hints specified by the user for DAG (intermediate) outputs.
self._type_hints = []

def _add_node(self, node: "ray.dag.DAGNode") -> None:
idx = self.counter
self.idx_to_task[idx] = CompiledTask(idx, node)
Expand All @@ -286,9 +326,11 @@ def _preprocess(self) -> None:

self.input_task_idx, self.output_task_idx = None, None
self.actor_task_count.clear()
self._type_hints.clear()

# For each task node, set its upstream and downstream task nodes.
for idx, task in self.idx_to_task.items():
# Also collect the set of tasks that produce torch.tensors.
for node_idx, task in self.idx_to_task.items():
dag_node = task.dag_node
if not (
isinstance(dag_node, InputNode)
Expand Down Expand Up @@ -320,10 +362,13 @@ def _preprocess(self) -> None:
)
self.actor_task_count[actor_handle._actor_id] += 1

for arg in task.args:
for arg_idx, arg in enumerate(task.args):
if isinstance(arg, DAGNode):
arg_idx = self.dag_node_to_idx[arg]
self.idx_to_task[arg_idx].downstream_node_idxs.add(idx)
arg_node_idx = self.dag_node_to_idx[arg]
self.idx_to_task[arg_node_idx].downstream_node_idxs.add(node_idx)

if dag_node.type_hint is not None:
self._type_hints.append(dag_node.type_hint)

for actor_id, task_count in self.actor_task_count.items():
if task_count > 1:
Expand Down Expand Up @@ -392,6 +437,10 @@ def _get_or_compile(
continue
visited.add(cur_idx)

# TODO: Check for GPU arguments. Find the actor upstream to that
# GPU argument. If both writer and reader actors are on GPUs, then
# add them.

task = self.idx_to_task[cur_idx]
# Create an output buffer on the actor.
assert task.output_channel is None
Expand Down Expand Up @@ -481,10 +530,17 @@ def _get_node_id(self):
do_exec_compiled_task,
resolved_args,
task.dag_node.get_method_name(),
output_wrapper_fn=task.output_wrapper_fn,
has_type_hints=bool(self._type_hints),
)
)

self.dag_input_channel = self.idx_to_task[self.input_task_idx].output_channel
# Wrapper function for inputs provided to dag.execute().
input_task = self.idx_to_task[self.input_task_idx]
self.input_wrapper_fn = input_task.output_wrapper_fn
self.dag_input_channel = input_task.output_channel
if self._type_hints:
_do_register_custom_dag_serializers(self)

self.dag_output_channels = []
for output in self.idx_to_task[self.output_task_idx].args:
Expand Down Expand Up @@ -596,7 +652,12 @@ def execute(
raise ValueError("Use execute_async if enable_asyncio=True")

self._get_or_compile()
self._dag_submitter.write(args[0])

inp = args[0]
if self.input_wrapper_fn is not None:
inp = self.input_wrapper_fn(inp)

self._dag_submitter.write(inp)

return self._dag_output_fetcher

Expand Down Expand Up @@ -628,7 +689,11 @@ async def execute_async(

self._get_or_compile()
async with self._dag_submission_lock:
await self._dag_submitter.write(args[0])
inp = args[0]
if self.input_wrapper_fn is not None:
inp = self.input_wrapper_fn(inp)

await self._dag_submitter.write(inp)
# Allocate a future that the caller can use to get the result.
fut = asyncio.Future()
await self._fut_queue.put(fut)
Expand Down
13 changes: 13 additions & 0 deletions python/ray/dag/dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from ray.dag.compiled_dag_node import build_compiled_dag_from_ray_dag

from ray.dag.experimental.types import DAGNodeOutputType

T = TypeVar("T")


Expand Down Expand Up @@ -61,6 +63,16 @@ def __init__(
# Cached values from last call to execute()
self.cache_from_last_execute = {}

self._type_hint: Optional[DAGNodeOutputType] = None

def with_type_hint(self, typ: DAGNodeOutputType):
self._type_hint = typ
return self

@property
def type_hint(self) -> Optional[DAGNodeOutputType]:
return self._type_hint

def get_args(self) -> Tuple[Any]:
"""Return the tuple of arguments for this node."""

Expand Down Expand Up @@ -352,6 +364,7 @@ def _copy(
new_args, new_kwargs, new_options, new_other_args_to_resolve
)
instance._stable_uuid = self._stable_uuid
instance = instance.with_type_hint(self.type_hint)
return instance

def __getstate__(self):
Expand Down
Empty file.
118 changes: 118 additions & 0 deletions python/ray/dag/experimental/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from typing import (
TYPE_CHECKING,
Tuple,
Any,
)

import ray.util.serialization
from ray.util.annotations import DeveloperAPI, PublicAPI

if TYPE_CHECKING:
import torch
import numpy as np


class DAGNodeOutputType:
pass


def _do_register_custom_dag_serializers(self: Any) -> None:
# Helper method to run on the DAG driver and actors to register custom
# serializers.
from ray.air._internal import torch_utils

default_device = torch_utils.get_devices()[0]
torch_tensor_serializer = _TorchTensorSerializer(default_device)

CUSTOM_SERIALIZERS = (
(
_TorchTensorWrapper,
torch_tensor_serializer.serialize_to_numpy,
torch_tensor_serializer.deserialize_from_numpy,
),
)

for cls, serializer, deserializer in CUSTOM_SERIALIZERS:
ray.util.serialization.register_serializer(
cls, serializer=serializer, deserializer=deserializer
)

self._torch_tensor_serializer = torch_tensor_serializer


@PublicAPI(stability="alpha")
class TorchTensorType(DAGNodeOutputType):
def __init__(self, shape: Tuple[int], dtype: "torch.dtype"):
self.shape = shape
self.dtype = dtype


@DeveloperAPI
class _TorchTensorWrapper:
def __init__(
self,
tensor: "torch.Tensor",
typ: TorchTensorType,
):
import torch

if not isinstance(tensor, torch.Tensor):
raise ValueError(
"DAG nodes wrapped with ray.experimental.TorchTensor must return a "
"torch.Tensor."
)
if tensor.shape != typ.shape:
raise ValueError(
"DAG node wrapped with ray.experimental.TorchTensor(shape="
f"{typ.shape}) returned "
f"a torch.Tensor of the shape {tensor.shape}"
)
if tensor.dtype != typ.dtype:
raise ValueError(
"DAG node wrapped with ray.experimental.TorchTensor(dtype="
f"{typ.dtype}) returned "
f"a torch.Tensor of the dtype {tensor.dtype}"
)

self.tensor = tensor


class _TorchTensorSerializer:
def __init__(self, device: "torch.device"):
self.device = device

@staticmethod
def serialize_to_numpy(instance: "_TorchTensorWrapper") -> "np.ndarray":
tensor = instance.tensor
# Transfer through Ray's shared memory store for now.
# TODO(swang): This requires two copies, one to transfer from GPU to
# CPU and another from CPU to shared memory. Ideally we should elide
# the first copy and memcpy directly from GPU to the shared memory
# buffer.
if tensor.device.type == "cuda":
tensor = tensor.to("cpu")

return tensor.numpy()

def deserialize_from_numpy(self, np_array: "np.ndarray"):
import torch

# TODO(swang): Support local P2P transfers if available.
# TODO(swang): Support multinode transfers with NCCL.

# If there is a GPU assigned to this worker, move it there.
if self.device.type == "cuda":
# Use zero-copy from_numpy() because we are going to copy to GPU
# anyway.
# TODO: Pin the np_array memory to reduce data movement time.
# TODO: Set np_array.flags.writeable=True to avoid the PyTorch
# warning about not owning the underlying memory. This is safe to
# do as long as all other readers are also copying the data to a
# GPU.
cpu_tensor = torch.from_numpy(np_array)
return cpu_tensor.to(device=self.device)

# TODO(swang): Use zero-copy from_numpy() if np_array.flags.writeable
# is True. This is safe to set when deserializing np_array if the
# upstream task has num_readers=1.
return torch.tensor(np_array, device=self.device)
Loading

0 comments on commit bbcdc49

Please sign in to comment.