Skip to content

Commit

Permalink
[Data] iter_torch_batches updates (ray-project#37625)
Browse files Browse the repository at this point in the history
Adds more documentation to iter_torch_batches docstring.
Changes the default value of device parameter to "auto" to make the behavior of automatic device transfer more explicit.

---------

Signed-off-by: amogkam <[email protected]>
Signed-off-by: NripeshN <[email protected]>
  • Loading branch information
amogkam authored and NripeshN committed Aug 15, 2023
1 parent 298fef5 commit 9a24b5e
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 72 deletions.
8 changes: 0 additions & 8 deletions python/ray/data/_internal/torch_iterable_dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
from typing import TYPE_CHECKING, Dict, Union

from torch.utils.data import IterableDataset

if TYPE_CHECKING:
import torch


TorchTensorBatchType = Union["torch.Tensor", Dict[str, "torch.Tensor"]]


class TorchIterableDataset(IterableDataset):
def __init__(self, generator_func):
Expand Down
83 changes: 53 additions & 30 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from uuid import uuid4
Expand Down Expand Up @@ -152,7 +153,6 @@
from tensorflow_metadata.proto.v0 import schema_pb2

from ray.data._internal.execution.interfaces import Executor, NodeIdStr
from ray.data._internal.torch_iterable_dataset import TorchTensorBatchType
from ray.data.dataset_pipeline import DatasetPipeline
from ray.data.grouped_data import GroupedData

Expand All @@ -165,6 +165,9 @@

TensorFlowTensorBatchType = Union["tf.Tensor", Dict[str, "tf.Tensor"]]

CollatedData = TypeVar("CollatedData")
TorchBatchType = Union[Dict[str, "torch.Tensor"], CollatedData]


@PublicAPI
class Dataset:
Expand Down Expand Up @@ -3336,7 +3339,7 @@ def iter_batches(
drop_last: bool = False,
local_shuffle_buffer_size: Optional[int] = None,
local_shuffle_seed: Optional[int] = None,
_collate_fn: Optional[Callable[[DataBatch], Any]] = None,
_collate_fn: Optional[Callable[[DataBatch], CollatedData]] = None,
# Deprecated.
prefetch_blocks: int = 0,
) -> Iterator[DataBatch]:
Expand Down Expand Up @@ -3408,37 +3411,49 @@ def iter_torch_batches(
prefetch_batches: int = 1,
batch_size: Optional[int] = 256,
dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None,
device: Optional[str] = None,
collate_fn: Optional[Callable[[Dict[str, np.ndarray]], Any]] = None,
device: str = "auto",
collate_fn: Optional[Callable[[Dict[str, np.ndarray]], CollatedData]] = None,
drop_last: bool = False,
local_shuffle_buffer_size: Optional[int] = None,
local_shuffle_seed: Optional[int] = None,
# Deprecated
prefetch_blocks: int = 0,
) -> Iterator["TorchTensorBatchType"]:
) -> Iterator[TorchBatchType]:
"""Return an iterator over batches of data represented as Torch tensors.
This iterator yields batches of type ``Dict[str, torch.Tensor]``.
For more flexibility, call :meth:`~Dataset.iter_batches` and manually convert
your data to Torch tensors.
Examples:
>>> import ray
>>> for batch in ray.data.range(
... 12,
... ).iter_torch_batches(batch_size=4):
... print(batch)
{'id': tensor([0, 1, 2, 3])}
{'id': tensor([4, 5, 6, 7])}
{'id': tensor([ 8, 9, 10, 11])}
Use the ``collate_fn`` to customize how the tensor batch is created.
>>> from typing import Any, Dict
>>> import torch
>>> import numpy as np
>>> import ray
>>> def collate_fn(batch: Dict[str, np.ndarray]) -> Any:
... return torch.stack(
... [torch.as_tensor(array) for array in batch.values()],
... axis=1
... )
>>> dataset = ray.data.from_items([
... {"col_1": 1, "col_2": 2},
... {"col_1": 3, "col_2": 4}])
>>> for batch in dataset.iter_torch_batches(collate_fn=collate_fn):
... print(batch)
tensor([[1, 2],
[3, 4]])
.. testcode::
import ray
# This dataset contains three images.
ds = ray.data.read_images("example:https://image-datasets/simple")
for batch in ds.iter_torch_batches(batch_size=2):
print(batch)
.. testoutput::
:options: +MOCK
{'image': <tf.Tensor: shape=(2, 32, 32, 3), dtype=uint8, numpy=array([[[[...]]]], dtype=uint8)>}
{'image': <tf.Tensor: shape=(1, 32, 32, 3), dtype=uint8, numpy=array([[[[...]]]], dtype=uint8)>}
Time complexity: O(1)
Expand All @@ -3451,17 +3466,25 @@ def iter_torch_batches(
blocks as batches (blocks may contain different number of rows).
The final batch may include fewer than ``batch_size`` rows if
``drop_last`` is ``False``. Defaults to 256.
dtypes: The Torch dtype(s) for the created tensor(s); if ``None``, the
dtype is inferred from the tensor data.
device: The device on which the tensor should be placed; if ``None``, the
Torch tensor is constructed on CPU.
dtypes: The Torch dtype(s) for the created tensor(s); if ``None``, the dtype
is inferred from the tensor data. You can't use this parameter with
``collate_fn``.
device: The device on which the tensor should be placed. Defaults to
"auto" which moves the tensors to the appropriate device when the
Dataset is passed to Ray Train and ``collate_fn`` is not provided.
Otherwise, defaults to CPU. You can't use this parameter with
``collate_fn``.
collate_fn: A function to convert a Numpy batch to a PyTorch tensor batch.
Potential use cases include collating along a dimension other than the
first, padding sequences of various lengths, or generally handling
batches of different length tensors. If not provided, the default
collate function is used which simply converts the batch of numpy
arrays to a batch of PyTorch tensors. This API is still experimental
and is subject to change.
When this parameter is specified, the user should manually handle the
host to device data transfer outside of collate_fn.
This is useful for further processing the data after it has been
batched. Potential use cases include collating along a dimension other
than the first, padding sequences of various lengths, or generally
handling batches of different length tensors. If not provided, the
default collate function is used which simply converts the batch of
numpy arrays to a batch of PyTorch tensors. This API is still
experimental and is subject to change. You can't use this parameter in
conjunction with ``dtypes`` or ``device``.
drop_last: Whether to drop the last batch if it's incomplete.
local_shuffle_buffer_size: If not ``None``, the data is randomly shuffled
using a local in-memory shuffle buffer, and this value serves as the
Expand Down
97 changes: 63 additions & 34 deletions python/ray/data/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@
import tensorflow as tf
import torch

from ray.data._internal.torch_iterable_dataset import TorchTensorBatchType
from ray.data.dataset import Schema, TensorFlowTensorBatchType
from ray.data.dataset import (
CollatedData,
Schema,
TensorFlowTensorBatchType,
TorchBatchType,
)


@PublicAPI(stability="beta")
Expand Down Expand Up @@ -93,7 +97,7 @@ def iter_batches(
drop_last: bool = False,
local_shuffle_buffer_size: Optional[int] = None,
local_shuffle_seed: Optional[int] = None,
_collate_fn: Optional[Callable[[DataBatch], Any]] = None,
_collate_fn: Optional[Callable[[DataBatch], "CollatedData"]] = None,
_finalize_fn: Optional[Callable[[Any], Any]] = None,
# Deprecated.
prefetch_blocks: int = 0,
Expand Down Expand Up @@ -255,29 +259,48 @@ def iter_torch_batches(
prefetch_batches: int = 1,
batch_size: Optional[int] = 256,
dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None,
device: Optional[str] = None,
collate_fn: Optional[
Callable[[Union[np.ndarray, Dict[str, np.ndarray]]], Any]
] = None,
device: str = "auto",
collate_fn: Optional[Callable[[Dict[str, np.ndarray]], "CollatedData"]] = None,
drop_last: bool = False,
local_shuffle_buffer_size: Optional[int] = None,
local_shuffle_seed: Optional[int] = None,
# Deprecated.
prefetch_blocks: int = 0,
) -> Iterator["TorchTensorBatchType"]:
) -> Iterator["TorchBatchType"]:
"""Return a batched iterator of Torch Tensors over the dataset.
This iterator will yield single-tensor batches if the underlying dataset
consists of a single column; otherwise, it will yield a dictionary of
column-tensors. If looking for more flexibility in the tensor conversion (e.g.
casting dtypes) or the batch format, try using `.iter_batches` directly.
This iterator yields a dictionary of column-tensors. If you are looking for
more flexibility in the tensor conversion (e.g. casting dtypes) or the batch
format, try using :meth:`~ray.data.iterator.DataIterator.iter_batches` directly.
Examples:
>>> import ray
>>> for row in ray.data.range(
... 1000000
... ).iterator().iter_rows(): # doctest: +SKIP
... print(row) # doctest: +SKIP
>>> for batch in ray.data.range(
... 12,
... ).iterator().iter_torch_batches(batch_size=4):
... print(batch)
{'id': tensor([0, 1, 2, 3])}
{'id': tensor([4, 5, 6, 7])}
{'id': tensor([ 8, 9, 10, 11])}
Use the ``collate_fn`` to customize how the tensor batch is created.
>>> from typing import Any, Dict
>>> import torch
>>> import numpy as np
>>> import ray
>>> def collate_fn(batch: Dict[str, np.ndarray]) -> Any:
... return torch.stack(
... [torch.as_tensor(array) for array in batch.values()],
... axis=1
... )
>>> iterator = ray.data.from_items([
... {"col_1": 1, "col_2": 2},
... {"col_1": 3, "col_2": 4}]).iterator()
>>> for batch in iterator.iter_torch_batches(collate_fn=collate_fn):
... print(batch)
tensor([[1, 2],
[3, 4]])
Time complexity: O(1)
Expand All @@ -293,17 +316,24 @@ def iter_torch_batches(
The final batch may include fewer than ``batch_size`` rows if
``drop_last`` is ``False``. Defaults to 256.
dtypes: The Torch dtype(s) for the created tensor(s); if None, the dtype
will be inferred from the tensor data.
device: The device on which the tensor should be placed; if None, the Torch
tensor will be constructed on the CPU.
collate_fn: A function to apply to each data batch before returning it. When
this parameter is specified, the user should manually handle the host
to device data transfer outside of collate_fn. Potential use cases
include collating along a dimension other than the first, padding
sequences of various lengths, or generally handling batches of different
length tensors. This API is still experimental and is subject to change.
This parameter cannot be used in conjunction with ``dtypes`` or
``device``.
will be inferred from the tensor data. You can't use this parameter
with ``collate_fn``.
device: The device on which the tensor should be placed. Defaults to
"auto" which moves the tensors to the appropriate device when the
Dataset is passed to Ray Train and ``collate_fn`` is not provided.
Otherwise, defaults to CPU. You can't use this parameter with
``collate_fn``.
collate_fn: A function to convert a Numpy batch to a PyTorch tensor batch.
When this parameter is specified, the user should manually handle the
host to device data transfer outside of ``collate_fn``.
This is useful for further processing the data after it has been
batched. Potential use cases include collating along a dimension other
than the first, padding sequences of various lengths, or generally
handling batches of different length tensors. If not provided, the
default collate function is used which simply converts the batch of
numpy arrays to a batch of PyTorch tensors. This API is still
experimental and is subject to change. You can't use this parameter in
conjunction with ``dtypes`` or ``device``.
drop_last: Whether to drop the last batch if it's incomplete.
local_shuffle_buffer_size: If non-None, the data will be randomly shuffled
using a local in-memory shuffle buffer, and this value will serve as the
Expand All @@ -324,20 +354,19 @@ def iter_torch_batches(
get_device,
)

if collate_fn is not None and (dtypes is not None or device is not None):
if collate_fn is not None and (dtypes is not None or device != "auto"):
raise ValueError(
"collate_fn cannot be used with dtypes and device."
"You should manually move the output Torch tensors to the"
"desired dtype and device outside of collate_fn."
)

if collate_fn is None:
# Automatically move torch tensors to the appropriate device.
if device is None:
default_device = get_device()
if default_device.type != "cpu":
device = default_device
if device == "auto":
# Use the appropriate device for Ray Train, or falls back to CPU if
# Ray Train is not being used.
device = get_device()

if collate_fn is None:
# The default collate_fn handles formatting and Tensor creation.
# Here, we set device=None to defer host to device data transfer
# to the subsequent finalize_fn.
Expand Down

0 comments on commit 9a24b5e

Please sign in to comment.