From 31fefa20b746e2ae8d9e9af3c72e61112bcb9725 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Sun, 25 Feb 2018 13:26:58 -0800 Subject: [PATCH] [tune] HyperBand Fixes (#1586) --- python/ray/tune/hyperband.py | 34 ++++++++++++++------ python/ray/tune/test/trial_scheduler_test.py | 11 ++----- 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/python/ray/tune/hyperband.py b/python/ray/tune/hyperband.py index abca5d01cc669..f0efe99cb638d 100644 --- a/python/ray/tune/hyperband.py +++ b/python/ray/tune/hyperband.py @@ -38,7 +38,7 @@ class HyperBandScheduler(FIFOScheduler): algorithm. It divides trials into brackets of varying sizes, and periodically early stops low-performing trials within each bracket. - To use this implementation of HyperBand with Ray.tune, all you need + To use this implementation of HyperBand with Ray Tune, all you need to do is specify the max length of time a trial can run `max_t`, the time units `time_attr`, and the name of the reported objective value `reward_attr`. We automatically determine reasonable values for the other @@ -164,7 +164,7 @@ def _process_bracket(self, trial_runner, bracket, trial): if bracket.cur_iter_done(): if bracket.finished(): bracket.cleanup_full(trial_runner) - return TrialScheduler.CONTINUE + return TrialScheduler.STOP good, bad = bracket.successive_halving(self._reward_attr) # kill bad trials @@ -225,6 +225,22 @@ def choose_trial_to_run(self, trial_runner): return None def debug_string(self): + """This provides a progress notification for the algorithm. + + For each bracket, the algorithm will output a string as follows: + + Bracket(Max Size (n)=5, Milestone (r)=33, completed=14.6%): + {PENDING: 2, RUNNING: 3, TERMINATED: 2} + + "Max Size" indicates the max number of pending/running experiments + set according to the Hyperband algorithm. + + "Milestone" indicates the iterations a trial will run for before + the next halving will occur. + + "Completed" indicates an approximate progress metric. Some brackets, + like ones that are unfilled, will not reach 100%. + """ out = "Using HyperBand: " out += "num_stopped={} total_brackets={}".format( self._num_stopped, sum(len(band) for band in self._hyperbands)) @@ -367,11 +383,11 @@ def _calculate_total_work(self, n, r, s): def __repr__(self): status = ", ".join([ - "n={}".format(self._n), - "r={}".format(self._r), - "completed={}%".format(int(100 * self.completion_percentage())) + "Max Size (n)={}".format(self._n), + "Milestone (r)={}".format(self._r), + "completed={:.1%}".format(self.completion_percentage()) ]) - counts = collections.Counter() - for t in self._all_trials: - counts[t.status] += 1 - return "Bracket({}): {}".format(status, dict(counts)) + counts = collections.Counter([t.status for t in self._all_trials]) + trial_statuses = ", ".join(sorted( + ["{}: {}".format(k, v) for k, v in counts.items()])) + return "Bracket({}): {{{}}} ".format(status, trial_statuses) diff --git a/python/ray/tune/test/trial_scheduler_test.py b/python/ray/tune/test/trial_scheduler_test.py index a73647c4c26b2..9ada2b04d5d4c 100644 --- a/python/ray/tune/test/trial_scheduler_test.py +++ b/python/ray/tune/test/trial_scheduler_test.py @@ -327,9 +327,9 @@ def testHalvingStop(self): self.assertEqual(action, TrialScheduler.STOP) - def testContinueLastOne(self): + def testStopsLastOne(self): stats = self.default_statistics() - num_trials = stats[str(0)]["n"] + num_trials = stats[str(0)]["n"] # setup one bracket sched, mock_runner = self.schedulerSetup(num_trials) big_bracket = sched._state["bracket"] for trl in big_bracket.current_trials(): @@ -342,12 +342,7 @@ def testContinueLastOne(self): mock_runner, trl, result(cur_units, i)) mock_runner.process_action(trl, action) - self.assertEqual(action, TrialScheduler.CONTINUE) - - for x in range(100): - action = sched.on_trial_result( - mock_runner, trl, result(cur_units + x, 10)) - self.assertEqual(action, TrialScheduler.CONTINUE) + self.assertEqual(action, TrialScheduler.STOP) def testTrialErrored(self): """If a trial errored, make sure successive halving still happens"""