Skip to content

Commit

Permalink
[AIR] Introduce better scoring API for BatchPredictor (ray-project#…
Browse files Browse the repository at this point in the history
…26451)

Signed-off-by: Amog Kamsetty <[email protected]>

As discussed offline, allow configurability for feature columns and keep columns in BatchPredictor for better scoring UX on test datasets.
  • Loading branch information
amogkam committed Jul 14, 2022
1 parent a0ce3c1 commit 6595bd6
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 20 deletions.
6 changes: 3 additions & 3 deletions doc/source/ray-air/examples/torch_image_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@
")\n",
"\n",
"outputs: ray.data.Dataset = batch_predictor.predict(\n",
" data=predict_dataset, dtype=torch.float\n",
" data=test_dataset, dtype=torch.float, feature_columns=[\"image\"], keep_columns=[\"label\"]\n",
")"
]
},
Expand Down Expand Up @@ -482,7 +482,7 @@
"def convert_logits_to_classes(df):\n",
" best_class = df[\"predictions\"].map(lambda x: x.argmax())\n",
" df[\"prediction\"] = best_class\n",
" return df[[\"prediction\"]]\n",
" return df\n",
"\n",
"predictions = outputs.map_batches(\n",
" convert_logits_to_classes, batch_format=\"pandas\"\n",
Expand Down Expand Up @@ -536,7 +536,7 @@
" df[\"correct\"] = df[\"prediction\"] == df[\"label\"]\n",
" return df[[\"prediction\", \"label\", \"correct\"]]\n",
"\n",
"scores = test_dataset.zip(predictions).map_batches(calculate_prediction_scores)\n",
"scores = predictions.map_batches(calculate_prediction_scores)\n",
"\n",
"scores.show(1)"
]
Expand Down
6 changes: 2 additions & 4 deletions doc/source/ray-air/examples/torch_incremental_learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@
"\n",
" batch_predictor = BatchPredictor.from_checkpoint(checkpoint, predictor_cls=TorchPredictor, model=SimpleMLP(num_classes=10))\n",
" model_output = batch_predictor.predict(\n",
" data=test_dataset, feature_columns=[\"image\"], unsqueeze=False\n",
" data=test_dataset, feature_columns=[\"image\"], keep_columns=[\"label\"]\n",
" )\n",
" \n",
" # Postprocess model outputs.\n",
Expand All @@ -654,12 +654,10 @@
" \n",
" # Then, for each prediction output, see if it matches with the ground truth\n",
" # label.\n",
" zipped_dataset = test_dataset.zip(prediction_results)\n",
"\n",
" def calculate_prediction_scores(df):\n",
" return pd.DataFrame({\"correct\": df[\"predictions\"] == df[\"label\"]})\n",
"\n",
" correct_dataset = zipped_dataset.map_batches(\n",
" correct_dataset = prediction_results.map_batches(\n",
" calculate_prediction_scores, batch_format=\"pandas\"\n",
" )\n",
"\n",
Expand Down
8 changes: 4 additions & 4 deletions python/ray/experimental/state/api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import logging
import threading
import urllib
import warnings
import threading
import logging
from dataclasses import fields
from typing import Dict, Generator, List, Optional, Tuple, Union, Any
from contextlib import contextmanager
from dataclasses import fields
from typing import Any, Dict, Generator, List, Optional, Tuple, Union

import requests

Expand Down
58 changes: 49 additions & 9 deletions python/ray/train/batch_predictor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import inspect
from typing import Any, Dict, Optional, Type, Union
from typing import Any, Dict, Optional, List, Type, Union

import ray
from ray.air import Checkpoint
Expand Down Expand Up @@ -45,6 +45,8 @@ def predict(
self,
data: Union[ray.data.Dataset, ray.data.DatasetPipeline],
*,
feature_columns: Optional[List[str]] = None,
keep_columns: Optional[List[str]] = None,
batch_size: int = 4096,
min_scoring_workers: int = 1,
max_scoring_workers: Optional[int] = None,
Expand All @@ -61,24 +63,41 @@ def predict(
>>> from ray.air import Checkpoint
>>> from ray.train.predictor import Predictor
>>> from ray.train.batch_predictor import BatchPredictor
>>> # Create a dummy predictor that always returns `42` for each input.
>>> # Create a dummy predictor that returns identity as the predictions.
>>> class DummyPredictor(Predictor):
... @classmethod
... def from_checkpoint(cls, checkpoint, **kwargs):
... return cls()
... def predict(self, data, **kwargs):
... return pd.DataFrame({"a": [42] * len(data)})
... def _predict_pandas(self, data_df, **kwargs):
... return data_df
>>> # Create a batch predictor for this dummy predictor.
>>> batch_pred = BatchPredictor( # doctest: +SKIP
... Checkpoint.from_dict({"x": 0}), DummyPredictor)
>>> # Create a dummy dataset.
>>> ds = ray.data.range_tensor(1000, parallelism=4) # doctest: +SKIP
>>> ds = ray.data.from_pandas(pd.DataFrame({ # doctest: +SKIP
... "feature_1": [1, 2, 3], "label": [1, 2, 3]}))
>>> # Execute batch prediction using this predictor.
>>> print(batch_pred.predict(ds)) # doctest: +SKIP
Dataset(num_blocks=4, num_rows=1000, schema={a: int64})
>>> predictions = batch_pred.predict(ds, # doctest: +SKIP
... feature_columns=["feature_1"], keep_columns=["label"])
>>> print(predictions) # doctest: +SKIP
Dataset(num_blocks=1, num_rows=3, schema={a: int64, label: int64})
>>> # Calculate final accuracy.
>>> def calculate_accuracy(df):
... return pd.DataFrame({"correct": df["predictions"] == df["label"]})
>>> correct = predictions.map_batches(calculate_accuracy) # doctest: +SKIP
>>> print("Final accuracy: ", # doctest: +SKIP
... correct.sum(on="correct") / correct.count())
Final accuracy: 1.0000
Args:
data: Ray dataset or pipeline to run batch prediction on.
feature_columns: List of columns in data to use for prediction. Columns not
specified will be dropped from `data` before being passed to the
predictor. If None, use all columns.
keep_columns: List of columns in `data` to include in the prediction result.
This is useful for calculating final accuracies/metrics on the result
dataset. If None, the columns in the output dataset will contain just
the prediction results.
batch_size: Split dataset into batches of this size for prediction.
min_scoring_workers: Minimum number of scoring actors.
max_scoring_workers: If set, specify the maximum number of scoring actors.
Expand Down Expand Up @@ -112,7 +131,15 @@ def __init__(self):
)

def __call__(self, batch):
prediction_output = self.predictor.predict(batch, **predict_kwargs)
if feature_columns:
prediction_batch = batch[feature_columns]
else:
prediction_batch = batch
prediction_output = self.predictor.predict(
prediction_batch, **predict_kwargs
)
if keep_columns:
prediction_output[keep_columns] = batch[keep_columns]
return convert_batch_type_to_pandas(prediction_output)

compute = ray.data.ActorPoolStrategy(
Expand All @@ -123,21 +150,25 @@ def __call__(self, batch):
ray_remote_args["num_cpus"] = num_cpus_per_worker
ray_remote_args["num_gpus"] = num_gpus_per_worker

return data.map_batches(
prediction_results = data.map_batches(
ScoringWrapper,
compute=compute,
batch_format="pandas",
batch_size=batch_size,
**ray_remote_args,
)

return prediction_results

def predict_pipelined(
self,
data: ray.data.Dataset,
*,
blocks_per_window: Optional[int] = None,
bytes_per_window: Optional[int] = None,
# The remaining args are from predict().
feature_columns: Optional[List[str]] = None,
keep_columns: Optional[List[str]] = None,
batch_size: int = 4096,
min_scoring_workers: int = 1,
max_scoring_workers: Optional[int] = None,
Expand Down Expand Up @@ -188,6 +219,13 @@ def predict_pipelined(
This will be treated as an upper bound for the window size, but each
window will still include at least one block. This is mutually
exclusive with ``blocks_per_window``.
feature_columns: List of columns in data to use for prediction. Columns not
specified will be dropped from `data` before being passed to the
predictor. If None, use all columns.
keep_columns: List of columns in `data` to include in the prediction result.
This is useful for calculating final accuracies/metrics on the result
dataset. If None, the columns in the output dataset will contain just
the prediction results.
batch_size: Split dataset into batches of this size for prediction.
min_scoring_workers: Minimum number of scoring actors.
max_scoring_workers: If set, specify the maximum number of scoring actors.
Expand Down Expand Up @@ -215,6 +253,8 @@ def predict_pipelined(
return self.predict(
pipe,
batch_size=batch_size,
feature_columns=feature_columns,
keep_columns=keep_columns,
min_scoring_workers=min_scoring_workers,
max_scoring_workers=max_scoring_workers,
num_cpus_per_worker=num_cpus_per_worker,
Expand Down
31 changes: 31 additions & 0 deletions python/ray/train/tests/test_batch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,37 @@ def test_batch_prediction_fs():
)


def test_batch_prediction_feature_cols():
batch_predictor = BatchPredictor.from_checkpoint(
Checkpoint.from_dict({"factor": 2.0}), DummyPredictor
)

test_dataset = ray.data.from_pandas(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))

assert batch_predictor.predict(
test_dataset, feature_columns=["a"]
).to_pandas().to_numpy().squeeze().tolist() == [4.0, 8.0, 12.0]


def test_batch_prediction_keep_cols():
batch_predictor = BatchPredictor.from_checkpoint(
Checkpoint.from_dict({"factor": 2.0}), DummyPredictor
)

test_dataset = ray.data.from_pandas(
pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]})
)

output_df = batch_predictor.predict(
test_dataset, feature_columns=["a"], keep_columns=["b"]
).to_pandas()

assert set(output_df.columns) == {"a", "b"}

assert output_df["a"].tolist() == [4.0, 8.0, 12.0]
assert output_df["b"].tolist() == [4, 5, 6]


def test_automatic_enable_gpu_from_num_gpus_per_worker():
"""
Test we automatically set underlying Predictor creation use_gpu to True if
Expand Down

0 comments on commit 6595bd6

Please sign in to comment.