Skip to content

Commit

Permalink
[release] update modin_xgboost_test to use anyscale connect (ray-proj…
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdeng committed Jul 8, 2021
1 parent cc21535 commit 264e2df
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 23 deletions.
3 changes: 2 additions & 1 deletion release/golden_notebook_tests/golden_notebook_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
compute_template: compute_tpl.yaml

run:
timeout: 900
use_connect: True
timeout: 1800
script: python workloads/modin_xgboost_test.py

- name: torch_tune_serve_test
Expand Down
8 changes: 4 additions & 4 deletions release/golden_notebook_tests/modin_xgboost_app_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ debian_packages:
python:
pip_packages:
- pytest
- xgboost_ray
- modin
- s3fs
conda_packages: [ ]

post_build_cmds:
- pip uninstall -y ray || true
- pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }}
- pip uninstall -y modin || true
- pip3 install -U git+https://github.com/modin-project/modin
- pip install -U {{ env["RAY_WHEELS"] | default("ray") }}
- pip install git+https://github.com/ray-project/xgboost_ray.git#xgboost_ray
30 changes: 20 additions & 10 deletions release/golden_notebook_tests/workloads/modin_xgboost_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,21 @@
import ray
from xgboost_ray import RayDMatrix, RayParams, train

FILE_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases/" \
"00280/HIGGS.csv.gz"
from utils.utils import is_anyscale_connect

parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing.")
args = parser.parse_args()
HIGGS_S3_URI = "s3:https://ray-ci-higgs/HIGGS.csv"
SIMPLE_HIGGS_S3_URI = "s3:https://ray-ci-higgs/simpleHIGGS.csv"


def main():
ray.client("anyscale:https://").connect()

print("Loading HIGGS data.")

colnames = ["label"] + ["feature-%02d" % i for i in range(1, 29)]

if args.smoke_test:
data = pd.read_csv(FILE_URL, names=colnames, nrows=1000)
data = pd.read_csv(SIMPLE_HIGGS_S3_URI, names=colnames)
else:
data = pd.read_csv(FILE_URL, names=colnames)
data = pd.read_csv(HIGGS_S3_URI, names=colnames)

print("Loaded HIGGS data.")

Expand All @@ -52,8 +47,23 @@ def main():


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test",
action="store_true",
help="Finish quickly for testing.")
args = parser.parse_args()

start = time.time()

client_builder = ray.client()
if is_anyscale_connect():
job_name = os.environ.get("RAY_JOB_NAME", "modin_xgboost_test")
client_builder.job_name(job_name)
client_builder.connect()

main()

taken = time.time() - start
result = {
"time_taken": taken,
Expand Down
11 changes: 3 additions & 8 deletions release/golden_notebook_tests/workloads/torch_tune_serve_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,7 @@
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import MNIST


def _is_anyscale_connect():
address = os.environ.get("RAY_ADDRESS")
is_anyscale_connect = address is not None and address.startswith(
"anyscale:https://")
return is_anyscale_connect
from utils.utils import is_anyscale_connect


def load_mnist_data(train: bool, download: bool):
Expand Down Expand Up @@ -93,7 +88,7 @@ def train_mnist(test_mode=False, num_workers=1, use_gpu=False):


def get_remote_model(remote_model_checkpoint_path):
if _is_anyscale_connect():
if is_anyscale_connect():
# Download training results to local client.
local_dir = "~/ray_results"
# TODO(matt): remove the following line when Anyscale Connect
Expand Down Expand Up @@ -217,7 +212,7 @@ def test_predictions(test_mode=False):
start = time.time()

client_builder = ray.client()
if (_is_anyscale_connect()):
if is_anyscale_connect():
job_name = os.environ.get("RAY_JOB_NAME", "torch_tune_serve_test")
client_builder.job_name(job_name)
client_builder.connect()
Expand Down
9 changes: 9 additions & 0 deletions release/golden_notebook_tests/workloads/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import os


def is_anyscale_connect():
"""Returns whether or not the Ray Address points to an Anyscale cluster."""
address = os.environ.get("RAY_ADDRESS")
is_anyscale_connect = address is not None and address.startswith(
"anyscale:https://")
return is_anyscale_connect

0 comments on commit 264e2df

Please sign in to comment.