Skip to content

Commit

Permalink
[Data] Make iter_batches an Iterable (ray-project#37881)
Browse files Browse the repository at this point in the history
Make iter_batches, iter_torch_batches, iter_rows and iter_tf_batches return an Iterable instead of a single use Iterator.

---------

Signed-off-by: amogkam <[email protected]>
  • Loading branch information
amogkam committed Aug 1, 2023
1 parent 844c2cc commit 0bded3b
Show file tree
Hide file tree
Showing 13 changed files with 199 additions and 435 deletions.
343 changes: 0 additions & 343 deletions doc/source/data/doc_code/tensor.py

This file was deleted.

2 changes: 1 addition & 1 deletion doc/source/train/doc_code/key_concepts.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def train_fn(config):
# Local worker rank on the current machine
"local_rank": session.get_local_rank(),
# Data
"data_shard": next(dataset_shard.iter_batches(batch_format="pandas")),
"data_shard": next(iter(dataset_shard.iter_batches(batch_format="pandas"))),
}
)

Expand Down
13 changes: 11 additions & 2 deletions python/ray/data/_internal/iterator/pipelined_iterator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
from typing import TYPE_CHECKING, Any, Callable, Iterator, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Iterator,
Optional,
Tuple,
Union,
)

from ray.data._internal.stats import DatasetStats
from ray.data.block import Block, BlockMetadata, DataBatch
Expand Down Expand Up @@ -70,7 +79,7 @@ def iter_batches(
_finalize_fn: Optional[Callable[[Any], Any]] = None,
# Deprecated.
prefetch_blocks: int = 0,
) -> Iterator[DataBatch]:
) -> Iterable[DataBatch]:
# Set prefetch_batches to default of 0 for DatasetPipeline.
return super().iter_batches(
prefetch_batches=prefetch_batches,
Expand Down
8 changes: 6 additions & 2 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2296,8 +2296,12 @@ def take_batch(
batch_format = _apply_strict_mode_batch_format(batch_format)
try:
res = next(
self.iter_batches(
batch_size=batch_size, prefetch_batches=0, batch_format=batch_format
iter(
self.iter_batches(
batch_size=batch_size,
prefetch_batches=0,
batch_format=batch_format,
)
)
)
except StopIteration:
Expand Down
Loading

0 comments on commit 0bded3b

Please sign in to comment.