Skip to content

Commit

Permalink
[Data] Add support for shuffling input files (#40154)
Browse files Browse the repository at this point in the history
This PR is to add support for shuffling input files ordering, for all file-based data sources. The interface for controlling the behavior is through `shuffle` argument in all read APIs for file-based data sources:

```py
# Enable input files shuffling with default seed
ds = ray.data.read_parquet(..., shuffle="files")
ds = ray.data.read_images(..., shuffle="files")
```

Signed-off-by: Cheng Su <[email protected]>
  • Loading branch information
c21 committed Oct 13, 2023
1 parent 9eb4416 commit ba6ae3e
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 38 deletions.
27 changes: 18 additions & 9 deletions python/ray/data/datasource/file_based_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from ray.data._internal.util import (
_check_pyarrow_version,
_resolve_custom_scheme,
get_attribute_from_class_name,
make_async_gen,
)
from ray.data.block import Block, BlockAccessor
Expand All @@ -46,6 +45,11 @@
)
from ray.util.annotations import DeveloperAPI, PublicAPI

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal

if TYPE_CHECKING:
import pandas as pd
import pyarrow
Expand Down Expand Up @@ -471,6 +475,7 @@ def __init__(
partition_filter: PathPartitionFilter = None,
partitioning: Partitioning = None,
ignore_missing_paths: bool = False,
shuffle: Union[Literal["files"], None] = None,
**reader_args,
):
_check_pyarrow_version()
Expand Down Expand Up @@ -511,10 +516,9 @@ def __init__(
"No input files found to read. Please double check that "
"'partition_filter' field is set properly."
)

ctx = DataContext.get_current()
shuffler_class = get_attribute_from_class_name(ctx.file_metadata_shuffler)
self._file_metadata_shuffler = shuffler_class(self._reader_args)
self._file_metadata_shuffler = None
if shuffle == "files":
self._file_metadata_shuffler = np.random.default_rng()

def estimate_inmemory_data_size(self) -> Optional[int]:
total_size = 0
Expand All @@ -531,10 +535,15 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
reader_args = self._reader_args
partitioning = self._partitioning

paths_and_sizes = self._file_metadata_shuffler.shuffle_files(
list(zip(self._paths, self._file_sizes))
)
paths, file_sizes = list(map(list, zip(*paths_and_sizes)))
if self._file_metadata_shuffler is not None:
files_metadata = list(zip(self._paths, self._file_sizes))
shuffled_files_metadata = [
files_metadata[i]
for i in self._file_metadata_shuffler.permutation(len(files_metadata))
]
paths, file_sizes = list(map(list, zip(*shuffled_files_metadata)))
else:
paths, file_sizes = self._paths, self._file_sizes

read_stream = self._delegate._read_stream
filesystem = _wrap_s3_serialization_workaround(self._filesystem)
Expand Down
1 change: 1 addition & 0 deletions python/ray/data/datasource/file_meta_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def _get_block_metadata(
if (
prefetched_metadata is not None
and len(prefetched_metadata) == num_fragments
and all(m is not None for m in prefetched_metadata)
):
# Fragment metadata was available, construct a normal
# BlockMetadata.
Expand Down
51 changes: 25 additions & 26 deletions python/ray/data/datasource/file_metadata_shuffler.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,33 @@
from typing import Any, Dict, List, Tuple
import sys
from typing import Any, List, Union

import numpy as np

class FileMetadataShuffler:
"""Abstract class for file metadata shuffler.
Shufflers live on the driver side of the Dataset only.
"""
if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal

def __init__(self, reader_args: Dict[str, Any]):
self._reader_args = reader_args

def shuffle_files(
self,
paths_and_sizes: List[Tuple[str, int]],
) -> List[Tuple[str, int]]:
"""Shuffle files in the given paths and sizes.
Args:
paths_and_sizes: The file paths and file sizes to shuffle.
Returns:
The file paths and their sizes after shuffling.
"""
raise NotImplementedError
class FileMetadataShuffler:
"""Random shuffle file metadata when the `shuffle` parameter enables it.
Otherwise returns file metadata in its original order.
"""

def __init__(self, shuffle: Union[Literal["files"], None]):
self._is_shuffle_enabled = False
if shuffle == "files":
self._is_shuffle_enabled = True
self._generator = np.random.default_rng()

class SequentialFileMetadataShuffler(FileMetadataShuffler):
def shuffle_files(
self,
paths_and_sizes: List[Tuple[str, int]],
) -> List[Tuple[str, int]]:
"""Return files in the given paths and sizes sequentially."""
return paths_and_sizes
files_metadata: List[Any],
) -> List[Any]:
if self._is_shuffle_enabled:
return [
files_metadata[i]
for i in self._generator.permutation(len(files_metadata))
]
else:
return files_metadata
38 changes: 35 additions & 3 deletions python/ray/data/datasource/parquet_datasource.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import sys
from typing import TYPE_CHECKING, Callable, Iterator, List, Optional, Union

import numpy as np
Expand All @@ -22,6 +23,11 @@
from ray.data.datasource.parquet_base_datasource import ParquetBaseDatasource
from ray.util.annotations import PublicAPI

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal

if TYPE_CHECKING:
import pyarrow
from pyarrow.dataset import ParquetFileFragment
Expand Down Expand Up @@ -182,6 +188,7 @@ def __init__(
schema: Optional[Union[type, "pyarrow.lib.Schema"]] = None,
meta_provider: ParquetMetadataProvider = DefaultParquetMetadataProvider(),
_block_udf: Optional[Callable[[Block], Block]] = None,
shuffle: Union[Literal["files"], None] = None,
**reader_args,
):
_check_pyarrow_version()
Expand Down Expand Up @@ -275,6 +282,9 @@ def __init__(
self._columns = columns
self._schema = schema
self._encoding_ratio = self._estimate_files_encoding_ratio()
self._file_metadata_shuffler = None
if shuffle == "files":
self._file_metadata_shuffler = np.random.default_rng()

def estimate_inmemory_data_size(self) -> Optional[int]:
total_size = 0
Expand All @@ -289,11 +299,33 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
# method in order to leverage pyarrow's ParquetDataset abstraction,
# which simplifies partitioning logic. We still use
# FileBasedDatasource's write side (do_write), however.
pq_metadata = self._metadata
if len(pq_metadata) < len(self._pq_fragments):
# Pad `pq_metadata` to be same length of `self._pq_fragments`.
# This can happen when no file metadata being prefetched.
pq_metadata += [None] * (len(self._pq_fragments) - len(pq_metadata))

if self._file_metadata_shuffler is not None:
files_metadata = list(zip(self._pq_fragments, self._pq_paths, pq_metadata))
shuffled_files_metadata = [
files_metadata[i]
for i in self._file_metadata_shuffler.permutation(len(files_metadata))
]
pq_fragments, pq_paths, pq_metadata = list(
map(list, zip(*shuffled_files_metadata))
)
else:
pq_fragments, pq_paths, pq_metadata = (
self._pq_fragments,
self._pq_paths,
pq_metadata,
)

read_tasks = []
for fragments, paths, metadata in zip(
np.array_split(self._pq_fragments, parallelism),
np.array_split(self._pq_paths, parallelism),
np.array_split(self._metadata, parallelism),
np.array_split(pq_fragments, parallelism),
np.array_split(pq_paths, parallelism),
np.array_split(pq_metadata, parallelism),
):
if len(fragments) <= 0:
continue
Expand Down
Loading

0 comments on commit ba6ae3e

Please sign in to comment.