Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serve][AIR] Support mixed input and output type, with batching #25688

Merged
merged 7 commits into from
Jun 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/ray/serve/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ py_test(

py_test(
name = "test_model_wrappers",
size = "small",
size = "medium",
srcs = serve_tests_srcs,
tags = ["exclusive", "team:serve"],
deps = [":serve_lib"],
Expand Down
193 changes: 97 additions & 96 deletions python/ray/serve/model_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,7 @@
from collections import defaultdict
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
Union,
)
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union

import numpy as np

import ray
from ray import serve
from ray._private.utils import import_attr
Expand Down Expand Up @@ -55,62 +44,87 @@ def _load_predictor_cls(
return predictor_cls


def collate_array(
input_list: List[np.ndarray],
) -> Tuple[np.ndarray, Callable[[np.ndarray], List[np.ndarray]]]:
batch_size = len(input_list)
batched = np.stack(input_list)
class BatchingManager:
shrekris-anyscale marked this conversation as resolved.
Show resolved Hide resolved
"""A collection of utilities for batching and splitting data."""

@staticmethod
def batch_array(input_list: List[np.ndarray]) -> np.ndarray:
batched = np.stack(input_list)
return batched

def unpack(output_arr):
if isinstance(output_arr, list):
return output_arr
if not isinstance(output_arr, np.ndarray):
@staticmethod
def split_array(output_array: np.ndarray, batch_size: int) -> List[np.ndarray]:
if not isinstance(output_array, np.ndarray):
raise TypeError(
f"The output should be np.ndarray but Serve got {type(output_arr)}."
f"The output should be np.ndarray but Serve got {type(output_array)}."
)
if len(output_arr) != batch_size:
if len(output_array) != batch_size:
raise ValueError(
f"The output array should have shape of ({batch_size}, ...) "
f"because the input has {batch_size} entries "
f"but Serve got {output_arr.shape}"
f"but Serve got {output_array.shape}"
)
return [arr.squeeze(axis=0) for arr in np.split(output_arr, batch_size, axis=0)]
return [
arr.squeeze(axis=0) for arr in np.split(output_array, batch_size, axis=0)
]

@staticmethod
@require_packages(["pandas"])
def batch_dataframe(input_list: List["pd.DataFrame"]) -> "pd.DataFrame":
import pandas as pd

batched = pd.concat(input_list, axis="index", ignore_index=True, copy=False)
return batched

@staticmethod
@require_packages(["pandas"])
def split_dataframe(
output_df: "pd.DataFrame", batch_size: int
) -> List["pd.DataFrame"]:
if not isinstance(output_df, pd.DataFrame):
raise TypeError(
"The output should be a Pandas DataFrame but Serve got "
f"{type(output_df)}"
)
if len(output_df) % batch_size != 0:
raise ValueError(
f"The output dataframe should have length divisible by {batch_size}, "
f"but Serve got length {len(output_df)}."
)
return [df.reset_index(drop=True) for df in np.split(output_df, batch_size)]

return batched, unpack
@staticmethod
def batch_dict_array(
input_list: List[Dict[str, np.ndarray]]
) -> Dict[str, np.ndarray]:
batch_size = len(input_list)

# Check that all inputs have the same dict keys.
input_keys = [set(item.keys()) for item in input_list]
batch_has_same_keys = input_keys.count(input_keys[0]) == batch_size
if not batch_has_same_keys:
raise ValueError(
"The input batch's dictoinary must contain the same keys. "
f"Got different keys in some dictionaries: {input_keys}."
)

def collate_dict_array(
input_list: List[Dict[str, np.ndarray]]
) -> Tuple[
Dict[str, np.ndarray],
Callable[[Dict[str, np.ndarray]], List[Dict[str, np.ndarray]]],
]:
batch_size = len(input_list)
# Turn list[dict[str, array]] to dict[str, List[array]]
key_to_list = defaultdict(list)
for single_dict in input_list:
for key, arr in single_dict.items():
key_to_list[key].append(arr)

# Check all input has the same dict keys.
input_keys = [set(item.keys()) for item in input_list]
batch_has_same_keys = input_keys.count(input_keys[0]) == batch_size
if not batch_has_same_keys:
raise ValueError(
f"The input batch contains dictionary of different keys: {input_keys}"
)
# Turn dict[str, List[array]] to dict[str, array]
batched_dict = {}
for key, list_of_arr in key_to_list.items():
arr = BatchingManager.batch_array(list_of_arr)
batched_dict[key] = arr

return batched_dict

# Turn list[dict[str, array]] to dict[str, List[array]]
key_to_list = defaultdict(list)
for single_dict in input_list:
for key, arr in single_dict.items():
key_to_list[key].append(arr)

# Turn dict[str, List[array]] to dict[str, array]
batched_dict = {}
unpack_dict = {}
for key, list_of_arr in key_to_list.items():
arr, unpack_func = collate_array(list_of_arr)
batched_dict[key] = arr
unpack_dict[key] = unpack_func

def unpack(output_dict: Dict[str, np.ndarray]):
# short circuit behavior, assume users already unpacked the output for us.
@staticmethod
def split_dict_array(
output_dict: Dict[str, np.ndarray], batch_size: int
) -> List[Dict[str, np.ndarray]]:
if isinstance(output_dict, list):
return output_dict

Expand All @@ -120,44 +134,14 @@ def unpack(output_dict: Dict[str, np.ndarray]):
)

split_list_of_dict = [{} for _ in range(batch_size)]
for key, arr_unpack_func in unpack_dict.items():
arr_list = arr_unpack_func(output_dict[key])
for key, result_arr in output_dict.items():
split_arrays = BatchingManager.split_array(result_arr, batch_size)
# in place update each dictionary with the split array chunk.
for item, arr in zip(split_list_of_dict, arr_list):
for item, arr in zip(split_list_of_dict, split_arrays):
item[key] = arr

return split_list_of_dict

return batched_dict, unpack


@require_packages(["pandas"])
def collate_dataframe(
input_list: List["pd.DataFrame"],
) -> Tuple["pd.DataFrame", Callable[["pd.DataFrame"], List["pd.DataFrame"]]]:
import pandas as pd

batch_size = len(input_list)
batched = pd.concat(input_list, axis="index", ignore_index=True, copy=False)

def unpack(output_df):
if isinstance(output_df, list):
return output_df
if not isinstance(output_df, pd.DataFrame):
raise TypeError(
"The output should be a Pandas DataFrame but Serve got "
f"{type(output_df)}"
)
if len(output_df) % batch_size != 0:
raise ValueError(
f"The output dataframe should have length divisible by {batch_size}, "
f"because the input from {batch_size} different requests "
f"but Serve got length {len(output_df)}."
)
return [df.reset_index(drop=True) for df in np.split(output_df, batch_size)]

return batched, unpack


class ModelWrapper(SimpleSchemaIngress):
"""Serve any Ray AIR predictor from an AIR checkpoint.
Expand Down Expand Up @@ -222,22 +206,39 @@ async def predict_impl(inp: Union[np.ndarray, "pd.DataFrame"]):

@serve.batch(**batching_params)
async def predict_impl(inp: Union[List[np.ndarray], List["pd.DataFrame"]]):

batch_size = len(inp)
if isinstance(inp[0], np.ndarray):
collate_func = collate_array
batched = BatchingManager.batch_array(inp)
elif pd is not None and isinstance(inp[0], pd.DataFrame):
collate_func = collate_dataframe
batched = BatchingManager.batch_dataframe(inp)
elif isinstance(inp[0], dict):
batched = BatchingManager.batch_dict_array(inp)
else:
raise ValueError(
"ModelWrapper only accepts numpy array or dataframe as input "
"ModelWrapper only accepts numpy array, dataframe, or dict of "
"arrays as input "
f"but got types {[type(i) for i in inp]}"
)

batched, unpack = collate_func(inp)
out = self.model.predict(batched, **predict_kwargs)
if isinstance(out, ray.ObjectRef):
out = await out
return unpack(out)

if isinstance(out, np.ndarray):
return BatchingManager.split_array(out, batch_size)
elif pd is not None and isinstance(out, pd.DataFrame):
return BatchingManager.split_dataframe(out, batch_size)
elif isinstance(out, dict):
return BatchingManager.split_dict_array(out, batch_size)
elif isinstance(out, list) and len(out) == batch_size:
return out
else:
raise ValueError(
f"ModelWrapper only accepts list of length {batch_size}, numpy "
shrekris-anyscale marked this conversation as resolved.
Show resolved Hide resolved
"array, dataframe, or dict of array as output "
f"but got types {type(out)} with length "
f"{len(out) if hasattr(out, '__len__') else 'unknown'}."
)

self.predict_impl = predict_impl

Expand Down
66 changes: 46 additions & 20 deletions python/ray/serve/tests/test_model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,56 +4,52 @@
import numpy as np
import pandas as pd
import pytest
import ray
import requests
from fastapi import Depends, FastAPI

import ray
from ray import serve
from ray.air.checkpoint import Checkpoint
from ray.serve.dag import InputNode
from ray.serve.deployment_graph import RayServeDAGHandle
from ray.serve.deployment_graph_build import build
from ray.serve.http_adapters import json_to_ndarray
from ray.serve.model_wrappers import (
ModelWrapperDeployment,
collate_array,
collate_dataframe,
collate_dict_array,
)
from ray.serve.model_wrappers import BatchingManager, ModelWrapperDeployment
from ray.train.predictor import DataBatchType, Predictor


class TestCollationFunctions:
class TestBatchingFunctionFunctions:
def test_array(self):
list_of_arr = [np.array([i]) for i in range(4)]
batched_arr = np.array([[i] for i in range(4)])
batch_size = 4

batched, unpack = collate_array(list_of_arr)
batched = BatchingManager.batch_array(list_of_arr)
assert np.array_equal(batched, batched_arr)
for i, j in zip(unpack(batched), list_of_arr):

for i, j in zip(BatchingManager.split_array(batched, batch_size), list_of_arr):
assert np.array_equal(i, j)

def test_array_error(self):
list_of_arr = [np.array([i]) for i in range(4)]
_, unpack = collate_array(list_of_arr)
with pytest.raises(ValueError, match="output array should have shape of"):
unpack(np.arange(2))
BatchingManager.split_array(np.arange(2), 10)
with pytest.raises(TypeError, match="output should be np.ndarray but"):
unpack("string")
BatchingManager.split_array("string", 6)

def test_dict_array(self):
list_of_dicts = [
{"a": np.array([1, 2]), "b": np.array(3)},
{"a": np.array([3, 4]), "b": np.array(4)},
]
batched_dict = {"a": np.array([[1, 2], [3, 4]]), "b": np.array([3, 4])}
batch_size = 2

batched, unpack = collate_dict_array(list_of_dicts)
batched = BatchingManager.batch_dict_array(list_of_dicts)
assert batched.keys() == batched_dict.keys()
for key in batched.keys():
assert np.array_equal(batched[key], batched_dict[key])

for original, unpacked in zip(list_of_dicts, unpack(batched)):
unpacked_list = BatchingManager.split_dict_array(batched, batch_size)
for original, unpacked in zip(list_of_dicts, unpacked_list):
assert original.keys() == unpacked.keys()
for key in original.keys():
assert np.array_equal(original[key], unpacked[key])
Expand All @@ -66,10 +62,14 @@ def test_dataframe(self):
"b": sum(([i, i] for i in range(4)), []),
}
)
batched, unpack = collate_dataframe(list_of_dfs)
batch_size = 4

batched = BatchingManager.batch_dataframe(list_of_dfs)
assert batched.equals(batched_df)
assert len(unpack(batched)) == len(list_of_dfs)
for i, j in zip(unpack(batched), list_of_dfs):

unpacked_list = BatchingManager.split_dataframe(batched, batch_size)
assert len(unpacked_list) == len(list_of_dfs)
for i, j in zip(unpacked_list, list_of_dfs):
assert i.equals(j)


Expand Down Expand Up @@ -145,6 +145,32 @@ def test_batching(serve_instance):
assert resp == {"value": [42], "batch_size": 2}


class TakeArrayReturnDataFramePredictor(Predictor):
def __init__(self, increment: int) -> None:
self.increment = increment

@classmethod
def from_checkpoint(
cls, checkpoint: Checkpoint
) -> "TakeArrayReturnDataFramePredictor":
return cls(checkpoint.to_dict()["increment"])

def predict(self, data: np.ndarray) -> DataBatchType:
return pd.DataFrame(data + self.increment, columns=["col_a", "col_b"])


def test_mixed_input_output_type_with_batching(serve_instance):
ModelWrapperDeployment.options(name="Adder").deploy(
predictor_cls=TakeArrayReturnDataFramePredictor,
checkpoint=Checkpoint.from_dict({"increment": 2}),
batching_params=dict(max_batch_size=2, batch_wait_timeout_s=1000),
)

refs = [send_request.remote(json={"array": [40, 45]}) for _ in range(2)]
for resp in ray.get(refs):
assert resp == [{"col_a": 42.0, "col_b": 47.0}]


app = FastAPI()


Expand Down
Loading