Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serve][AIR] Support mixed input and output type, with batching #25688

Merged
merged 7 commits into from
Jun 16, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 93 additions & 85 deletions python/ray/serve/model_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Type, Union
import numpy as np

from ray._private.utils import import_attr
Expand Down Expand Up @@ -37,62 +37,84 @@ def _load_predictor_cls(
return predictor_cls


def collate_array(
input_list: List[np.ndarray],
) -> Tuple[np.ndarray, Callable[[np.ndarray], List[np.ndarray]]]:
batch_size = len(input_list)
batched = np.stack(input_list)
class BatchingManager:
shrekris-anyscale marked this conversation as resolved.
Show resolved Hide resolved
@staticmethod
def batch_array(input_list: List[np.ndarray]) -> np.ndarray:
batched = np.stack(input_list)
return batched

def unpack(output_arr):
if isinstance(output_arr, list):
return output_arr
if not isinstance(output_arr, np.ndarray):
@staticmethod
def split_array(output_array: np.ndarray, batch_size: int) -> List[np.ndarray]:
if not isinstance(output_array, np.ndarray):
raise TypeError(
f"The output should be np.ndarray but Serve got {type(output_arr)}."
f"The output should be np.ndarray but Serve got {type(output_array)}."
)
if len(output_arr) != batch_size:
if len(output_array) != batch_size:
raise ValueError(
f"The output array should have shape of ({batch_size}, ...) "
f"because the input has {batch_size} entries "
f"but Serve got {output_arr.shape}"
f"but Serve got {output_array.shape}"
)
return [arr.squeeze(axis=0) for arr in np.split(output_arr, batch_size, axis=0)]
return [
arr.squeeze(axis=0) for arr in np.split(output_array, batch_size, axis=0)
]

@staticmethod
@require_packages(["pandas"])
def batch_dataframe(input_list: List["pd.DataFrame"]) -> "pd.DataFrame":
import pandas as pd

batched = pd.concat(input_list, axis="index", ignore_index=True, copy=False)
return batched

@staticmethod
@require_packages(["pandas"])
def split_dataframe(
output_df: "pd.DataFrame", batch_size: int
) -> List["pd.DataFrame"]:
if not isinstance(output_df, pd.DataFrame):
raise TypeError(
"The output should be a Pandas DataFrame but Serve got "
f"{type(output_df)}"
)
if len(output_df) % batch_size != 0:
raise ValueError(
f"The output dataframe should have length divisible by {batch_size}, "
f"but Serve got length {len(output_df)}."
)
return [df.reset_index(drop=True) for df in np.split(output_df, batch_size)]

return batched, unpack
@staticmethod
def batch_dict_array(
input_list: List[Dict[str, np.ndarray]]
) -> Dict[str, np.ndarray]:
batch_size = len(input_list)

# Check all input has the same dict keys.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Check all input has the same dict keys.
# Check that all inputs have the same dict keys.

input_keys = [set(item.keys()) for item in input_list]
batch_has_same_keys = input_keys.count(input_keys[0]) == batch_size
if not batch_has_same_keys:
raise ValueError(
f"The input batch contains dictionary of different keys: {input_keys}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"The input batch contains dictionary of different keys: {input_keys}"
f"The input batch's dictionaries must each contain the same keys. Got different keys in some dictionaries: {input_keys}"

)

def collate_dict_array(
input_list: List[Dict[str, np.ndarray]]
) -> Tuple[
Dict[str, np.ndarray],
Callable[[Dict[str, np.ndarray]], List[Dict[str, np.ndarray]]],
]:
batch_size = len(input_list)
# Turn list[dict[str, array]] to dict[str, List[array]]
key_to_list = defaultdict(list)
for single_dict in input_list:
for key, arr in single_dict.items():
key_to_list[key].append(arr)

# Check all input has the same dict keys.
input_keys = [set(item.keys()) for item in input_list]
batch_has_same_keys = input_keys.count(input_keys[0]) == batch_size
if not batch_has_same_keys:
raise ValueError(
f"The input batch contains dictionary of different keys: {input_keys}"
)
# Turn dict[str, List[array]] to dict[str, array]
batched_dict = {}
for key, list_of_arr in key_to_list.items():
arr = BatchingManager.batch_array(list_of_arr)
batched_dict[key] = arr

# Turn list[dict[str, array]] to dict[str, List[array]]
key_to_list = defaultdict(list)
for single_dict in input_list:
for key, arr in single_dict.items():
key_to_list[key].append(arr)

# Turn dict[str, List[array]] to dict[str, array]
batched_dict = {}
unpack_dict = {}
for key, list_of_arr in key_to_list.items():
arr, unpack_func = collate_array(list_of_arr)
batched_dict[key] = arr
unpack_dict[key] = unpack_func

def unpack(output_dict: Dict[str, np.ndarray]):
# short circuit behavior, assume users already unpacked the output for us.
return batched_dict

@staticmethod
def split_dict_array(
output_dict: Dict[str, np.ndarray], batch_size: int
) -> List[Dict[str, np.ndarray]]:
if isinstance(output_dict, list):
return output_dict

Expand All @@ -102,44 +124,14 @@ def unpack(output_dict: Dict[str, np.ndarray]):
)

split_list_of_dict = [{} for _ in range(batch_size)]
for key, arr_unpack_func in unpack_dict.items():
arr_list = arr_unpack_func(output_dict[key])
for key, result_arr in output_dict.items():
split_arrays = BatchingManager.split_array(result_arr, batch_size)
# in place update each dictionary with the split array chunk.
for item, arr in zip(split_list_of_dict, arr_list):
for item, arr in zip(split_list_of_dict, split_arrays):
item[key] = arr

return split_list_of_dict

return batched_dict, unpack


@require_packages(["pandas"])
def collate_dataframe(
input_list: List["pd.DataFrame"],
) -> Tuple["pd.DataFrame", Callable[["pd.DataFrame"], List["pd.DataFrame"]]]:
import pandas as pd

batch_size = len(input_list)
batched = pd.concat(input_list, axis="index", ignore_index=True, copy=False)

def unpack(output_df):
if isinstance(output_df, list):
return output_df
if not isinstance(output_df, pd.DataFrame):
raise TypeError(
"The output should be a Pandas DataFrame but Serve got "
f"{type(output_df)}"
)
if len(output_df) % batch_size != 0:
raise ValueError(
f"The output dataframe should have length divisible by {batch_size}, "
f"because the input from {batch_size} different requests "
f"but Serve got length {len(output_df)}."
)
return [df.reset_index(drop=True) for df in np.split(output_df, batch_size)]

return batched, unpack


class ModelWrapper(SimpleSchemaIngress):
"""Serve any Ray AIR predictor from an AIR checkpoint.
Expand Down Expand Up @@ -204,22 +196,38 @@ async def predict_impl(inp: Union[np.ndarray, "pd.DataFrame"]):

@serve.batch(**batching_params)
async def predict_impl(inp: Union[List[np.ndarray], List["pd.DataFrame"]]):

batch_size = len(inp)
if isinstance(inp[0], np.ndarray):
collate_func = collate_array
batched = BatchingManager.batch_array(inp)
elif pd is not None and isinstance(inp[0], pd.DataFrame):
collate_func = collate_dataframe
batched = BatchingManager.batch_dataframe(inp)
elif isinstance(inp[0], dict):
batched = BatchingManager.batch_dict_array(inp)
else:
raise ValueError(
"ModelWrapper only accepts numpy array or dataframe as input "
"ModelWrapper only accepts numpy array, dataframe, or dict of "
"array as input "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"array as input "
"arrays as input "

f"but got types {[type(i) for i in inp]}"
)

batched, unpack = collate_func(inp)
out = self.model.predict(batched, **predict_kwargs)
if isinstance(out, ray.ObjectRef):
out = await out
return unpack(out)

if isinstance(out, np.ndarray):
return BatchingManager.split_array(out, batch_size)
elif pd is not None and isinstance(out, pd.DataFrame):
return BatchingManager.split_dataframe(out, batch_size)
elif isinstance(out, dict):
return BatchingManager.split_dict_array(out, batch_size)
elif isinstance(out, list) and len(out) == batch_size:
return out
else:
raise ValueError(
f"ModelWrapper only accepts list of length {batch_size}, numpy "
shrekris-anyscale marked this conversation as resolved.
Show resolved Hide resolved
"array, dataframe, or dict of array as output "
f"but got types {[type(i) for i in inp]}"
)

self.predict_impl = predict_impl

Expand Down
60 changes: 45 additions & 15 deletions python/ray/serve/tests/test_model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@
import pytest

from ray.serve.model_wrappers import (
BatchingManager,
ModelWrapperDeployment,
collate_array,
collate_dataframe,
collate_dict_array,
)
from ray.air.checkpoint import Checkpoint
from ray.air.predictor import DataBatchType, Predictor
Expand All @@ -23,37 +21,39 @@
from ray import serve


class TestCollationFunctions:
class TestBatchingFunctionFunctions:
def test_array(self):
list_of_arr = [np.array([i]) for i in range(4)]
batched_arr = np.array([[i] for i in range(4)])
batch_size = 4

batched, unpack = collate_array(list_of_arr)
batched = BatchingManager.batch_array(list_of_arr)
assert np.array_equal(batched, batched_arr)
for i, j in zip(unpack(batched), list_of_arr):

for i, j in zip(BatchingManager.split_array(batched, batch_size), list_of_arr):
assert np.array_equal(i, j)

def test_array_error(self):
list_of_arr = [np.array([i]) for i in range(4)]
_, unpack = collate_array(list_of_arr)
with pytest.raises(ValueError, match="output array should have shape of"):
unpack(np.arange(2))
BatchingManager.split_array(np.arange(2), 10)
with pytest.raises(TypeError, match="output should be np.ndarray but"):
unpack("string")
BatchingManager.split_array("string", 6)

def test_dict_array(self):
list_of_dicts = [
{"a": np.array([1, 2]), "b": np.array(3)},
{"a": np.array([3, 4]), "b": np.array(4)},
]
batched_dict = {"a": np.array([[1, 2], [3, 4]]), "b": np.array([3, 4])}
batch_size = 2

batched, unpack = collate_dict_array(list_of_dicts)
batched = BatchingManager.batch_dict_array(list_of_dicts)
assert batched.keys() == batched_dict.keys()
for key in batched.keys():
assert np.array_equal(batched[key], batched_dict[key])

for original, unpacked in zip(list_of_dicts, unpack(batched)):
unpacked_list = BatchingManager.split_dict_array(batched, batch_size)
for original, unpacked in zip(list_of_dicts, unpacked_list):
assert original.keys() == unpacked.keys()
for key in original.keys():
assert np.array_equal(original[key], unpacked[key])
Expand All @@ -66,10 +66,14 @@ def test_dataframe(self):
"b": sum(([i, i] for i in range(4)), []),
}
)
batched, unpack = collate_dataframe(list_of_dfs)
batch_size = 4

batched = BatchingManager.batch_dataframe(list_of_dfs)
assert batched.equals(batched_df)
assert len(unpack(batched)) == len(list_of_dfs)
for i, j in zip(unpack(batched), list_of_dfs):

unpacked_list = BatchingManager.split_dataframe(batched, batch_size)
assert len(unpacked_list) == len(list_of_dfs)
for i, j in zip(unpacked_list, list_of_dfs):
assert i.equals(j)


Expand Down Expand Up @@ -145,6 +149,32 @@ def test_batching(serve_instance):
assert resp == {"value": [42], "batch_size": 2}


class TakeArrayReturnDataFramePredictor(Predictor):
def __init__(self, increment: int) -> None:
self.increment = increment

@classmethod
def from_checkpoint(
cls, checkpoint: Checkpoint
) -> "TakeArrayReturnDataFramePredictor":
return cls(checkpoint.to_dict()["increment"])

def predict(self, data: np.ndarray) -> DataBatchType:
return pd.DataFrame(data + self.increment, columns=["col_a", "col_b"])


def test_mixed_input_output_type_with_batching(serve_instance):
ModelWrapperDeployment.options(name="Adder").deploy(
predictor_cls=TakeArrayReturnDataFramePredictor,
checkpoint=Checkpoint.from_dict({"increment": 2}),
batching_params=dict(max_batch_size=2, batch_wait_timeout_s=1000),
)

refs = [send_request.remote(json={"array": [40, 45]}) for _ in range(2)]
for resp in ray.get(refs):
assert resp == [{"col_a": 42.0, "col_b": 47.0}]


app = FastAPI()


Expand Down
2 changes: 1 addition & 1 deletion python/ray/serve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ async def wrapped(*args, **kwargs):
check_import_once()
return await func(*args, **kwargs)

elif inspect.isfunction(func):
elif inspect.isroutine(func):

@wraps(func)
def wrapped(*args, **kwargs):
Expand Down