Skip to content

Commit

Permalink
[release] update torch_tune_serve_test to use anyscale connect (#16754)
Browse files Browse the repository at this point in the history
* [release] update torch_tune_serve_test to use anyscale connect

* use download_results to download model checkpoint

* clean up code to support both OSS and Anyscale
  • Loading branch information
matthewdeng committed Jul 7, 2021
1 parent 7318a21 commit 23088bd
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
1 change: 1 addition & 0 deletions release/golden_notebook_tests/golden_notebook_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
compute_template: gpu_tpl.yaml

run:
use_connect: True
timeout: 1800
script: python workloads/torch_tune_serve_test.py

45 changes: 41 additions & 4 deletions release/golden_notebook_tests/workloads/torch_tune_serve_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@
from torchvision.datasets import MNIST


def _is_anyscale_connect():
address = os.environ.get("RAY_ADDRESS")
is_anyscale_connect = address is not None and address.startswith(
"anyscale:https://")
return is_anyscale_connect


def load_mnist_data(train: bool, download: bool):
transform = transforms.Compose(
[transforms.ToTensor(),
Expand Down Expand Up @@ -85,8 +92,33 @@ def train_mnist(test_mode=False, num_workers=1, use_gpu=False):
checkpoint_at_end=True)


def get_best_model(best_model_checkpoint_path):
model_state = torch.load(best_model_checkpoint_path)
def get_remote_model(remote_model_checkpoint_path):
if _is_anyscale_connect():
# Download training results to local client.
local_dir = "~/ray_results"
# TODO(matt): remove the following line when Anyscale Connect
# supports tilde expansion.
local_dir = os.path.expanduser(local_dir)
remote_dir = "/home/ray/ray_results/"
ray.client().download_results(
local_dir=local_dir, remote_dir=remote_dir)

# Compute local path.
rel_model_checkpoint_path = os.path.relpath(
remote_model_checkpoint_path, remote_dir)
local_model_checkpoint_path = os.path.join(local_dir,
rel_model_checkpoint_path)

# Load model reference.
return get_model(local_model_checkpoint_path)
else:
get_best_model_remote = ray.remote(get_model)
return ray.get(
get_best_model_remote.remote(remote_model_checkpoint_path))


def get_model(model_checkpoint_path):
model_state = torch.load(model_checkpoint_path)

model = ResNet18(None)
model.conv1 = nn.Conv2d(
Expand Down Expand Up @@ -184,7 +216,12 @@ def test_predictions(test_mode=False):

start = time.time()

ray.client("anyscale:https://").connect()
client_builder = ray.client()
if (_is_anyscale_connect()):
job_name = os.environ.get("RAY_JOB_NAME", "torch_tune_serve_test")
client_builder.job_name(job_name)
client_builder.connect()

num_workers = 2
use_gpu = True

Expand All @@ -193,7 +230,7 @@ def test_predictions(test_mode=False):

print("Retrieving best model.")
best_checkpoint = analysis.best_checkpoint
model_id = get_best_model(best_checkpoint)
model_id = get_remote_model(best_checkpoint)

print("Setting up Serve.")
setup_serve(model_id)
Expand Down

0 comments on commit 23088bd

Please sign in to comment.