Skip to content

Commit

Permalink
[docs][ml][kuberay] Add a --disable-check flag to the XGBoost benchma…
Browse files Browse the repository at this point in the history
…rk. (ray-project#27277)

This PR adds a flag --disable-check to the XGBoost benchmark script which disables the RuntimeError that comes up if training or prediction took too long. This is meant for non-CI exploratory use-cases.

Specifically, the reason is this:
We will include the XGBoost benchmark as an example workload for the KubeRay documentation.
The actual performance of the workload is highly sensitive to infrastructure environment, so we won't want to raise an alarming RuntimeError if the workload took too long on the user's infrastructure.
(When I tried the 100Gb benchmark on KubeRay, training ran just a couple of minutes longer than the 1000 second cutoff.)
  • Loading branch information
DmitriGekhtman committed Jul 29, 2022
1 parent 1a10b53 commit 8bdeb30
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions release/air_tests/air_benchmarks/workloads/xgboost_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,23 +122,32 @@ def main(args):
with open(test_output_json, "wt") as f:
json.dump(result, f)

if training_time > _TRAINING_TIME_THRESHOLD:
raise RuntimeError(
f"Training on XGBoost is taking {training_time} seconds, "
f"which is longer than expected ({_TRAINING_TIME_THRESHOLD} seconds)."
)
if not args.disable_check:
if training_time > _TRAINING_TIME_THRESHOLD:
raise RuntimeError(
f"Training on XGBoost is taking {training_time} seconds, "
f"which is longer than expected ({_TRAINING_TIME_THRESHOLD} seconds)."
)

if prediction_time > _PREDICTION_TIME_THRESHOLD:
raise RuntimeError(
f"Batch prediction on XGBoost is taking {prediction_time} seconds, "
f"which is longer than expected ({_PREDICTION_TIME_THRESHOLD} seconds)."
)
if prediction_time > _PREDICTION_TIME_THRESHOLD:
raise RuntimeError(
f"Batch prediction on XGBoost is taking {prediction_time} seconds, "
f"which is longer than expected ({_PREDICTION_TIME_THRESHOLD} seconds)."
)


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--size", type=str, choices=["10G", "100G"], default="100G")
# Add a flag for disabling the timeout error.
# Use case: running the benchmark as a documented example, in infra settings
# different from the formal benchmark's EC2 setup.
parser.add_argument(
"--disable-check",
action="store_true",
help="disable runtime error on benchmark timeout",
)
args = parser.parse_args()
main(args)

0 comments on commit 8bdeb30

Please sign in to comment.