Skip to content

Commit

Permalink
[air/docs] checkpoints (ray-project#25901)
Browse files Browse the repository at this point in the history
  • Loading branch information
richardliaw committed Jul 12, 2022
1 parent 1abe908 commit 92efc85
Show file tree
Hide file tree
Showing 7 changed files with 253 additions and 1 deletion.
1 change: 1 addition & 0 deletions doc/source/_toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ parts:
sections:
- file: ray-air/preprocessors
- file: ray-air/check-ingest
- file: ray-air/checkpoints
- file: ray-air/examples/analyze_tuning_results
- file: ray-air/examples/upload_to_comet_ml
- file: ray-air/examples/upload_to_wandb
Expand Down
84 changes: 84 additions & 0 deletions doc/source/ray-air/checkpoints.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
Checkpoints
===========

Checkpoints are the common format for models that are used across different components of the Ray AI Runtime.

.. image:: images/checkpoints.jpg

What exactly is a checkpoint?
-----------------------------

The Checkpoint object is a serializable reference to a model. The model can represented in one of three ways:

- a directory located on local (on-disk) storage
- a directory located on external storage (e.g. cloud storage)
- an in-memory dictionary

The flexibility provided in the Checkpoint model representation is useful in distributed environments,
where you may want to recreate the same model on multiple nodes in your Ray cluster for inference
or across different Ray clusters.


Creating a checkpoint
---------------------

There are two ways of generating a checkpoint.

The first way is to generate it from a pretrained model. Each framework that AIR supports has a ``to_air_checkpoint`` method that can be used to generate an AIR checkpoint:

.. literalinclude:: doc_code/checkpoint_usage.py
:language: python
:start-after: __checkpoint_quick_start__
:end-before: __checkpoint_quick_end__


Another way is to retrieve it from the results of a Trainer or a Tuner.

.. literalinclude:: doc_code/checkpoint_usage.py
:language: python
:start-after: __use_trainer_checkpoint_start__
:end-before: __use_trainer_checkpoint_end__

What can I do with a checkpoint?
--------------------------------

Checkpoints can be used to instantiate a :class:`Predictor`, :class:`BatchPredictor`, or :class:`PredictorDeployment`.
Upon usage, the model held by the Checkpoint will be instantiated in memory and used for inference.

Below is an example using a checkpoint in the :class:`BatchPredictor` for scalable batch inference:

.. literalinclude:: doc_code/checkpoint_usage.py
:language: python
:start-after: __batch_pred_start__
:end-before: __batch_pred_end__

Below is an example using a checkpoint in a service for online inference via :class:`PredictorDeployment`:

.. literalinclude:: doc_code/checkpoint_usage.py
:language: python
:start-after: __online_inference_start__
:end-before: __online_inference_end__

The Checkpoint object has methods to translate between different checkpoint storage locations.
With this flexibility, Checkpoint objects can be serialized and used in different contexts
(e.g., on a different process or a different machine):

.. literalinclude:: doc_code/checkpoint_usage.py
:language: python
:start-after: __basic_checkpoint_start__
:end-before: __basic_checkpoint_end__


Example: Using Checkpoints with MLflow
--------------------------------------

MLflow has its own `checkpoint format <https://www.mlflow.org/docs/latest/models.html>`__ called the "MLflow Model". It is a standard format for packaging machine learning models that can be used in a variety of downstream tools.

Below is an example of using MLflow models as a Ray AIR Checkpoint.

.. literalinclude:: doc_code/checkpoint_mlflow.py
:language: python
:start-after: __mlflow_checkpoint_start__
:end-before: __mlflow_checkpoint_end__


28 changes: 28 additions & 0 deletions doc/source/ray-air/doc_code/checkpoint_mlflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# flake8: noqa
# isort: skip_file

# __mlflow_checkpoint_start__
from ray.air.checkpoint import Checkpoint
from sklearn.ensemble import RandomForestClassifier
import mlflow.sklearn

# Create an sklearn classifier
clf = RandomForestClassifier(max_depth=7, random_state=0)
# ... e.g. train model with clf.fit()
# Save model using MLflow
mlflow.sklearn.save_model(clf, "model_directory")

# Create checkpoint object from path
checkpoint = Checkpoint.from_directory("model_directory")

# Write it to some other directory
checkpoint.to_directory("other_directory")
# You can also use `checkpoint.to_uri/from_uri` to
# read from/write to cloud storage

# We can now use MLflow to re-load the model
clf = mlflow.sklearn.load_model("other_directory")

# It is guaranteed that the original data was recovered
assert isinstance(clf, RandomForestClassifier)
# __mlflow_checkpoint_end__
130 changes: 130 additions & 0 deletions doc/source/ray-air/doc_code/checkpoint_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# flake8: noqa
# isort: skip_file

# __checkpoint_quick_start__
from ray.train.tensorflow import to_air_checkpoint
import tensorflow as tf

# This can be a trained model.
def build_model() -> tf.keras.Model:
model = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(input_shape=(1,)),
tf.keras.layers.Dense(1),
]
)
return model


model = build_model()

checkpoint = to_air_checkpoint(model)
# __checkpoint_quick_end__


# __use_trainer_checkpoint_start__
import pandas as pd
import ray
from ray.air import train_test_split
from ray.train.xgboost import XGBoostTrainer


bc_df = pd.read_csv(
"https://air-example-data.s3.us-east-2.amazonaws.com/breast_cancer.csv"
)
dataset = ray.data.from_pandas(bc_df)
# Optionally, read directly from s3
# dataset = ray.data.read_csv("s3:https://air-example-data/breast_cancer.csv")

# Split data into train and validation.
train_dataset, valid_dataset = train_test_split(dataset, test_size=0.3)

trainer = XGBoostTrainer(
scaling_config={"num_workers": 2},
label_column="target",
params={
"objective": "binary:logistic",
"eval_metric": ["logloss", "error"],
},
datasets={"train": train_dataset},
num_boost_round=5,
)

result = trainer.fit()
checkpoint = result.checkpoint
# __use_trainer_checkpoint_end__

# __batch_pred_start__
from ray.train.batch_predictor import BatchPredictor
from ray.train.xgboost import XGBoostPredictor

# Create a test dataset by dropping the target column.
test_dataset = valid_dataset.map_batches(
lambda df: df.drop("target", axis=1), batch_format="pandas"
)

batch_predictor = BatchPredictor.from_checkpoint(checkpoint, XGBoostPredictor)

# Bulk batch prediction.
batch_predictor.predict(test_dataset)
# __batch_pred_end__


# __online_inference_start__
import requests
from fastapi import Request
import pandas as pd

from ray import serve
from ray.serve import PredictorDeployment
from ray.serve.http_adapters import json_request


async def adapter(request: Request):
content = await request.json()
print(content)
return pd.DataFrame.from_dict(content)


serve.start(detached=True)
deployment = PredictorDeployment.options(name="XGBoostService")

deployment.deploy(
XGBoostPredictor, checkpoint, batching_params=False, http_adapter=adapter
)

print(deployment.url)

sample_input = test_dataset.take(1)
sample_input = dict(sample_input[0])

output = requests.post(deployment.url, json=[sample_input]).json()
print(output)
# __online_inference_end__

# __basic_checkpoint_start__
from ray.air.checkpoint import Checkpoint

# Create checkpoint data dict
checkpoint_data = {"data": 123}

# Create checkpoint object from data
checkpoint = Checkpoint.from_dict(checkpoint_data)

# Save checkpoint to a directory on the file system.
path = checkpoint.to_directory()

# This path can then be passed around,
# # e.g. to a different function or a different script.
# You can also use `checkpoint.to_uri/from_uri` to
# read from/write to cloud storage

# In another function or script, recover Checkpoint object from path
checkpoint = Checkpoint.from_directory(path)

# Convert into dictionary again
recovered_data = checkpoint.to_dict()

# It is guaranteed that the original data has been recovered
assert recovered_data == checkpoint_data
# __basic_checkpoint_end__
Binary file added doc/source/ray-air/images/checkpoints.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 9 additions & 0 deletions doc/source/ray-air/user-guides.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ AIR Feature Guides
:type: ref
:text: How to use AIR Preprocessors?
:classes: btn-link btn-block stretched-link
---
:img-top: /ray-overview/images/ray_svg_logo.svg

+++
.. link-button:: /ray-air/checkpoints
:type: ref
:text: What are AIR Checkpoints?
:classes: btn-link btn-block stretched-link


---
:img-top: /ray-overview/images/ray_svg_logo.svg
Expand Down
2 changes: 1 addition & 1 deletion python/ray/air/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

@PublicAPI(stability="alpha")
class Checkpoint:
"""Ray ML Checkpoint.
"""Ray AIR Checkpoint.
This implementation provides methods to translate between
different checkpoint storage locations: Local storage, external storage
Expand Down

0 comments on commit 92efc85

Please sign in to comment.