Skip to content

Commit

Permalink
[core][experimental] Accelerated DAG NCCL-based p2p channels for torc…
Browse files Browse the repository at this point in the history
…h.Tensors (ray-project#45092)

## Why are these changes needed?

This adds a NCCL-based transport option for torch.tensors. Here is an
example of the API:

```python
    with InputNode() as inp:
        dag = sender.send.bind(inp)
        dag = dag.with_type_hint(TorchTensorType(SHAPE, DTYPE, transport="nccl"))
        dag = receiver.recv.bind(dag)

    compiled_dag = dag.experimental_compile()
```
When `transport="nccl"` is specified, upon compile(), Ray will
initialize a NCCL group with the actors involved. The reading actor(s)
will `recv` on the NCCL communicator instead of reading from the default
shared-memory channel.

This PR also refactors channel types so that we now create
`ChannelInterfaces` based on the type hints that appear in the DAG,
either a `TorchTensorType` or the default `SharedMemoryType`.

Current limitations:
- p2p only, no collectives
- Synchronizes CUDA stream after receiving data. This is because
kernels following the NCCL op have no guarantee that the op succeeded,
so it is not safe to read the received buffer unless we know that the op
succeeded.
- Shape and dtype of the tensor must be declared at compile time.

---------

Signed-off-by: Stephanie Wang <[email protected]>
Co-authored-by: SangBin Cho <[email protected]>
  • Loading branch information
stephanie-wang and rkooo567 committed May 11, 2024
1 parent 094748e commit 79f3995
Show file tree
Hide file tree
Showing 19 changed files with 1,563 additions and 420 deletions.
10 changes: 5 additions & 5 deletions python/ray/_private/ray_experimental_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ def read(self, chans):
chan.begin_read()
chan.end_read()

chans = [ray_channel.Channel([None], 1000)]
chans = [ray_channel.Channel(None, [None], 1000)]
results += timeit(
"[unstable] local put:local get, single channel calls",
lambda: put_channel_small(chans, do_get=True, do_release=True),
)

reader = ChannelReader.remote()
chans = [ray_channel.Channel([reader], 1000)]
chans = [ray_channel.Channel(None, [reader], 1000)]
ray.get(reader.ready.remote())
reader.read.remote(chans)
results += timeit(
Expand All @@ -86,7 +86,7 @@ def read(self, chans):
print(f"Testing multiple readers/channels, n={n_cpu}")

readers = [ChannelReader.remote() for _ in range(n_cpu)]
chans = [ray_channel.Channel(readers, 1000)]
chans = [ray_channel.Channel(None, readers, 1000)]
ray.get([reader.ready.remote() for reader in readers])
for reader in readers:
reader.read.remote(chans)
Expand All @@ -98,7 +98,7 @@ def read(self, chans):
ray.kill(reader)

reader = ChannelReader.remote()
chans = [ray_channel.Channel([reader], 1000) for _ in range(n_cpu)]
chans = [ray_channel.Channel(None, [reader], 1000) for _ in range(n_cpu)]
ray.get(reader.ready.remote())
reader.read.remote(chans)
results += timeit(
Expand All @@ -108,7 +108,7 @@ def read(self, chans):
ray.kill(reader)

readers = [ChannelReader.remote() for _ in range(n_cpu)]
chans = [ray_channel.Channel([readers[i]], 1000) for i in range(n_cpu)]
chans = [ray_channel.Channel(None, [readers[i]], 1000) for i in range(n_cpu)]
ray.get([reader.ready.remote() for reader in readers])
for chan, reader in zip(chans, readers):
reader.read.remote([chan])
Expand Down
25 changes: 20 additions & 5 deletions python/ray/dag/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
load("//bazel:python.bzl", "doctest")
load("//bazel:python.bzl", "py_test_module_list")

doctest(
files = glob(["**/*.py"]),
Expand Down Expand Up @@ -69,10 +70,24 @@ py_test(
deps = [":dag_lib"],
)

py_test_module_list(
files = [
"tests/experimental/test_accelerated_dag.py",
"tests/experimental/test_torch_tensor_dag.py",
],
size = "medium",
tags = ["exclusive", "accelerated_dag", "no_windows", "team:core"],
deps = ["//:ray_lib"],
)

py_test(
name = "test_accelerated_dag",
size = "medium",
srcs = dag_tests_srcs,
tags = ["exclusive", "team:core", "ray_dag_tests"],
deps = [":dag_lib"],
name = "test_torch_tensor_dag_gpu",
srcs = [
"tests/experimental/test_torch_tensor_dag.py",
],
main = "tests/experimental/test_torch_tensor_dag.py",
size = "medium",
tags = ["exclusive", "accelerated_dag", "no_windows", "team:core", "multi_gpu"],
env = {"RAY_PYTEST_USE_GPU": "1"},
deps = ["//:ray_lib"],
)
125 changes: 84 additions & 41 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,33 @@
import asyncio
from collections import defaultdict
from typing import Any, Callable, Dict, List, Tuple, Union, Optional
from typing import Any, Callable, Dict, List, Tuple, Union, Optional, Set
import logging
import traceback
import threading

import ray
from ray.exceptions import RayTaskError
from ray.experimental.channel import (
Channel,
_do_register_custom_serializers,
ChannelInterface,
ChannelOutputType,
ReaderInterface,
SynchronousReader,
WriterInterface,
SynchronousWriter,
AwaitableBackgroundReader,
AwaitableBackgroundWriter,
_init_nccl_group,
)
from ray.util.annotations import DeveloperAPI, PublicAPI
from ray.dag.experimental.types import _do_register_custom_dag_serializers

from ray.experimental.channel.shared_memory_channel import (
SharedMemoryType,
)
from ray.experimental.channel.torch_tensor_type import (
TorchTensorType,
_TorchTensorWrapper,
)

MAX_BUFFER_SIZE = int(100 * 1e6) # 100MB

Expand All @@ -28,8 +38,8 @@
def do_allocate_channel(
self,
readers: List[Optional["ray.actor.ActorHandle"]],
buffer_size_bytes: int,
) -> Channel:
typ: ChannelOutputType,
) -> ChannelInterface:
"""Generic actor method to allocate an output channel.
Args:
Expand All @@ -39,15 +49,25 @@ def do_allocate_channel(
Returns:
The allocated channel.
"""
output_channel = Channel(readers, buffer_size_bytes)
self_actor = None
try:
self_actor = ray.get_runtime_context().current_actor
except RuntimeError:
# This is the driver so there is no current actor handle.
pass

output_channel = typ.create_channel(
self_actor,
readers,
)
return output_channel


@DeveloperAPI
def do_exec_tasks(
self,
tasks: List["ExecutableTask"],
has_type_hints: bool,
type_hints: List[type],
) -> None:
"""Generic actor method to begin executing the tasks belonging to an actor.
This runs an infinite loop to run each task in turn (following the order specified
Expand All @@ -58,8 +78,7 @@ def do_exec_tasks(
tasks: the executable tasks corresponding to the actor methods.
"""
try:
if has_type_hints:
_do_register_custom_dag_serializers(self)
_do_register_custom_serializers(self, type_hints)

self._input_readers = []
self._output_writers = []
Expand Down Expand Up @@ -93,7 +112,7 @@ def _prep_task(self, task: "ExecutableTask") -> None:
"""
# Add placeholders for input channels.
for idx, inp in enumerate(task.resolved_args):
if isinstance(inp, Channel):
if isinstance(inp, ChannelInterface):
task.input_channels.append(inp)
task.input_channel_idxs.append(idx)
task.resolved_inputs.append(None)
Expand Down Expand Up @@ -201,33 +220,23 @@ def __init__(self, idx: int, dag_node: "ray.dag.DAGNode"):
idx: A unique index into the original DAG.
dag_node: The original DAG node created by the user.
"""
from ray.dag.experimental.types import (
TorchTensorType,
_TorchTensorWrapper,
)

self.idx = idx
self.dag_node = dag_node
self.arg_idx_to_tensor_meta: Dict[int, Dict[str, Any]] = {}

self.downstream_node_idxs = set()
self.downstream_node_idxs: Dict[int, "ray.actor.ActorHandle"] = {}
self.output_channel = None

# If set, a lambda to apply to the task output. This can be used to
# check type hints, if any.
self.output_wrapper_fn = None
if self.dag_node.type_hint is not None:
print(self.dag_node.type_hint)
if isinstance(self.dag_node.type_hint, TorchTensorType):
# Wrap outputs produced by this task to indicate that it
# should be specially serialized.
self.output_wrapper_fn = lambda t: _TorchTensorWrapper(
t, self.dag_node.type_hint
)
else:
raise ValueError(
"DAGNode.with_type_hint may only be called on " "TorchTensorType"
)

@property
def args(self) -> Tuple[Any]:
Expand Down Expand Up @@ -270,9 +279,9 @@ def __init__(
self.output_wrapper_fn = task.output_wrapper_fn
self.resolved_args = resolved_args

self.resolved_inputs = []
self.input_channels = []
self.input_channel_idxs = []
self.resolved_inputs: List[Union[Any, ChannelInterface]] = []
self.input_channels: List[ChannelInterface] = []
self.input_channel_idxs: List[int] = []


@DeveloperAPI
Expand Down Expand Up @@ -314,6 +323,9 @@ def __init__(
self._buffer_size_bytes: Optional[int] = buffer_size_bytes
if self._buffer_size_bytes is None:
self._buffer_size_bytes = MAX_BUFFER_SIZE
self._default_type_hint: ChannelOutputType = SharedMemoryType(
self._buffer_size_bytes
)
if not isinstance(self._buffer_size_bytes, int) or self._buffer_size_bytes <= 0:
raise ValueError(
"`buffer_size_bytes` must be a positive integer, found "
Expand Down Expand Up @@ -345,8 +357,8 @@ def __init__(
self.actor_task_count: Dict["ray._raylet.ActorID", int] = defaultdict(int)

# Cached attributes that are set during compilation.
self.dag_input_channel: Optional[Channel] = None
self.dag_output_channels: Optional[List[Channel]] = None
self.dag_input_channel: Optional[ChannelInterface] = None
self.dag_output_channels: Optional[List[ChannelInterface]] = None
self._dag_submitter: Optional[WriterInterface] = None
self._dag_output_fetcher: Optional[ReaderInterface] = None

Expand Down Expand Up @@ -391,6 +403,8 @@ def _preprocess(self) -> None:
self.actor_task_count.clear()
self._type_hints.clear()

nccl_actors: Set["ray.actor.ActorHandle"] = set()

# For each task node, set its upstream and downstream task nodes.
# Also collect the set of tasks that produce torch.tensors.
for node_idx, task in self.idx_to_task.items():
Expand Down Expand Up @@ -425,10 +439,30 @@ def _preprocess(self) -> None:
)
self.actor_task_count[actor_handle._actor_id] += 1

for arg in task.args:
if isinstance(arg, DAGNode):
arg_node_idx = self.dag_node_to_idx[arg]
self.idx_to_task[arg_node_idx].downstream_node_idxs.add(node_idx)
if (
isinstance(dag_node.type_hint, TorchTensorType)
and dag_node.type_hint.transport == "nccl"
):
# Add all writers to the NCCL group.
nccl_actors.add(actor_handle)

for arg_idx, arg in enumerate(task.args):
if not isinstance(arg, DAGNode):
continue

upstream_node_idx = self.dag_node_to_idx[arg]
upstream_node = self.idx_to_task[upstream_node_idx]
downstream_actor_handle = None
if isinstance(task.dag_node, ClassMethodNode):
downstream_actor_handle = task.dag_node._get_actor_handle()
upstream_node.downstream_node_idxs[node_idx] = downstream_actor_handle

if (
isinstance(upstream_node.dag_node.type_hint, TorchTensorType)
and upstream_node.dag_node.type_hint.transport == "nccl"
):
# Add all readers to the NCCL group.
nccl_actors.add(downstream_actor_handle)

if dag_node.type_hint is not None:
self._type_hints.append(dag_node.type_hint)
Expand Down Expand Up @@ -462,6 +496,14 @@ def _preprocess(self) -> None:
# now.
self._preprocess()

# If there were type hints indicating transport via NCCL, initialize
# the NCCL group on the participating actors.
nccl_actors = list(nccl_actors)
if None in nccl_actors:
raise ValueError("Driver cannot participate in the NCCL group.")
if nccl_actors:
_init_nccl_group(nccl_actors)

def _get_or_compile(
self,
) -> None:
Expand Down Expand Up @@ -493,13 +535,14 @@ def _get_or_compile(
continue
visited.add(cur_idx)

# TODO: Check for GPU arguments. Find the actor upstream to that
# GPU argument. If both writer and reader actors are on GPUs, then
# add them.

task = self.idx_to_task[cur_idx]
# Create an output buffer for the actor method.
assert task.output_channel is None

type_hint = task.dag_node.type_hint
if type_hint is None:
type_hint = self._default_type_hint

if isinstance(task.dag_node, ClassMethodNode):
readers = [self.idx_to_task[idx] for idx in task.downstream_node_idxs]
assert len(readers) == 1
Expand Down Expand Up @@ -533,7 +576,7 @@ def _get_node_id(self):
fn.remote(
do_allocate_channel,
reader_handles,
buffer_size_bytes=self._buffer_size_bytes,
typ=type_hint,
)
)
actor_handle = task.dag_node._get_actor_handle()
Expand All @@ -553,9 +596,10 @@ def _get_node_id(self):
)
reader_handles_set.add(reader_handle)
reader_handles.append(reader_handle)
task.output_channel = Channel(
task.output_channel = do_allocate_channel(
self,
reader_handles,
buffer_size_bytes=self._buffer_size_bytes,
typ=type_hint,
)
else:
assert isinstance(task.dag_node, MultiOutputNode)
Expand Down Expand Up @@ -621,15 +665,14 @@ def _get_node_id(self):
] = worker_fn.options(concurrency_group="_ray_system").remote(
do_exec_tasks,
executable_tasks,
has_type_hints=bool(self._type_hints),
type_hints=list(set(self._type_hints)),
)

# Wrapper function for inputs provided to dag.execute().
input_task = self.idx_to_task[self.input_task_idx]
self.input_wrapper_fn = input_task.output_wrapper_fn
self.dag_input_channel = input_task.output_channel
if self._type_hints:
_do_register_custom_dag_serializers(self)
_do_register_custom_serializers(self, list(set(self._type_hints)))

self.dag_output_channels = []
for output in self.idx_to_task[self.output_task_idx].args:
Expand Down Expand Up @@ -747,7 +790,7 @@ def execute(
self,
*args,
**kwargs,
) -> Union[Channel, List[Channel]]:
) -> ReaderInterface:
"""Execute this DAG using the compiled execution path.
Args:
Expand Down
9 changes: 4 additions & 5 deletions python/ray/dag/dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
import asyncio

from ray.dag.compiled_dag_node import build_compiled_dag_from_ray_dag

from ray.dag.experimental.types import DAGNodeOutputType
from ray.experimental.channel import ChannelOutputType

T = TypeVar("T")

Expand Down Expand Up @@ -63,14 +62,14 @@ def __init__(
# Cached values from last call to execute()
self.cache_from_last_execute = {}

self._type_hint: Optional[DAGNodeOutputType] = None
self._type_hint: Optional[ChannelOutputType] = None

def with_type_hint(self, typ: DAGNodeOutputType):
def with_type_hint(self, typ: ChannelOutputType):
self._type_hint = typ
return self

@property
def type_hint(self) -> Optional[DAGNodeOutputType]:
def type_hint(self) -> Optional[ChannelOutputType]:
return self._type_hint

def get_args(self) -> Tuple[Any]:
Expand Down
Loading

0 comments on commit 79f3995

Please sign in to comment.