Skip to content

Commit

Permalink
[data] fix nested ragged ndarray (ray-project#44236)
Browse files Browse the repository at this point in the history
---------

Signed-off-by: Hao Chen <[email protected]>
  • Loading branch information
raulchen authored Mar 26, 2024
1 parent 99cb040 commit 459edae
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
23 changes: 11 additions & 12 deletions python/ray/data/_internal/numpy_support.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, List

import numpy as np

Expand All @@ -24,12 +24,11 @@ def is_valid_udf_return(udf_return_col: Any) -> bool:
return isinstance(udf_return_col, list) or is_array_like(udf_return_col)


def is_scalar_list(udf_return_col: Any) -> bool:
"""Check whether a UDF column is is a scalar list."""

return isinstance(udf_return_col, list) and (
not udf_return_col or np.isscalar(udf_return_col[0])
)
def is_nested_list(udf_return_col: List[Any]) -> bool:
for e in udf_return_col:
if isinstance(e, list):
return True
return False


def convert_udf_returns_to_numpy(udf_return_col: Any) -> Any:
Expand Down Expand Up @@ -62,11 +61,11 @@ def convert_udf_returns_to_numpy(udf_return_col: Any) -> Any:
# `str` are also Iterable.
try:
# Try to cast the inner scalars to numpy as well, to avoid unnecessarily
# creating an inefficient array of array of object dtype. Don't convert
# scalar lists though, since those can be represented as pyarrow list type
# without needing to go through our tensor extension.
if all(
is_valid_udf_return(e) and not is_scalar_list(e) for e in udf_return_col
# creating an inefficient array of array of object dtype.
# But don't convert if the list is nested. Because if sub-lists have
# heterogeneous shapes, we need to create a ragged ndarray.
if not is_nested_list(udf_return_col) and all(
is_valid_udf_return(e) for e in udf_return_col
):
# Use np.asarray() instead of np.array() to avoid copying if possible.
udf_return_col = [np.asarray(e) for e in udf_return_col]
Expand Down
15 changes: 14 additions & 1 deletion python/ray/data/tests/test_numpy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_ragged_array_like(ray_start_regular_shared):
def test_scalar_nested_arrays(ray_start_regular_shared):
data = [[[1]], [[2]]]
output = do_map_batches(data)
assert_structure_equals(output, np.array([[[1]], [[2]]]))
assert_structure_equals(output, create_ragged_ndarray(data))


def test_scalar_lists_not_converted(ray_start_regular_shared):
Expand Down Expand Up @@ -155,6 +155,19 @@ def test_scalar_ragged_array_like(ray_start_regular_shared):
)


def test_nested_ragged_arrays(ray_start_regular_shared):
data = [
{"a": [[1], [2, 3]]},
{"a": [[4, 5], [6]]},
]

def f(row):
return data[row["id"]]

output = ray.data.range(2).map(f).take_all()
assert output == data


# https://github.com/ray-project/ray/issues/35340
def test_complex_ragged_arrays(ray_start_regular_shared):
data = [[{"a": 1}, {"a": 2}, {"a": 3}], [{"b": 1}]]
Expand Down

0 comments on commit 459edae

Please sign in to comment.