Skip to content

Commit

Permalink
[Datasets] [Arrow 7+ Support] [2/N] Add support for Arrow 7 by fixing…
Browse files Browse the repository at this point in the history
… Arrow serialization bug. (#29993)

This PR adds support for Arrow 7 in Ray, and is the second PR in a set of stacked PRs making up this mono-PR for Arrow 7+ support: #29161, and is stacked on top of a PR fixing task cancellation in Ray Core: #29984.

This PR:
- fixes a serialization bug in Arrow with a custom serializer for Arrow data ([Datasets] Arrow data buffers aren't truncated when pickling zero-copy slice views, leading to huge serialization bloat #29814)
- removes a bunch of defensive copying of Arrow data, which was a workaround for the aforementioned Arrow serialization bug
- adds a CI job for Arrow 7
- bumps the pyarrow upper bound to 8.0.0
  • Loading branch information
clarkzinzow authored Nov 8, 2022
1 parent 60cde11 commit 75b206e
Show file tree
Hide file tree
Showing 33 changed files with 606 additions and 158 deletions.
15 changes: 13 additions & 2 deletions .buildkite/pipeline.ml.yml
Original file line number Diff line number Diff line change
Expand Up @@ -269,14 +269,25 @@
# Dask tests and examples.
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=-client python/ray/util/dask/...

- label: "Dataset tests"
- label: "Dataset tests (Arrow 7)"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_PYTHON_AFFECTED", "RAY_CI_DATA_AFFECTED"]
instance_size: medium
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- DATA_PROCESSING_TESTING=1 ./ci/env/install-dependencies.sh
- DATA_PROCESSING_TESTING=1 ARROW_VERSION=7.* ./ci/env/install-dependencies.sh
- ./ci/env/env_info.sh
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only python/ray/data/...
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=ray_data python/ray/air/...

- label: "Dataset tests (Arrow 6)"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_PYTHON_AFFECTED", "RAY_CI_DATA_AFFECTED"]
instance_size: medium
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- DATA_PROCESSING_TESTING=1 ARROW_VERSION=6.* ./ci/env/install-dependencies.sh
- ./ci/env/env_info.sh
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only python/ray/data/...
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=ray_data python/ray/air/...

- label: "Workflow tests"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_PYTHON_AFFECTED", "RAY_CI_WORKFLOW_AFFECTED"]
Expand Down
7 changes: 7 additions & 0 deletions ci/env/install-dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,13 @@ install_pip_packages() {
fi
if [ "${DATA_PROCESSING_TESTING-}" = 1 ]; then
pip install -U -c "${WORKSPACE_DIR}"/python/requirements.txt -r "${WORKSPACE_DIR}"/python/requirements/data_processing/requirements_dataset.txt
if [ -n "${ARROW_VERSION-}" ]; then
if [ "${ARROW_VERSION-}" = nightly ]; then
pip install --extra-index-url https://pypi.fury.io/arrow-nightlies/ --prefer-binary --pre pyarrow
else
pip install -U pyarrow=="${ARROW_VERSION}"
fi
fi
fi

# Remove this entire section once Serve dependencies are fixed.
Expand Down
18 changes: 18 additions & 0 deletions python/ray/_private/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@


ENV_DISABLE_DOCKER_CPU_WARNING = "RAY_DISABLE_DOCKER_CPU_WARNING" in os.environ
_PYARROW_VERSION = None


def get_user_temp_dir():
Expand Down Expand Up @@ -1596,3 +1597,20 @@ def split_address(address: str) -> Tuple[str, str]:

module_string, inner_address = address.split(":https://", maxsplit=1)
return (module_string, inner_address)


def _get_pyarrow_version() -> Optional[str]:
"""Get the version of the installed pyarrow package, returned as a tuple of ints.
Returns None if the package is not found.
"""
global _PYARROW_VERSION
if _PYARROW_VERSION is None:
try:
import pyarrow
except ModuleNotFoundError:
# pyarrow not installed, short-circuit.
pass
else:
if hasattr(pyarrow, "__version__"):
_PYARROW_VERSION = pyarrow.__version__
return _PYARROW_VERSION
4 changes: 2 additions & 2 deletions python/ray/air/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ py_test(
name = "test_data_batch_conversion",
size = "small",
srcs = ["tests/test_data_batch_conversion.py"],
tags = ["team:ml", "exclusive"],
tags = ["team:ml", "exclusive", "ray_data"],
deps = [":ml_lib"]
)

Expand Down Expand Up @@ -110,7 +110,7 @@ py_test(
name = "test_tensor_extension",
size = "small",
srcs = ["tests/test_tensor_extension.py"],
tags = ["team:ml", "exclusive"],
tags = ["team:ml", "exclusive", "ray_data"],
deps = [":ml_lib"]
)

Expand Down
63 changes: 61 additions & 2 deletions python/ray/air/tests/test_tensor_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,14 +331,17 @@ def test_tensor_array_reductions():
np.testing.assert_equal(df["two"].agg(name), reducer(arr, axis=0, **np_kwargs))


def test_arrow_tensor_array_getitem():
@pytest.mark.parametrize("chunked", [False, True])
def test_arrow_tensor_array_getitem(chunked):
outer_dim = 3
inner_shape = (2, 2, 2)
shape = (outer_dim,) + inner_shape
num_items = np.prod(np.array(shape))
arr = np.arange(num_items).reshape(shape)

t_arr = ArrowTensorArray.from_numpy(arr)
if chunked:
t_arr = pa.chunked_array(t_arr)

for idx in range(outer_dim):
np.testing.assert_array_equal(t_arr[idx], arr[idx])
Expand All @@ -352,8 +355,64 @@ def test_arrow_tensor_array_getitem():

# Test slicing and indexing.
t_arr2 = t_arr[1:]
if chunked:
# For extension arrays, ChunkedArray.to_numpy() concatenates chunk storage
# arrays and calls to_numpy() on the resulting array, which returns the wrong
# ndarray.
# TODO(Clark): Fix this in Arrow by (1) providing an ExtensionArray hook for
# concatenation, and (2) using that + a to_numpy() call on the resulting
# ExtensionArray.
t_arr2_npy = t_arr2.chunk(0).to_numpy()
else:
t_arr2_npy = t_arr2.to_numpy()

np.testing.assert_array_equal(t_arr2_npy, arr[1:])

for idx in range(1, outer_dim):
np.testing.assert_array_equal(t_arr2[idx - 1], arr[idx])


@pytest.mark.parametrize("chunked", [False, True])
def test_arrow_variable_shaped_tensor_array_getitem(chunked):
shapes = [(2, 2), (3, 3), (4, 4)]
outer_dim = len(shapes)
cumsum_sizes = np.cumsum([0] + [np.prod(shape) for shape in shapes[:-1]])
arrs = [
np.arange(offset, offset + np.prod(shape)).reshape(shape)
for offset, shape in zip(cumsum_sizes, shapes)
]
arr = np.array(arrs, dtype=object)
t_arr = ArrowVariableShapedTensorArray.from_numpy(arr)

if chunked:
t_arr = pa.chunked_array(t_arr)

np.testing.assert_array_equal(t_arr2.to_numpy(), arr[1:])
for idx in range(outer_dim):
np.testing.assert_array_equal(t_arr[idx], arr[idx])

# Test __iter__.
for t_subarr, subarr in zip(t_arr, arr):
np.testing.assert_array_equal(t_subarr, subarr)

# Test to_pylist.
for t_subarr, subarr in zip(t_arr.to_pylist(), list(arr)):
np.testing.assert_array_equal(t_subarr, subarr)

# Test slicing and indexing.
t_arr2 = t_arr[1:]
if chunked:
# For extension arrays, ChunkedArray.to_numpy() concatenates chunk storage
# arrays and calls to_numpy() on the resulting array, which returns the wrong
# ndarray.
# TODO(Clark): Fix this in Arrow by (1) providing an ExtensionArray hook for
# concatenation, and (2) using that + a to_numpy() call on the resulting
# ExtensionArray.
t_arr2_npy = t_arr2.chunk(0).to_numpy()
else:
t_arr2_npy = t_arr2.to_numpy()

for t_subarr, subarr in zip(t_arr2_npy, arr[1:]):
np.testing.assert_array_equal(t_subarr, subarr)

for idx in range(1, outer_dim):
np.testing.assert_array_equal(t_arr2[idx - 1], arr[idx])
Expand Down
7 changes: 1 addition & 6 deletions python/ray/air/util/tensor_extensions/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,12 +470,7 @@ def __getitem__(self, key):
# unfortunately overriding Cython cdef methods with normal Python methods isn't
# allowed.
if isinstance(key, slice):
sliced = super().__getitem__(key).to_numpy()
if sliced.dtype.type is not np.object_:
# Force ths slice to match NumPy semantics for unit (single-element)
# slices.
sliced = sliced[0:1]
return sliced
return super().__getitem__(key)
return self._to_numpy(key)

def __iter__(self):
Expand Down
14 changes: 0 additions & 14 deletions python/ray/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
import ray
from ray.data._internal.arrow_serialization import (
_register_arrow_json_parseoptions_serializer,
_register_arrow_json_readoptions_serializer,
)
from ray.data._internal.compute import ActorPoolStrategy
from ray.data._internal.progress_bar import set_progress_bars
from ray.data.dataset import Dataset
Expand Down Expand Up @@ -39,15 +34,6 @@
read_tfrecords,
)

# Register custom Arrow JSON ReadOptions and ParseOptions serializer after worker has
# initialized.
if ray.is_initialized():
_register_arrow_json_readoptions_serializer()
_register_arrow_json_parseoptions_serializer()
else:
pass
# ray._internal.worker._post_init_hooks.append(_register_arrow_json_readoptions_serializer)

__all__ = [
"ActorPoolStrategy",
"Dataset",
Expand Down
39 changes: 27 additions & 12 deletions python/ray/data/_internal/arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@

from ray.data._internal.sort import SortKeyT


T = TypeVar("T")


Expand All @@ -70,6 +71,19 @@ class ArrowRow(TableRow):
"""

def __getitem__(self, key: str) -> Any:
from ray.data.extensions.tensor_extension import (
ArrowTensorType,
ArrowVariableShapedTensorType,
)

schema = self._row.schema
if isinstance(
schema.field(key).type,
(ArrowTensorType, ArrowVariableShapedTensorType),
):
# Build a tensor row.
return ArrowBlockAccessor._build_tensor_row(self._row, col_name=key)

col = self._row[key]
if len(col) == 0:
return None
Expand All @@ -79,7 +93,7 @@ def __getitem__(self, key: str) -> Any:
return item.as_py()
except AttributeError:
# Assume that this row is an element of an extension array, and
# that it is bypassing pyarrow's scalar model.
# that it is bypassing pyarrow's scalar model for Arrow < 8.0.0.
return item

def __iter__(self) -> Iterator:
Expand Down Expand Up @@ -166,10 +180,13 @@ def numpy_to_block(
return pa.Table.from_pydict(new_batch)

@staticmethod
def _build_tensor_row(row: ArrowRow) -> np.ndarray:
return row[VALUE_COL_NAME][0]
def _build_tensor_row(row: ArrowRow, col_name: str = VALUE_COL_NAME) -> np.ndarray:
element = row[col_name][0]
# For Arrow < 8.0.0, accessing an element in a chunked tensor array produces an
# ndarray, which we return directly.
return element

def slice(self, start: int, end: int, copy: bool) -> "pyarrow.Table":
def slice(self, start: int, end: int, copy: bool = False) -> "pyarrow.Table":
view = self._table.slice(start, end - start)
if copy:
view = _copy_table(view)
Expand Down Expand Up @@ -212,10 +229,10 @@ def to_numpy(
arrays = []
for column in columns:
array = self._table[column]
if array.num_chunks == 0:
array = pyarrow.array([], type=array.type)
elif _is_column_extension_type(array):
if _is_column_extension_type(array):
array = _concatenate_extension_column(array)
elif array.num_chunks == 0:
array = pyarrow.array([], type=array.type)
else:
array = array.combine_chunks()
arrays.append(array.to_numpy(zero_copy_only=False))
Expand Down Expand Up @@ -399,11 +416,9 @@ def sort_and_partition(
bounds = np.searchsorted(table[col], boundaries)
last_idx = 0
for idx in bounds:
# Slices need to be copied to avoid including the base table
# during serialization.
partitions.append(_copy_table(table.slice(last_idx, idx - last_idx)))
partitions.append(table.slice(last_idx, idx - last_idx))
last_idx = idx
partitions.append(_copy_table(table.slice(last_idx)))
partitions.append(table.slice(last_idx))
return partitions

def combine(self, key: KeyFn, aggs: Tuple[AggregateFn]) -> Block[ArrowRow]:
Expand Down Expand Up @@ -449,7 +464,7 @@ def iter_groups() -> Iterator[Tuple[KeyType, Block]]:
except StopIteration:
next_row = None
break
yield next_key, self.slice(start, end, copy=False)
yield next_key, self.slice(start, end)
start = end
except StopIteration:
break
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/_internal/arrow_ops/transform_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


if TYPE_CHECKING:
from ray.data.impl.sort import SortKeyT
from ray.data._internal.sort import SortKeyT

pl = None

Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/_internal/arrow_ops/transform_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
pyarrow = None

if TYPE_CHECKING:
from ray.data.impl.sort import SortKeyT
from ray.data._internal.sort import SortKeyT


def sort(table: "pyarrow.Table", key: "SortKeyT", descending: bool) -> "pyarrow.Table":
Expand Down
Loading

0 comments on commit 75b206e

Please sign in to comment.