Skip to content

Commit

Permalink
[Data] Add partition_filter parameter to read_parquet (ray-projec…
Browse files Browse the repository at this point in the history
…t#38479)

If your dataset contains any non-parquet files, read_parquet errors. To workaround this, this PR adds a partition_filter parameter so users can ignore non-parquet files.

Signed-off-by: Balaji Veeramani <[email protected]>
  • Loading branch information
bveeramani committed Aug 22, 2023
1 parent b60a484 commit 4cf94ba
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 4 deletions.
27 changes: 23 additions & 4 deletions python/ray/data/datasource/parquet_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from ray.data._internal.util import _check_pyarrow_version
from ray.data.block import Block
from ray.data.context import DataContext
from ray.data.datasource._default_metadata_providers import (
get_generic_metadata_provider,
)
from ray.data.datasource.datasource import Reader, ReadTask
from ray.data.datasource.file_based_datasource import _resolve_paths_and_filesystem
from ray.data.datasource.file_meta_provider import (
Expand Down Expand Up @@ -186,10 +189,6 @@ def __init__(
import pyarrow as pa
import pyarrow.parquet as pq

paths, filesystem = _resolve_paths_and_filesystem(paths, filesystem)
if len(paths) == 1:
paths = paths[0]

self._local_scheduling = None
if local_uri:
import ray
Expand All @@ -199,6 +198,26 @@ def __init__(
ray.get_runtime_context().get_node_id(), soft=False
)

paths, filesystem = _resolve_paths_and_filesystem(paths, filesystem)

# HACK: PyArrow's `ParquetDataset` errors if input paths contain non-parquet
# files. To avoid this, we expand the input paths with the default metadata
# provider and then apply the partition filter.
partition_filter = reader_args.pop("partition_filter", None)
if partition_filter is not None:
default_meta_provider = get_generic_metadata_provider(file_extensions=None)
expanded_paths, _ = map(
list, zip(*default_meta_provider.expand_paths(paths, filesystem))
)
paths = partition_filter(expanded_paths)

filtered_paths = set(expanded_paths) - set(paths)
if filtered_paths:
logger.info(f"Filtered out the following paths: {filtered_paths}")

if len(paths) == 1:
paths = paths[0]

dataset_kwargs = reader_args.pop("dataset_kwargs", {})
try:
pq_ds = pq.ParquetDataset(
Expand Down
5 changes: 5 additions & 0 deletions python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,7 @@ def read_parquet(
ray_remote_args: Dict[str, Any] = None,
tensor_column_schema: Optional[Dict[str, Tuple[np.dtype, Tuple[int, ...]]]] = None,
meta_provider: Optional[ParquetMetadataProvider] = None,
partition_filter: Optional[PathPartitionFilter] = None,
**arrow_parquet_args,
) -> Dataset:
"""Creates a :class:`~ray.data.Dataset` from parquet files.
Expand Down Expand Up @@ -658,6 +659,9 @@ def read_parquet(
meta_provider: A :ref:`file metadata provider <metadata_provider>`. Custom
metadata providers may be able to resolve file metadata more quickly and/or
accurately. In most cases you do not need to set this parameter.
partition_filter: A
:class:`~ray.data.datasource.partitioning.PathPartitionFilter`. Use
with a custom callback to read only selected partitions of a dataset.
arrow_parquet_args: Other parquet read options to pass to PyArrow. For the full
set of arguments, see the`PyArrow API <https://arrow.apache.org/docs/\
python/generated/pyarrow.dataset.Scanner.html\
Expand All @@ -681,6 +685,7 @@ def read_parquet(
columns=columns,
ray_remote_args=ray_remote_args,
meta_provider=meta_provider,
partition_filter=partition_filter,
**arrow_parquet_args,
)

Expand Down
15 changes: 15 additions & 0 deletions python/ray/data/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ray.data.datasource import (
DefaultFileMetadataProvider,
DefaultParquetMetadataProvider,
FileExtensionFilter,
)
from ray.data.datasource.file_based_datasource import _unwrap_protocol
from ray.data.datasource.parquet_base_datasource import ParquetBaseDatasource
Expand Down Expand Up @@ -756,6 +757,20 @@ def test_parquet_write(ray_start_regular_shared, fs, data_path, endpoint_url):
fs.delete_dir(_unwrap_protocol(path))


def test_parquet_partition_filter(ray_start_regular_shared, tmp_path):
table = pa.table({"food": ["spam", "ham", "eggs"]})
pq.write_table(table, tmp_path / "table.parquet")
# `spam` should be filtered out.
with open(tmp_path / "spam", "w"):
pass

ds = ray.data.read_parquet(
tmp_path, partition_filter=FileExtensionFilter("parquet")
)

assert ds.count() == 3


@pytest.mark.parametrize(
"fs,data_path,endpoint_url",
[
Expand Down

0 comments on commit 4cf94ba

Please sign in to comment.