forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[core][experimental] Pass torch.Tensors through accelerated DAGs (ray…
…-project#44825) This PR adds support for passing torch.Tensors to local actors in an accelerated DAG, via Ray's shared memory store. It supports the following transfer cases, as long as the sending and receiving actors are on the same node: CPU-CPU, CPU-GPU, GPU-CPU, GPU-GPU (via CPU). This iteration requires the user to explicitly declare which DAG nodes contain torch.Tensors and the tensors' shape and dtype, with a new `with_type_hint` decorator. For example: ```python with InputNode() as inp: dag = sender.send.bind(inp) dag = dag.with_type_hint(TorchTensorType(SHAPE, DTYPE)) dag = receiver.recv.bind(dag) compiled_dag = dag.experimental_compile() ``` This declaration isn't necessarily useful for this PR, but it is included now because it makes it much simpler to efficiently support other cases in the future, such as p2p GPU-GPU transfers. When a TorchTensor node is declared, the serialization of the underlying torch.Tensor is performed differently from vanilla Ray. In particular, we store the numpy view of the data. On the receiving actor, we deserialize to a torch.Tensor and move it to the device assigned to the actor, if any. Microbenchmarking shows that this is 4x faster than normal pickling and unpickling of a torch.Tensor, likely due to Ray's serialization support for numpy. Also, when moving the torch.Tensor to a GPU on the receiving side, we can avoid one extra data copy by copying directly from Ray's shared memory buffer to GPU memory. Limitations: - Only supports tasks that directly return a torch.Tensor, i.e. the torch.Tensor cannot be nested in other data. - The task must declare the shape and dtype of its torch.Tensor at DAG compile time. - Does not support local p2p GPU-GPU transfer, either using `cudaMemCpy` or NCCL. Microbenchmark shows this can be >10x faster than transfer via CPU. - Does not support multinode GPU-GPU transfer, e.g., via RPC between hosts or NCCL. --------- Signed-off-by: Stephanie Wang <[email protected]>
- Loading branch information
1 parent
ab8e10a
commit bbcdc49
Showing
5 changed files
with
354 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
from typing import ( | ||
TYPE_CHECKING, | ||
Tuple, | ||
Any, | ||
) | ||
|
||
import ray.util.serialization | ||
from ray.util.annotations import DeveloperAPI, PublicAPI | ||
|
||
if TYPE_CHECKING: | ||
import torch | ||
import numpy as np | ||
|
||
|
||
class DAGNodeOutputType: | ||
pass | ||
|
||
|
||
def _do_register_custom_dag_serializers(self: Any) -> None: | ||
# Helper method to run on the DAG driver and actors to register custom | ||
# serializers. | ||
from ray.air._internal import torch_utils | ||
|
||
default_device = torch_utils.get_devices()[0] | ||
torch_tensor_serializer = _TorchTensorSerializer(default_device) | ||
|
||
CUSTOM_SERIALIZERS = ( | ||
( | ||
_TorchTensorWrapper, | ||
torch_tensor_serializer.serialize_to_numpy, | ||
torch_tensor_serializer.deserialize_from_numpy, | ||
), | ||
) | ||
|
||
for cls, serializer, deserializer in CUSTOM_SERIALIZERS: | ||
ray.util.serialization.register_serializer( | ||
cls, serializer=serializer, deserializer=deserializer | ||
) | ||
|
||
self._torch_tensor_serializer = torch_tensor_serializer | ||
|
||
|
||
@PublicAPI(stability="alpha") | ||
class TorchTensorType(DAGNodeOutputType): | ||
def __init__(self, shape: Tuple[int], dtype: "torch.dtype"): | ||
self.shape = shape | ||
self.dtype = dtype | ||
|
||
|
||
@DeveloperAPI | ||
class _TorchTensorWrapper: | ||
def __init__( | ||
self, | ||
tensor: "torch.Tensor", | ||
typ: TorchTensorType, | ||
): | ||
import torch | ||
|
||
if not isinstance(tensor, torch.Tensor): | ||
raise ValueError( | ||
"DAG nodes wrapped with ray.experimental.TorchTensor must return a " | ||
"torch.Tensor." | ||
) | ||
if tensor.shape != typ.shape: | ||
raise ValueError( | ||
"DAG node wrapped with ray.experimental.TorchTensor(shape=" | ||
f"{typ.shape}) returned " | ||
f"a torch.Tensor of the shape {tensor.shape}" | ||
) | ||
if tensor.dtype != typ.dtype: | ||
raise ValueError( | ||
"DAG node wrapped with ray.experimental.TorchTensor(dtype=" | ||
f"{typ.dtype}) returned " | ||
f"a torch.Tensor of the dtype {tensor.dtype}" | ||
) | ||
|
||
self.tensor = tensor | ||
|
||
|
||
class _TorchTensorSerializer: | ||
def __init__(self, device: "torch.device"): | ||
self.device = device | ||
|
||
@staticmethod | ||
def serialize_to_numpy(instance: "_TorchTensorWrapper") -> "np.ndarray": | ||
tensor = instance.tensor | ||
# Transfer through Ray's shared memory store for now. | ||
# TODO(swang): This requires two copies, one to transfer from GPU to | ||
# CPU and another from CPU to shared memory. Ideally we should elide | ||
# the first copy and memcpy directly from GPU to the shared memory | ||
# buffer. | ||
if tensor.device.type == "cuda": | ||
tensor = tensor.to("cpu") | ||
|
||
return tensor.numpy() | ||
|
||
def deserialize_from_numpy(self, np_array: "np.ndarray"): | ||
import torch | ||
|
||
# TODO(swang): Support local P2P transfers if available. | ||
# TODO(swang): Support multinode transfers with NCCL. | ||
|
||
# If there is a GPU assigned to this worker, move it there. | ||
if self.device.type == "cuda": | ||
# Use zero-copy from_numpy() because we are going to copy to GPU | ||
# anyway. | ||
# TODO: Pin the np_array memory to reduce data movement time. | ||
# TODO: Set np_array.flags.writeable=True to avoid the PyTorch | ||
# warning about not owning the underlying memory. This is safe to | ||
# do as long as all other readers are also copying the data to a | ||
# GPU. | ||
cpu_tensor = torch.from_numpy(np_array) | ||
return cpu_tensor.to(device=self.device) | ||
|
||
# TODO(swang): Use zero-copy from_numpy() if np_array.flags.writeable | ||
# is True. This is safe to set when deserializing np_array if the | ||
# upstream task has num_readers=1. | ||
return torch.tensor(np_array, device=self.device) |
Oops, something went wrong.