From 4cf94ba1d4652e7bf8b18006f69cce404a1de2cc Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Tue, 22 Aug 2023 15:18:15 -0500 Subject: [PATCH] [Data] Add `partition_filter` parameter to `read_parquet` (#38479) If your dataset contains any non-parquet files, read_parquet errors. To workaround this, this PR adds a partition_filter parameter so users can ignore non-parquet files. Signed-off-by: Balaji Veeramani --- .../ray/data/datasource/parquet_datasource.py | 27 ++++++++++++++++--- python/ray/data/read_api.py | 5 ++++ python/ray/data/tests/test_parquet.py | 15 +++++++++++ 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/python/ray/data/datasource/parquet_datasource.py b/python/ray/data/datasource/parquet_datasource.py index 8b950a9f4bfce..1162c35da5877 100644 --- a/python/ray/data/datasource/parquet_datasource.py +++ b/python/ray/data/datasource/parquet_datasource.py @@ -10,6 +10,9 @@ from ray.data._internal.util import _check_pyarrow_version from ray.data.block import Block from ray.data.context import DataContext +from ray.data.datasource._default_metadata_providers import ( + get_generic_metadata_provider, +) from ray.data.datasource.datasource import Reader, ReadTask from ray.data.datasource.file_based_datasource import _resolve_paths_and_filesystem from ray.data.datasource.file_meta_provider import ( @@ -186,10 +189,6 @@ def __init__( import pyarrow as pa import pyarrow.parquet as pq - paths, filesystem = _resolve_paths_and_filesystem(paths, filesystem) - if len(paths) == 1: - paths = paths[0] - self._local_scheduling = None if local_uri: import ray @@ -199,6 +198,26 @@ def __init__( ray.get_runtime_context().get_node_id(), soft=False ) + paths, filesystem = _resolve_paths_and_filesystem(paths, filesystem) + + # HACK: PyArrow's `ParquetDataset` errors if input paths contain non-parquet + # files. To avoid this, we expand the input paths with the default metadata + # provider and then apply the partition filter. + partition_filter = reader_args.pop("partition_filter", None) + if partition_filter is not None: + default_meta_provider = get_generic_metadata_provider(file_extensions=None) + expanded_paths, _ = map( + list, zip(*default_meta_provider.expand_paths(paths, filesystem)) + ) + paths = partition_filter(expanded_paths) + + filtered_paths = set(expanded_paths) - set(paths) + if filtered_paths: + logger.info(f"Filtered out the following paths: {filtered_paths}") + + if len(paths) == 1: + paths = paths[0] + dataset_kwargs = reader_args.pop("dataset_kwargs", {}) try: pq_ds = pq.ParquetDataset( diff --git a/python/ray/data/read_api.py b/python/ray/data/read_api.py index ef519829e1139..daa04b61de529 100644 --- a/python/ray/data/read_api.py +++ b/python/ray/data/read_api.py @@ -556,6 +556,7 @@ def read_parquet( ray_remote_args: Dict[str, Any] = None, tensor_column_schema: Optional[Dict[str, Tuple[np.dtype, Tuple[int, ...]]]] = None, meta_provider: Optional[ParquetMetadataProvider] = None, + partition_filter: Optional[PathPartitionFilter] = None, **arrow_parquet_args, ) -> Dataset: """Creates a :class:`~ray.data.Dataset` from parquet files. @@ -658,6 +659,9 @@ def read_parquet( meta_provider: A :ref:`file metadata provider `. Custom metadata providers may be able to resolve file metadata more quickly and/or accurately. In most cases you do not need to set this parameter. + partition_filter: A + :class:`~ray.data.datasource.partitioning.PathPartitionFilter`. Use + with a custom callback to read only selected partitions of a dataset. arrow_parquet_args: Other parquet read options to pass to PyArrow. For the full set of arguments, see the`PyArrow API