Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CI] Format Python code with Black #21975

Merged
merged 12 commits into from
Jan 30, 2022
Prev Previous commit
Format parquet_datasource.py
  • Loading branch information
bveeramani committed Jan 29, 2022
commit ec91cdfa5c139fbb19f21a8fdfe2c476085977d6
120 changes: 65 additions & 55 deletions python/ray/data/datasource/parquet_datasource.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
import itertools
from typing import Any, Callable, Dict, Optional, List, Union, \
Iterator, TYPE_CHECKING
from typing import Any, Callable, Dict, Optional, List, Union, Iterator, TYPE_CHECKING

import numpy as np

Expand All @@ -13,7 +12,10 @@
from ray.data.context import DatasetContext
from ray.data.datasource.datasource import ReadTask
from ray.data.datasource.file_based_datasource import (
FileBasedDatasource, _resolve_paths_and_filesystem, _resolve_kwargs)
FileBasedDatasource,
_resolve_paths_and_filesystem,
_resolve_kwargs,
)
from ray.data.impl.block_list import BlockMetadata
from ray.data.impl.output_buffer import BlockOutputBuffer
from ray.data.impl.progress_bar import ProgressBar
Expand All @@ -40,16 +42,16 @@ class ParquetDatasource(FileBasedDatasource):
"""

def prepare_read(
self,
parallelism: int,
paths: Union[str, List[str]],
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
columns: Optional[List[str]] = None,
schema: Optional[Union[type, "pyarrow.lib.Schema"]] = None,
_block_udf: Optional[Callable[[Block], Block]] = None,
**reader_args) -> List[ReadTask]:
"""Creates and returns read tasks for a Parquet file-based datasource.
"""
self,
parallelism: int,
paths: Union[str, List[str]],
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
columns: Optional[List[str]] = None,
schema: Optional[Union[type, "pyarrow.lib.Schema"]] = None,
_block_udf: Optional[Callable[[Block], Block]] = None,
**reader_args,
) -> List[ReadTask]:
"""Creates and returns read tasks for a Parquet file-based datasource."""
# NOTE: We override the base class FileBasedDatasource.prepare_read
# method in order to leverage pyarrow's ParquetDataset abstraction,
# which simplifies partitioning logic. We still use
Expand All @@ -66,24 +68,24 @@ def prepare_read(

dataset_kwargs = reader_args.pop("dataset_kwargs", {})
pq_ds = pq.ParquetDataset(
paths,
**dataset_kwargs,
filesystem=filesystem,
use_legacy_dataset=False)
paths, **dataset_kwargs, filesystem=filesystem, use_legacy_dataset=False
)
if schema is None:
schema = pq_ds.schema
if columns:
schema = pa.schema([schema.field(column) for column in columns],
schema.metadata)
schema = pa.schema(
[schema.field(column) for column in columns], schema.metadata
)

def read_pieces(serialized_pieces: str) -> Iterator[pa.Table]:
# Implicitly trigger S3 subsystem initialization by importing
# pyarrow.fs.
import pyarrow.fs # noqa: F401

# Deserialize after loading the filesystem class.
pieces: List["pyarrow._dataset.ParquetFileFragment"] = \
cloudpickle.loads(serialized_pieces)
pieces: List["pyarrow._dataset.ParquetFileFragment"] = cloudpickle.loads(
serialized_pieces
)

# Ensure that we're reading at least one dataset fragment.
assert len(pieces) > 0
Expand All @@ -92,8 +94,8 @@ def read_pieces(serialized_pieces: str) -> Iterator[pa.Table]:

ctx = DatasetContext.get_current()
output_buffer = BlockOutputBuffer(
block_udf=_block_udf,
target_max_block_size=ctx.target_max_block_size)
block_udf=_block_udf, target_max_block_size=ctx.target_max_block_size
)

logger.debug(f"Reading {len(pieces)} parquet pieces")
use_threads = reader_args.pop("use_threads", False)
Expand All @@ -104,14 +106,17 @@ def read_pieces(serialized_pieces: str) -> Iterator[pa.Table]:
columns=columns,
schema=schema,
batch_size=PARQUET_READER_ROW_BATCH_SIZE,
**reader_args)
**reader_args,
)
for batch in batches:
table = pyarrow.Table.from_batches([batch], schema=schema)
if part:
for col, value in part.items():
table = table.set_column(
table.schema.get_field_index(col), col,
pa.array([value] * len(table)))
table.schema.get_field_index(col),
col,
pa.array([value] * len(table)),
)
# If the table is empty, drop it.
if table.num_rows > 0:
output_buffer.add_block(table)
Expand All @@ -126,13 +131,13 @@ def read_pieces(serialized_pieces: str) -> Iterator[pa.Table]:
dummy_table = schema.empty_table()
try:
inferred_schema = _block_udf(dummy_table).schema
inferred_schema = inferred_schema.with_metadata(
schema.metadata)
inferred_schema = inferred_schema.with_metadata(schema.metadata)
except Exception:
logger.debug(
"Failed to infer schema of dataset by passing dummy table "
"through UDF due to the following exception:",
exc_info=True)
exc_info=True,
)
inferred_schema = schema
else:
inferred_schema = schema
Expand All @@ -142,22 +147,26 @@ def read_pieces(serialized_pieces: str) -> Iterator[pa.Table]:
else:
metadata = _fetch_metadata(pq_ds.pieces)
for piece_data in np.array_split(
list(zip(pq_ds.pieces, metadata)), parallelism):
list(zip(pq_ds.pieces, metadata)), parallelism
):
if len(piece_data) == 0:
continue
pieces, metadata = zip(*piece_data)
serialized_pieces = cloudpickle.dumps(pieces)
meta = _build_block_metadata(pieces, metadata, inferred_schema)
read_tasks.append(
ReadTask(lambda p=serialized_pieces: read_pieces(p), meta))
ReadTask(lambda p=serialized_pieces: read_pieces(p), meta)
)

return read_tasks

def _write_block(self,
f: "pyarrow.NativeFile",
block: BlockAccessor,
writer_args_fn: Callable[[], Dict[str, Any]] = lambda: {},
**writer_args):
def _write_block(
self,
f: "pyarrow.NativeFile",
block: BlockAccessor,
writer_args_fn: Callable[[], Dict[str, Any]] = lambda: {},
**writer_args,
):
import pyarrow.parquet as pq

writer_args = _resolve_kwargs(writer_args_fn, **writer_args)
Expand All @@ -168,11 +177,11 @@ def _file_format(self) -> str:


def _fetch_metadata_remotely(
pieces: List["pyarrow._dataset.ParquetFileFragment"]
pieces: List["pyarrow._dataset.ParquetFileFragment"],
) -> List[ObjectRef["pyarrow.parquet.FileMetaData"]]:
from ray import cloudpickle
remote_fetch_metadata = cached_remote_fn(
_fetch_metadata_serialization_wrapper)

remote_fetch_metadata = cached_remote_fn(_fetch_metadata_serialization_wrapper)
metas = []
parallelism = min(len(pieces) // PIECES_PER_META_FETCH, 100)
meta_fetch_bar = ProgressBar("Metadata Fetch Progress", total=parallelism)
Expand All @@ -185,21 +194,22 @@ def _fetch_metadata_remotely(


def _fetch_metadata_serialization_wrapper(
pieces: str) -> List["pyarrow.parquet.FileMetaData"]:
pieces: str,
) -> List["pyarrow.parquet.FileMetaData"]:
# Implicitly trigger S3 subsystem initialization by importing
# pyarrow.fs.
import pyarrow.fs # noqa: F401
from ray import cloudpickle

# Deserialize after loading the filesystem class.
pieces: List["pyarrow._dataset.ParquetFileFragment"] = \
cloudpickle.loads(pieces)
pieces: List["pyarrow._dataset.ParquetFileFragment"] = cloudpickle.loads(pieces)

return _fetch_metadata(pieces)


def _fetch_metadata(pieces: List["pyarrow.dataset.ParquetFileFragment"]
) -> List["pyarrow.parquet.FileMetaData"]:
def _fetch_metadata(
pieces: List["pyarrow.dataset.ParquetFileFragment"],
) -> List["pyarrow.parquet.FileMetaData"]:
piece_metadata = []
for p in pieces:
try:
Expand All @@ -210,28 +220,28 @@ def _fetch_metadata(pieces: List["pyarrow.dataset.ParquetFileFragment"]


def _build_block_metadata(
pieces: List["pyarrow.dataset.ParquetFileFragment"],
metadata: List["pyarrow.parquet.FileMetaData"],
schema: Optional[Union[type, "pyarrow.lib.Schema"]]) -> BlockMetadata:
pieces: List["pyarrow.dataset.ParquetFileFragment"],
metadata: List["pyarrow.parquet.FileMetaData"],
schema: Optional[Union[type, "pyarrow.lib.Schema"]],
) -> BlockMetadata:
input_files = [p.path for p in pieces]
if len(metadata) == len(pieces):
# Piece metadata was available, construct a normal
# BlockMetadata.
block_metadata = BlockMetadata(
num_rows=sum(m.num_rows for m in metadata),
size_bytes=sum(
sum(
m.row_group(i).total_byte_size
for i in range(m.num_row_groups)) for m in metadata),
sum(m.row_group(i).total_byte_size for i in range(m.num_row_groups))
for m in metadata
),
schema=schema,
input_files=input_files,
exec_stats=None) # Exec stats filled in later.
exec_stats=None,
) # Exec stats filled in later.
else:
# Piece metadata was not available, construct an empty
# BlockMetadata.
block_metadata = BlockMetadata(
num_rows=None,
size_bytes=None,
schema=schema,
input_files=input_files)
num_rows=None, size_bytes=None, schema=schema, input_files=input_files
)
return block_metadata