Skip to content

Commit

Permalink
[core][experimental] Avoid serialization for data passed between two …
Browse files Browse the repository at this point in the history
…tasks on the same actor (#45591)

* `LocalChannel` is a channel for communication between two tasks in the
same worker process. It writes data directly to the worker's
serialization context and reads data from the serialization context to
avoid the serialization overhead and the need for reading/writing from
shared memory.

* `MultiChannel` can be used to send data to different readers via
different channels. For example, if the reader is in the same worker
process as the writer, the data can be sent via LocalChannel. If the
reader is in a different worker process, the data can be sent via shared
memory channel.

## Simple benchmark

* I used [this Python
script](https://gist.github.com/kevin85421/0ebffad403d158ab140c4d4dc879e214)
to conduct a simple benchmark with the commit
97292b1
for an early signal.
* Experiment results
  * Case 1: with this PR => `Execution time: 1.246751070022583 seconds`
* Case 2: without this PR => `Execution time: 2.6125080585479736
seconds`

Closes #45230

---------

Signed-off-by: kaihsun <[email protected]>
  • Loading branch information
kevin85421 committed Jun 13, 2024
1 parent d59d1ef commit d577652
Show file tree
Hide file tree
Showing 8 changed files with 523 additions and 14 deletions.
6 changes: 3 additions & 3 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def do_allocate_channel(
Args:
readers: The actor handles of the readers.
buffer_size_bytes: The maximum size of messages in the channel.
typ: The output type hint for the channel.
Returns:
The allocated channel.
Expand Down Expand Up @@ -131,7 +131,7 @@ def _exec_task(self, task: "ExecutableTask", idx: int) -> bool:
True if we are done executing all tasks of this actor, False otherwise.
"""
# TODO: for cases where output is passed as input to a task on
# the same actor, introduce a "LocalChannel" to avoid the overhead
# the same actor, introduce a "IntraProcessChannel" to avoid the overhead
# of serialization/deserialization and synchronization.
method = getattr(self, task.method_name)
input_reader = self._input_readers[idx]
Expand Down Expand Up @@ -683,12 +683,12 @@ def _get_or_compile(
# `readers` is the nodes that are ordered after the current one (`task`)
# in the DAG.
readers = [self.idx_to_task[idx] for idx in task.downstream_node_idxs]
assert len(readers) == 1

def _get_node_id(self):
return ray.get_runtime_context().get_node_id()

if isinstance(readers[0].dag_node, MultiOutputNode):
assert len(readers) == 1
# This node is a multi-output node, which means that it will only be
# read by the driver, not an actor. Thus, we handle this case by
# setting `reader_handles` to `[self._driver_actor]`.
Expand Down
113 changes: 113 additions & 0 deletions python/ray/dag/tests/experimental/test_accelerated_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,119 @@ async def main():
compiled_dag.teardown()


class TestCompositeChannel:
def test_composite_channel_one_actor(self, ray_start_regular_shared):
"""
In this test, there are three 'inc' tasks on the same Ray actor, chained
together. Therefore, the DAG will look like this:
Driver -> a.inc -> a.inc -> a.inc -> Driver
All communication between the driver and the actor will be done through remote
channels, i.e., shared memory channels. All communication between the actor
tasks will be conducted through local channels, i.e., IntraProcessChannel in
this case.
To elaborate, all output channels of the actor DAG nodes will be
CompositeChannel, and the first two will have a local channel, while the last
one will have a remote channel.
"""
a = Actor.remote(0)
with InputNode() as inp:
dag = a.inc.bind(inp)
dag = a.inc.bind(dag)
dag = a.inc.bind(dag)

compiled_dag = dag.experimental_compile()
output_channel = compiled_dag.execute(1)
result = output_channel.begin_read()
assert result == 4
output_channel.end_read()

output_channel = compiled_dag.execute(2)
result = output_channel.begin_read()
assert result == 24
output_channel.end_read()

output_channel = compiled_dag.execute(3)
result = output_channel.begin_read()
assert result == 108
output_channel.end_read()

compiled_dag.teardown()

def test_composite_channel_two_actors(self, ray_start_regular_shared):
"""
In this test, there are three 'inc' tasks on the two Ray actors, chained
together. Therefore, the DAG will look like this:
Driver -> a.inc -> b.inc -> a.inc -> Driver
All communication between the driver and actors will be done through remote
channels. Also, all communication between the actor tasks will be conducted
through remote channels, i.e., shared memory channel in this case because no
consecutive tasks are on the same actor.
"""
a = Actor.remote(0)
b = Actor.remote(100)
with InputNode() as inp:
dag = a.inc.bind(inp)
dag = b.inc.bind(dag)
dag = a.inc.bind(dag)

# a: 0+1 -> b: 100+1 -> a: 1+101
compiled_dag = dag.experimental_compile()
output_channel = compiled_dag.execute(1)
result = output_channel.begin_read()
assert result == 102
output_channel.end_read()

# a: 102+2 -> b: 101+104 -> a: 104+205
output_channel = compiled_dag.execute(2)
result = output_channel.begin_read()
assert result == 309
output_channel.end_read()

# a: 309+3 -> b: 205+312 -> a: 312+517
output_channel = compiled_dag.execute(3)
result = output_channel.begin_read()
assert result == 829
output_channel.end_read()

compiled_dag.teardown()

def test_composite_channel_multi_output(self, ray_start_regular_shared):
"""
Driver -> a.inc -> a.inc ---> Driver
| |
-> b.inc -
All communication in this DAG will be done through CompositeChannel.
Under the hood, the communication between two `a.inc` tasks will
be done through a local channel, i.e., IntraProcessChannel in this
case, while the communication between `a.inc` and `b.inc` will be
done through a shared memory channel.
"""
a = Actor.remote(0)
b = Actor.remote(100)
with InputNode() as inp:
dag = a.inc.bind(inp)
dag = MultiOutputNode([a.inc.bind(dag), b.inc.bind(dag)])

compiled_dag = dag.experimental_compile()
output_channel = compiled_dag.execute(1)
result = output_channel.begin_read()
assert result == [2, 101]
output_channel.end_read()

output_channel = compiled_dag.execute(3)
result = output_channel.begin_read()
assert result == [10, 106]
output_channel.end_read()

compiled_dag.teardown()


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
Expand Down
5 changes: 4 additions & 1 deletion python/ray/experimental/channel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
SynchronousWriter,
WriterInterface,
)
from ray.experimental.channel.shared_memory_channel import Channel
from ray.experimental.channel.intra_process_channel import IntraProcessChannel
from ray.experimental.channel.shared_memory_channel import Channel, CompositeChannel
from ray.experimental.channel.torch_tensor_nccl_channel import TorchTensorNcclChannel

__all__ = [
Expand All @@ -22,4 +23,6 @@
"WriterInterface",
"ChannelContext",
"TorchTensorNcclChannel",
"IntraProcessChannel",
"CompositeChannel",
]
6 changes: 6 additions & 0 deletions python/ray/experimental/channel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,15 @@ def __init__(
pass

def ensure_registered_as_writer(self):
"""
Check whether the process is a valid writer. This method must be idempotent.
"""
raise NotImplementedError

def ensure_registered_as_reader(self):
"""
Check whether the process is a valid reader. This method must be idempotent.
"""
raise NotImplementedError

def write(self, value: Any) -> None:
Expand Down
65 changes: 65 additions & 0 deletions python/ray/experimental/channel/intra_process_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import uuid
from typing import Any, Optional

import ray
from ray.experimental.channel import ChannelContext
from ray.experimental.channel.common import ChannelInterface
from ray.util.annotations import PublicAPI


@PublicAPI(stability="alpha")
class IntraProcessChannel(ChannelInterface):
"""
IntraProcessChannel is a channel for communication between two tasks in the same
worker process. It writes data directly to the worker's _SerializationContext
and reads data from the _SerializationContext to avoid the serialization
overhead and the need for reading/writing from shared memory.
Args:
actor_handle: The actor handle of the worker process.
"""

def __init__(
self,
actor_handle: ray.actor.ActorHandle,
_channel_id: Optional[str] = None,
):
# TODO (kevin85421): Currently, if we don't pass `actor_handle` to
# `IntraProcessChannel`, the actor will die due to the reference count of
# `actor_handle` is 0. We should fix this issue in the future.
self._actor_handle = actor_handle
# Generate a unique ID for the channel. The writer and reader will use
# this ID to store and retrieve data from the _SerializationContext.
self._channel_id = _channel_id
if self._channel_id is None:
self._channel_id = str(uuid.uuid4())

def ensure_registered_as_writer(self) -> None:
pass

def ensure_registered_as_reader(self) -> None:
pass

def __reduce__(self):
return IntraProcessChannel, (
self._actor_handle,
self._channel_id,
)

def write(self, value: Any):
# Because both the reader and writer are in the same worker process,
# we can directly store the data in the context instead of storing
# it in the channel object. This removes the serialization overhead of `value`.
ctx = ChannelContext.get_current().serialization_context
ctx.set_data(self._channel_id, value)

def begin_read(self) -> Any:
ctx = ChannelContext.get_current().serialization_context
return ctx.get_data(self._channel_id)

def end_read(self):
pass

def close(self) -> None:
ctx = ChannelContext.get_current().serialization_context
ctx.reset_data(self._channel_id)
23 changes: 22 additions & 1 deletion python/ray/experimental/channel/serialization_context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, List, Union
from typing import TYPE_CHECKING, Any, Dict, List, Union

if TYPE_CHECKING:
import numpy as np
Expand All @@ -9,10 +9,31 @@ class _SerializationContext:
def __init__(self):
self.use_external_transport: bool = False
self.tensors: List["torch.Tensor"] = []
# Buffer for transferring data between tasks in the same worker process.
# The key is the channel ID, and the value is the data. We don't use a
# lock when reading/writing the buffer because a DAG node actor will only
# execute one task at a time in `do_exec_tasks`. It will not execute multiple
# Ray tasks on a single actor simultaneously.
self.intra_process_channel_buffers: Dict[str, Any] = {}

def set_use_external_transport(self, use_external_transport: bool) -> None:
self.use_external_transport = use_external_transport

def set_data(self, channel_id: str, value: Any) -> None:
assert (
channel_id not in self.intra_process_channel_buffers
), f"Channel {channel_id} already exists in the buffer."
self.intra_process_channel_buffers[channel_id] = value

def get_data(self, channel_id: str) -> Any:
assert (
channel_id in self.intra_process_channel_buffers
), f"Channel {channel_id} does not exist in the buffer."
return self.intra_process_channel_buffers.pop(channel_id)

def reset_data(self, channel_id: str) -> None:
self.intra_process_channel_buffers.pop(channel_id, None)

def reset_tensors(self, tensors: List["torch.Tensor"]) -> List["torch.Tensor"]:
prev_tensors = self.tensors
self.tensors = tensors
Expand Down
Loading

0 comments on commit d577652

Please sign in to comment.