Skip to content

Commit

Permalink
[tune] HyperBand Fixes (ray-project#1586)
Browse files Browse the repository at this point in the history
  • Loading branch information
richardliaw committed Feb 25, 2018
1 parent 2026c14 commit 31fefa2
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 17 deletions.
34 changes: 25 additions & 9 deletions python/ray/tune/hyperband.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class HyperBandScheduler(FIFOScheduler):
algorithm. It divides trials into brackets of varying sizes, and
periodically early stops low-performing trials within each bracket.
To use this implementation of HyperBand with Ray.tune, all you need
To use this implementation of HyperBand with Ray Tune, all you need
to do is specify the max length of time a trial can run `max_t`, the time
units `time_attr`, and the name of the reported objective value
`reward_attr`. We automatically determine reasonable values for the other
Expand Down Expand Up @@ -164,7 +164,7 @@ def _process_bracket(self, trial_runner, bracket, trial):
if bracket.cur_iter_done():
if bracket.finished():
bracket.cleanup_full(trial_runner)
return TrialScheduler.CONTINUE
return TrialScheduler.STOP

good, bad = bracket.successive_halving(self._reward_attr)
# kill bad trials
Expand Down Expand Up @@ -225,6 +225,22 @@ def choose_trial_to_run(self, trial_runner):
return None

def debug_string(self):
"""This provides a progress notification for the algorithm.
For each bracket, the algorithm will output a string as follows:
Bracket(Max Size (n)=5, Milestone (r)=33, completed=14.6%):
{PENDING: 2, RUNNING: 3, TERMINATED: 2}
"Max Size" indicates the max number of pending/running experiments
set according to the Hyperband algorithm.
"Milestone" indicates the iterations a trial will run for before
the next halving will occur.
"Completed" indicates an approximate progress metric. Some brackets,
like ones that are unfilled, will not reach 100%.
"""
out = "Using HyperBand: "
out += "num_stopped={} total_brackets={}".format(
self._num_stopped, sum(len(band) for band in self._hyperbands))
Expand Down Expand Up @@ -367,11 +383,11 @@ def _calculate_total_work(self, n, r, s):

def __repr__(self):
status = ", ".join([
"n={}".format(self._n),
"r={}".format(self._r),
"completed={}%".format(int(100 * self.completion_percentage()))
"Max Size (n)={}".format(self._n),
"Milestone (r)={}".format(self._r),
"completed={:.1%}".format(self.completion_percentage())
])
counts = collections.Counter()
for t in self._all_trials:
counts[t.status] += 1
return "Bracket({}): {}".format(status, dict(counts))
counts = collections.Counter([t.status for t in self._all_trials])
trial_statuses = ", ".join(sorted(
["{}: {}".format(k, v) for k, v in counts.items()]))
return "Bracket({}): {{{}}} ".format(status, trial_statuses)
11 changes: 3 additions & 8 deletions python/ray/tune/test/trial_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,9 @@ def testHalvingStop(self):

self.assertEqual(action, TrialScheduler.STOP)

def testContinueLastOne(self):
def testStopsLastOne(self):
stats = self.default_statistics()
num_trials = stats[str(0)]["n"]
num_trials = stats[str(0)]["n"] # setup one bracket
sched, mock_runner = self.schedulerSetup(num_trials)
big_bracket = sched._state["bracket"]
for trl in big_bracket.current_trials():
Expand All @@ -342,12 +342,7 @@ def testContinueLastOne(self):
mock_runner, trl, result(cur_units, i))
mock_runner.process_action(trl, action)

self.assertEqual(action, TrialScheduler.CONTINUE)

for x in range(100):
action = sched.on_trial_result(
mock_runner, trl, result(cur_units + x, 10))
self.assertEqual(action, TrialScheduler.CONTINUE)
self.assertEqual(action, TrialScheduler.STOP)

def testTrialErrored(self):
"""If a trial errored, make sure successive halving still happens"""
Expand Down

0 comments on commit 31fefa2

Please sign in to comment.