Skip to content

Commit

Permalink
[Release] Add multi-node, multi-GPU SGD release test (ray-project#16046)
Browse files Browse the repository at this point in the history
  • Loading branch information
amogkam authored May 31, 2021
1 parent 9fa3b9f commit da6f28d
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 25 deletions.
59 changes: 34 additions & 25 deletions python/ray/util/sgd/torch/examples/cifar_pytorch_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,36 @@ def setup(self, config):
train_loader=train_loader, validation_loader=validation_loader)


def train_cifar(num_workers, use_gpu, num_epochs, fp16=False, test_mode=False):
trainer1 = TorchTrainer(
training_operator_cls=CifarTrainingOperator,
initialization_hook=initialization_hook,
num_workers=num_workers,
config={
"lr": 0.1,
"test_mode": test_mode, # subset the data
# this will be split across workers.
BATCH_SIZE: 128 * num_workers
},
use_gpu=use_gpu,
scheduler_step_freq="epoch",
use_fp16=fp16,
use_tqdm=False)
pbar = trange(num_epochs, unit="epoch")
for i in pbar:
info = {"num_steps": 1} if test_mode else {}
info["epoch_idx"] = i
info["num_epochs"] = num_epochs
# Increase `max_retries` to turn on fault tolerance.
trainer1.train(max_retries=1, info=info)
val_stats = trainer1.validate()
pbar.set_postfix(dict(acc=val_stats["val_accuracy"]))

print(trainer1.validate())
trainer1.shutdown()
print("success!")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -125,30 +155,9 @@ def setup(self, config):
num_cpus = 4 if args.smoke_test else None
ray.init(address=args.address, num_cpus=num_cpus, log_to_driver=True)

trainer1 = TorchTrainer(
training_operator_cls=CifarTrainingOperator,
initialization_hook=initialization_hook,
train_cifar(
num_workers=args.num_workers,
config={
"lr": 0.1,
"test_mode": args.smoke_test, # subset the data
# this will be split across workers.
BATCH_SIZE: 128 * args.num_workers
},
use_gpu=args.use_gpu,
scheduler_step_freq="epoch",
use_fp16=args.fp16,
use_tqdm=False)
pbar = trange(args.num_epochs, unit="epoch")
for i in pbar:
info = {"num_steps": 1} if args.smoke_test else {}
info["epoch_idx"] = i
info["num_epochs"] = args.num_epochs
# Increase `max_retries` to turn on fault tolerance.
trainer1.train(max_retries=1, info=info)
val_stats = trainer1.validate()
pbar.set_postfix(dict(acc=val_stats["val_accuracy"]))

print(trainer1.validate())
trainer1.shutdown()
print("success!")
fp16=args.fp16,
num_epochs=args.num_epochs,
test_mode=args.smoke_test)
14 changes: 14 additions & 0 deletions release/nightly_gpu_tests/nightly_gpu_tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Test multi-node, multi-GPU Ray SGD example.
- name: sgd_gpu
owner:
mail: "[email protected]"
slack: "@tune-team"

cluster:
app-config: sgd_gpu/sgd_gpu_app_config.yaml
compute_template: sgd_gpu/sgd_gpu_compute.yaml

run:
timeout: 3000
prepare: python wait_cluster.py 2 600
script: python sgd_gpu/sgd_gpu_test.py --num-workers=2 --use-gpu --address=auto
12 changes: 12 additions & 0 deletions release/nightly_gpu_tests/sgd_gpu/sgd_gpu_app_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
base_image: "anyscale/ray-ml:pinned-nightly"
env_vars: {}
debian_packages: []

python:
pip_packages: []
conda_packages: []

post_build_cmds:
- pip uninstall -y ray
- pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }}
- pip3 install -U ray[tune]
14 changes: 14 additions & 0 deletions release/nightly_gpu_tests/sgd_gpu/sgd_gpu_compute.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
cloud_id: {{env["ANYSCALE_CLOUD_ID"]}}
region: us-west-2

max_workers: 1

head_node_type:
name: head_node
instance_type: g3.4xlarge

worker_node_types:
- name: worker_node
instance_type: g3.4xlarge
min_workers: 1
max_workers: 1
25 changes: 25 additions & 0 deletions release/nightly_gpu_tests/sgd_gpu/sgd_gpu_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import json
import os
import time

import ray
from ray.util.sgd.torch.examples.cifar_pytorch_example import train_cifar

if __name__ == "__main__":
ray.init(address=os.environ.get("RAY_ADDRESS", "auto"))
start_time = time.time()
success = True
try:
train_cifar(
num_workers=2,
use_gpu=True,
num_epochs=5,
fp16=True,
test_mode=False)
except Exception as e:
print(f"The test failed with {e}")
success = False

delta = time.time() - start_time
with open(os.environ["TEST_OUTPUT_JSON"], "w") as f:
f.write(json.dumps({"train_time": delta, "success": success}))

0 comments on commit da6f28d

Please sign in to comment.