Skip to content

Commit

Permalink
[Data] Replace deprecated .pieces with updated .fragments (ray-pr…
Browse files Browse the repository at this point in the history
…oject#39523)

ParquetDataset.pieces has been deprecated since PyArrow 5. This PR replaces it with the updated .fragments attribute.

---------

Signed-off-by: Balaji Veeramani <[email protected]>
  • Loading branch information
bveeramani committed Sep 12, 2023
1 parent 67d7943 commit 4f39207
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 85 deletions.
39 changes: 21 additions & 18 deletions python/ray/data/datasource/file_meta_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _get_block_metadata(
paths: List[str],
schema: Optional[Union[type, "pyarrow.lib.Schema"]],
*,
num_pieces: int,
num_fragments: int,
prefetched_metadata: Optional[List[Any]],
) -> BlockMetadata:
"""Resolves and returns block metadata for files of a single dataset block.
Expand All @@ -229,11 +229,11 @@ def _get_block_metadata(
paths: The file paths for a single dataset block.
schema: The user-provided or inferred schema for the given file
paths, if any.
num_pieces: The number of Parquet file fragments derived from the input
num_fragments: The number of Parquet file fragments derived from the input
file paths.
prefetched_metadata: Metadata previously returned from
`prefetch_file_metadata()` for each file fragment, where
`prefetched_metadata[i]` contains the metadata for `pieces[i]`.
`prefetched_metadata[i]` contains the metadata for `fragments[i]`.
Returns:
BlockMetadata aggregated across the given file paths.
Expand All @@ -242,7 +242,7 @@ def _get_block_metadata(

def prefetch_file_metadata(
self,
pieces: List["pyarrow.dataset.ParquetFileFragment"],
fragments: List["pyarrow.dataset.ParquetFileFragment"],
**ray_remote_args,
) -> Optional[List[Any]]:
"""Pre-fetches file metadata for all Parquet file fragments in a single batch.
Expand All @@ -255,12 +255,12 @@ def prefetch_file_metadata(
override this method.
Args:
pieces: The Parquet file fragments to fetch metadata for.
fragments: The Parquet file fragments to fetch metadata for.
Returns:
Metadata resolved for each input file fragment, or `None`. Metadata
must be returned in the same order as all input file fragments, such
that `metadata[i]` always contains the metadata for `pieces[i]`.
that `metadata[i]` always contains the metadata for `fragments[i]`.
"""
return None

Expand All @@ -278,11 +278,14 @@ def _get_block_metadata(
paths: List[str],
schema: Optional[Union[type, "pyarrow.lib.Schema"]],
*,
num_pieces: int,
num_fragments: int,
prefetched_metadata: Optional[List["pyarrow.parquet.FileMetaData"]],
) -> BlockMetadata:
if prefetched_metadata is not None and len(prefetched_metadata) == num_pieces:
# Piece metadata was available, construct a normal
if (
prefetched_metadata is not None
and len(prefetched_metadata) == num_fragments
):
# Fragment metadata was available, construct a normal
# BlockMetadata.
block_metadata = BlockMetadata(
num_rows=sum(m.num_rows for m in prefetched_metadata),
Expand All @@ -295,7 +298,7 @@ def _get_block_metadata(
exec_stats=None,
) # Exec stats filled in later.
else:
# Piece metadata was not available, construct an empty
# Fragment metadata was not available, construct an empty
# BlockMetadata.
block_metadata = BlockMetadata(
num_rows=None,
Expand All @@ -308,32 +311,32 @@ def _get_block_metadata(

def prefetch_file_metadata(
self,
pieces: List["pyarrow.dataset.ParquetFileFragment"],
fragments: List["pyarrow.dataset.ParquetFileFragment"],
**ray_remote_args,
) -> Optional[List["pyarrow.parquet.FileMetaData"]]:
from ray.data.datasource.file_based_datasource import _fetch_metadata_parallel
from ray.data.datasource.parquet_datasource import (
FRAGMENTS_PER_META_FETCH,
PARALLELIZE_META_FETCH_THRESHOLD,
PIECES_PER_META_FETCH,
_fetch_metadata,
_fetch_metadata_serialization_wrapper,
_SerializedPiece,
_SerializedFragment,
)

if len(pieces) > PARALLELIZE_META_FETCH_THRESHOLD:
if len(fragments) > PARALLELIZE_META_FETCH_THRESHOLD:
# Wrap Parquet fragments in serialization workaround.
pieces = [_SerializedPiece(piece) for piece in pieces]
fragments = [_SerializedFragment(fragment) for fragment in fragments]
# Fetch Parquet metadata in parallel using Ray tasks.
return list(
_fetch_metadata_parallel(
pieces,
fragments,
_fetch_metadata_serialization_wrapper,
PIECES_PER_META_FETCH,
FRAGMENTS_PER_META_FETCH,
**ray_remote_args,
)
)
else:
return _fetch_metadata(pieces)
return _fetch_metadata(fragments)


def _handle_read_os_error(error: OSError, paths: Union[str, List[str]]) -> str:
Expand Down
100 changes: 50 additions & 50 deletions python/ray/data/datasource/parquet_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

logger = logging.getLogger(__name__)

PIECES_PER_META_FETCH = 6
FRAGMENTS_PER_META_FETCH = 6
PARALLELIZE_META_FETCH_THRESHOLD = 24

# The number of rows to read per batch. This is sized to generate 10MiB batches
Expand Down Expand Up @@ -69,7 +69,7 @@

# TODO(ekl) this is a workaround for a pyarrow serialization bug, where serializing a
# raw pyarrow file fragment causes S3 network calls.
class _SerializedPiece:
class _SerializedFragment:
def __init__(self, frag: "ParquetFileFragment"):
self._data = cloudpickle.dumps(
(frag.format, frag.path, frag.filesystem, frag.partition_expression)
Expand All @@ -87,10 +87,10 @@ def deserialize(self) -> "ParquetFileFragment":


# Visible for test mocking.
def _deserialize_pieces(
serialized_pieces: List[_SerializedPiece],
def _deserialize_fragments(
serialized_fragments: List[_SerializedFragment],
) -> List["pyarrow._dataset.ParquetFileFragment"]:
return [p.deserialize() for p in serialized_pieces]
return [p.deserialize() for p in serialized_fragments]


# This retry helps when the upstream datasource is not able to handle
Expand All @@ -100,14 +100,14 @@ def _deserialize_pieces(
# simutaneously running many hyper parameter tuning jobs
# with ray.data parallelism setting at high value like the default 200
# Such connection failure can be restored with some waiting and retry.
def _deserialize_pieces_with_retry(
serialized_pieces: List[_SerializedPiece],
def _deserialize_fragments_with_retry(
serialized_fragments: List[_SerializedFragment],
) -> List["pyarrow._dataset.ParquetFileFragment"]:
min_interval = 0
final_exception = None
for i in range(FILE_READING_RETRY):
try:
return _deserialize_pieces(serialized_pieces)
return _deserialize_fragments(serialized_fragments)
except Exception as e:
import random
import time
Expand All @@ -123,7 +123,7 @@ def _deserialize_pieces_with_retry(
else (
f"If earlier read attempt threw certain Exception"
f", it may or may not be an issue depends on these retries "
f"succeed or not. serialized_pieces:{serialized_pieces}"
f"succeed or not. serialized_fragments:{serialized_fragments}"
)
)
logger.exception(
Expand Down Expand Up @@ -256,7 +256,7 @@ def __init__(
prefetch_remote_args["scheduling_strategy"] = self._local_scheduling
self._metadata = (
meta_provider.prefetch_file_metadata(
pq_ds.pieces, **prefetch_remote_args
pq_ds.fragments, **prefetch_remote_args
)
or []
)
Expand All @@ -265,9 +265,9 @@ def __init__(

# NOTE: Store the custom serialized `ParquetFileFragment` to avoid unexpected
# network calls when `_ParquetDatasourceReader` is serialized. See
# `_SerializedPiece()` implementation for more details.
self._pq_pieces = [_SerializedPiece(p) for p in pq_ds.pieces]
self._pq_paths = [p.path for p in pq_ds.pieces]
# `_SerializedFragment()` implementation for more details.
self._pq_fragments = [_SerializedFragment(p) for p in pq_ds.fragments]
self._pq_paths = [p.path for p in pq_ds.fragments]
self._meta_provider = meta_provider
self._inferred_schema = inferred_schema
self._block_udf = _block_udf
Expand All @@ -290,18 +290,18 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
# which simplifies partitioning logic. We still use
# FileBasedDatasource's write side (do_write), however.
read_tasks = []
for pieces, paths, metadata in zip(
np.array_split(self._pq_pieces, parallelism),
for fragments, paths, metadata in zip(
np.array_split(self._pq_fragments, parallelism),
np.array_split(self._pq_paths, parallelism),
np.array_split(self._metadata, parallelism),
):
if len(pieces) <= 0:
if len(fragments) <= 0:
continue

meta = self._meta_provider(
paths,
self._inferred_schema,
num_pieces=len(pieces),
num_fragments=len(fragments),
prefetched_metadata=metadata,
)
# If there is a filter operation, reset the calculated row count,
Expand Down Expand Up @@ -338,13 +338,13 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
)
read_tasks.append(
ReadTask(
lambda p=pieces: _read_pieces(
lambda f=fragments: _read_fragments(
block_udf,
reader_args,
default_read_batch_size,
columns,
schema,
p,
f,
),
meta,
)
Expand All @@ -364,7 +364,7 @@ def _estimate_files_encoding_ratio(self) -> float:
# Launch tasks to sample multiple files remotely in parallel.
# Evenly distributed to sample N rows in i-th row group in i-th file.
# TODO(ekl/cheng) take into account column pruning.
num_files = len(self._pq_pieces)
num_files = len(self._pq_fragments)
num_samples = int(num_files * PARQUET_ENCODING_RATIO_ESTIMATE_SAMPLING_RATIO)
min_num_samples = min(
PARQUET_ENCODING_RATIO_ESTIMATE_MIN_NUM_SAMPLES, num_files
Expand All @@ -377,19 +377,19 @@ def _estimate_files_encoding_ratio(self) -> float:
# Evenly distributed to choose which file to sample, to avoid biased prediction
# if data is skewed.
file_samples = [
self._pq_pieces[idx]
self._pq_fragments[idx]
for idx in np.linspace(0, num_files - 1, num_samples).astype(int).tolist()
]

sample_piece = cached_remote_fn(_sample_piece)
sample_fragment = cached_remote_fn(_sample_fragment)
futures = []
scheduling = self._local_scheduling or "SPREAD"
for sample in file_samples:
# Sample the first rows batch in i-th file.
# Use SPREAD scheduling strategy to avoid packing many sampling tasks on
# same machine to cause OOM issue, as sampling can be memory-intensive.
futures.append(
sample_piece.options(scheduling_strategy=scheduling).remote(
sample_fragment.options(scheduling_strategy=scheduling).remote(
self._reader_args,
self._columns,
self._schema,
Expand All @@ -404,34 +404,34 @@ def _estimate_files_encoding_ratio(self) -> float:
return max(ratio, PARQUET_ENCODING_RATIO_ESTIMATE_LOWER_BOUND)


def _read_pieces(
def _read_fragments(
block_udf,
reader_args,
default_read_batch_size,
columns,
schema,
serialized_pieces: List[_SerializedPiece],
serialized_fragments: List[_SerializedFragment],
) -> Iterator["pyarrow.Table"]:
# This import is necessary to load the tensor extension type.
from ray.data.extensions.tensor_extension import ArrowTensorType # noqa

# Deserialize after loading the filesystem class.
pieces: List[
fragments: List[
"pyarrow._dataset.ParquetFileFragment"
] = _deserialize_pieces_with_retry(serialized_pieces)
] = _deserialize_fragments_with_retry(serialized_fragments)

# Ensure that we're reading at least one dataset fragment.
assert len(pieces) > 0
assert len(fragments) > 0

import pyarrow as pa
from pyarrow.dataset import _get_partition_keys

logger.debug(f"Reading {len(pieces)} parquet pieces")
logger.debug(f"Reading {len(fragments)} parquet fragments")
use_threads = reader_args.pop("use_threads", False)
batch_size = reader_args.pop("batch_size", default_read_batch_size)
for piece in pieces:
part = _get_partition_keys(piece.partition_expression)
batches = piece.to_batches(
for fragment in fragments:
part = _get_partition_keys(fragment.partition_expression)
batches = fragment.to_batches(
use_threads=use_threads,
columns=columns,
schema=schema,
Expand All @@ -456,46 +456,46 @@ def _read_pieces(


def _fetch_metadata_serialization_wrapper(
pieces: _SerializedPiece,
fragments: List[_SerializedFragment],
) -> List["pyarrow.parquet.FileMetaData"]:
pieces: List[
fragments: List[
"pyarrow._dataset.ParquetFileFragment"
] = _deserialize_pieces_with_retry(pieces)
] = _deserialize_fragments_with_retry(fragments)

return _fetch_metadata(pieces)
return _fetch_metadata(fragments)


def _fetch_metadata(
pieces: List["pyarrow.dataset.ParquetFileFragment"],
fragments: List["pyarrow.dataset.ParquetFileFragment"],
) -> List["pyarrow.parquet.FileMetaData"]:
piece_metadata = []
for p in pieces:
fragment_metadata = []
for f in fragments:
try:
piece_metadata.append(p.metadata)
fragment_metadata.append(f.metadata)
except AttributeError:
break
return piece_metadata
return fragment_metadata


def _sample_piece(
def _sample_fragment(
reader_args,
columns,
schema,
file_piece: _SerializedPiece,
file_fragment: _SerializedFragment,
) -> float:
# Sample the first rows batch from file piece `serialized_piece`.
# Sample the first rows batch from file fragment `serialized_fragment`.
# Return the encoding ratio calculated from the sampled rows.
piece = _deserialize_pieces_with_retry([file_piece])[0]
fragment = _deserialize_fragments_with_retry([file_fragment])[0]

# Only sample the first row group.
piece = piece.subset(row_group_ids=[0])
fragment = fragment.subset(row_group_ids=[0])
batch_size = max(
min(piece.metadata.num_rows, PARQUET_ENCODING_RATIO_ESTIMATE_NUM_ROWS), 1
min(fragment.metadata.num_rows, PARQUET_ENCODING_RATIO_ESTIMATE_NUM_ROWS), 1
)
# Use the batch_size calculated above, and ignore the one specified by user if set.
# This is to avoid sampling too few or too many rows.
reader_args.pop("batch_size", None)
batches = piece.to_batches(
batches = fragment.to_batches(
columns=columns,
schema=schema,
batch_size=batch_size,
Expand All @@ -509,7 +509,7 @@ def _sample_piece(
else:
if batch.num_rows > 0:
in_memory_size = batch.nbytes / batch.num_rows
metadata = piece.metadata
metadata = fragment.metadata
total_size = 0
for idx in range(metadata.num_row_groups):
total_size += metadata.row_group(idx).total_byte_size
Expand All @@ -518,7 +518,7 @@ def _sample_piece(
else:
ratio = PARQUET_ENCODING_RATIO_ESTIMATE_LOWER_BOUND
logger.debug(
f"Estimated Parquet encoding ratio is {ratio} for piece {piece} "
f"Estimated Parquet encoding ratio is {ratio} for fragment {fragment} "
f"with batch size {batch_size}."
)
return ratio
Loading

0 comments on commit 4f39207

Please sign in to comment.