Skip to content

Commit

Permalink
[AIR - Datasets] Encode number of dimensions in variable-shaped tenso…
Browse files Browse the repository at this point in the history
…r extension type. (ray-project#29281)

Knowing the number of dimensions in a variable-shaped tensor column is useful for e.g. inferring a ragged tensor spec when constructing a tf.data Dataset; by encoding this ndim data in the extension type, we can do this type inference base on Dataset metadata, which is required.

Note that this change will disallow variable-shaped tensor columns containing tensor elements that have a variable number of dimensions. This isn't supported by TensorFlow and Torch ragged tensors, so sacrificing this feature seems tenable.
  • Loading branch information
clarkzinzow committed Oct 24, 2022
1 parent 3562cb4 commit 9a3c0fb
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 19 deletions.
6 changes: 5 additions & 1 deletion python/ray/air/tests/test_tensor_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def test_scalar_tensor_array_roundtrip():


def test_arrow_variable_shaped_tensor_array_validation():
# Test tensor elements with differing dimensions raises ValueError.
with pytest.raises(ValueError):
ArrowVariableShapedTensorArray.from_numpy([np.ones((2, 2)), np.ones((3, 3, 3))])

# Test arbitrary object raises ValueError.
with pytest.raises(ValueError):
ArrowVariableShapedTensorArray.from_numpy(object())
Expand Down Expand Up @@ -405,7 +409,7 @@ def test_tensor_array_concat(a1, a2):
assert ta.dtype.element_shape == a1.shape[1:]
np.testing.assert_array_equal(ta.to_numpy(), np.concatenate([a1, a2]))
else:
assert ta.dtype.element_shape is None
assert ta.dtype.element_shape == (None,) * (len(a1.shape) - 1)
for arr, expected in zip(
ta.to_numpy(), np.array([e for a in [a1, a2] for e in a], dtype=object)
):
Expand Down
28 changes: 22 additions & 6 deletions python/ray/air/util/tensor_extensions/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,13 +383,15 @@ class ArrowVariableShapedTensorType(pa.PyExtensionType):
https://arrow.apache.org/docs/python/extending_types.html#defining-extension-types-user-defined-types
"""

def __init__(self, dtype: pa.DataType):
def __init__(self, dtype: pa.DataType, ndim: int):
"""
Construct the Arrow extension type for array of heterogeneous-shaped tensors.
Args:
dtype: pyarrow dtype of tensor elements.
ndim: The number of dimensions in the tensor elements.
"""
self._ndim = ndim
super().__init__(
pa.struct([("data", pa.list_(dtype)), ("shape", pa.list_(pa.int64()))])
)
Expand All @@ -404,14 +406,19 @@ def to_pandas_dtype(self):
from ray.air.util.tensor_extensions.pandas import TensorDtype

return TensorDtype(
None,
(None,) * self.ndim,
self.storage_type["data"].type.value_type.to_pandas_dtype(),
)

@property
def ndim(self) -> int:
"""Return the number of dimensions in the tensor elements."""
return self._ndim

def __reduce__(self):
return (
ArrowVariableShapedTensorType,
(self.storage_type["data"].type.value_type,),
(self.storage_type["data"].type.value_type, self._ndim),
)

def __arrow_ext_class__(self):
Expand All @@ -426,7 +433,7 @@ def __arrow_ext_class__(self):

def __str__(self) -> str:
dtype = self.storage_type["data"].type.value_type
return f"ArrowVariableShapedTensorType(dtype={dtype})"
return f"ArrowVariableShapedTensorType(dtype={dtype}, ndim={self.ndim})"

def __repr__(self) -> str:
return str(self)
Expand All @@ -440,7 +447,8 @@ class ArrowVariableShapedTensorArray(pa.ExtensionArray):
This is the Arrow side of TensorArray for tensor elements that have differing
shapes. Note that this extension only supports non-ragged tensor elements; i.e.,
when considering each tensor element in isolation, they must have a well-defined
shape.
shape. This extension also only supports tensor elements that all have the same
number of dimensions.
See Arrow docs for customizing extension arrays:
https://arrow.apache.org/docs/python/extending_types.html#custom-extension-array-class
Expand Down Expand Up @@ -520,8 +528,16 @@ def from_numpy(

# Whether all subndarrays are contiguous views of the same ndarray.
shapes, sizes, raveled = [], [], []
ndim = None
for a in arr:
a = np.asarray(a)
if ndim is not None and a.ndim != ndim:
raise ValueError(
"ArrowVariableShapedTensorArray only supports tensor elements that "
"all have the same number of dimensions, but got tensor elements "
f"with dimensions: {ndim}, {a.ndim}"
)
ndim = a.ndim
shapes.append(a.shape)
sizes.append(a.size)
# Convert to 1D array view; this should be zero-copy in the common case.
Expand Down Expand Up @@ -571,7 +587,7 @@ def from_numpy(
[data_array, shape_array],
["data", "shape"],
)
type_ = ArrowVariableShapedTensorType(pa_dtype)
type_ = ArrowVariableShapedTensorType(pa_dtype, ndim)
return pa.ExtensionArray.from_storage(type_, storage)

def _to_numpy(self, index: Optional[int] = None, zero_copy_only: bool = False):
Expand Down
12 changes: 6 additions & 6 deletions python/ray/air/util/tensor_extensions/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ class TensorDtype(pd.api.extensions.ExtensionDtype):
# https://github.com/CODAIT/text-extensions-for-pandas/issues/166
base = None

def __init__(self, shape: Optional[Tuple[int, ...]], dtype: np.dtype):
def __init__(self, shape: Tuple[Optional[int], ...], dtype: np.dtype):
self._shape = shape
self._dtype = dtype

Expand All @@ -308,8 +308,8 @@ def element_dtype(self):
@property
def element_shape(self):
"""
The shape of the underlying tensor elements. This will be None if the
corresponding TensorArray for this TensorDtype holds variable-shaped tensor
The shape of the underlying tensor elements. This will be a tuple of Nones if
the corresponding TensorArray for this TensorDtype holds variable-shaped tensor
elements.
"""
return self._shape
Expand All @@ -320,7 +320,7 @@ def is_variable_shaped(self):
Whether the corresponding TensorArray for this TensorDtype holds variable-shaped
tensor elements.
"""
return self.shape is None
return all(dim_size is None for dim_size in self.shape)

@property
def name(self) -> str:
Expand Down Expand Up @@ -384,7 +384,7 @@ def construct_from_string(cls, string: str):
)
# Upstream code uses exceptions as part of its normal control flow and
# will pass this method bogus class names.
regex = r"^TensorDtype\(shape=((?:\((?:\d+,?\s?)*\))|(?:None)), dtype=(\w+)\)$"
regex = r"^TensorDtype\(shape=(\((?:(?:\d+|None),?\s?)*\)), dtype=(\w+)\)$"
m = re.search(regex, string)
err_msg = (
f"Cannot construct a '{cls.__name__}' from '{string}'; expected a string "
Expand Down Expand Up @@ -890,7 +890,7 @@ def dtype(self) -> pd.api.extensions.ExtensionDtype:
# A tensor is only considered variable-shaped if it's non-empty, so no
# non-empty check is needed here.
dtype = self._tensor[0].dtype
shape = None
shape = (None,) * self._tensor[0].ndim
else:
dtype = self.numpy_dtype
shape = self.numpy_shape[1:]
Expand Down
1 change: 0 additions & 1 deletion python/ray/data/_internal/arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ def numpy_to_block(

@staticmethod
def _build_tensor_row(row: ArrowRow) -> np.ndarray:
# Getting an item in a tensor column automatically does a NumPy conversion.
return row[VALUE_COL_NAME][0]

def slice(self, start: int, end: int, copy: bool) -> "pyarrow.Table":
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/_internal/arrow_ops/transform_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def sort(table: "pyarrow.Table", key: "SortKeyT", descending: bool) -> "pyarrow.
import pyarrow.compute as pac

indices = pac.sort_indices(table, sort_keys=key)
return table.take(indices)
return take_table(table, indices)


def take_table(
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ def test_tensors_inferred_from_map(ray_start_regular_shared):
)
assert str(ds) == (
"Dataset(num_blocks=4, num_rows=16, "
"schema={a: TensorDtype(shape=None, dtype=float64)})"
"schema={a: TensorDtype(shape=(None, None), dtype=float64)})"
)


Expand Down
13 changes: 10 additions & 3 deletions python/ray/data/tests/test_transform_pyarrow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pyarrow as pa
import pytest

from ray.data.extensions import (
ArrowTensorArray,
Expand Down Expand Up @@ -108,7 +109,7 @@ def test_arrow_concat_tensor_extension_variable_shaped():
assert len(out) == 4
# Check schema.
assert out.column_names == ["a"]
assert out.schema.types == [ArrowVariableShapedTensorType(pa.int64())]
assert out.schema.types == [ArrowVariableShapedTensorType(pa.int64(), 2)]
# Confirm that concatenation is zero-copy (i.e. it didn't trigger chunk
# consolidation).
assert out["a"].num_chunks == 2
Expand Down Expand Up @@ -136,7 +137,7 @@ def test_arrow_concat_tensor_extension_uniform_and_variable_shaped():
assert len(out) == 5
# Check schema.
assert out.column_names == ["a"]
assert out.schema.types == [ArrowVariableShapedTensorType(pa.int64())]
assert out.schema.types == [ArrowVariableShapedTensorType(pa.int64(), 2)]
# Confirm that concatenation is zero-copy (i.e. it didn't trigger chunk
# consolidation).
assert out["a"].num_chunks == 2
Expand All @@ -161,7 +162,7 @@ def test_arrow_concat_tensor_extension_uniform_but_different():
assert len(out) == 6
# Check schema.
assert out.column_names == ["a"]
assert out.schema.types == [ArrowVariableShapedTensorType(pa.int64())]
assert out.schema.types == [ArrowVariableShapedTensorType(pa.int64(), 2)]
# Confirm that concatenation is zero-copy (i.e. it didn't trigger chunk
# consolidation).
assert out["a"].num_chunks == 2
Expand All @@ -170,3 +171,9 @@ def test_arrow_concat_tensor_extension_uniform_but_different():
np.testing.assert_array_equal(out["a"].chunk(1).to_numpy(), a2)
# NOTE: We don't check equivalence with pyarrow.concat_tables since it currently
# fails for this case.


if __name__ == "__main__":
import sys

sys.exit(pytest.main(["-v", __file__]))

0 comments on commit 9a3c0fb

Please sign in to comment.