Skip to content

Commit

Permalink
[tune] Fix reuse_actors error on actor cleanup for function trainab…
Browse files Browse the repository at this point in the history
…les (#42951)

This PR fixes a bug caused by reuse actors stopping `FunctionTrainable` actors before their training thread started. Additionally, this PR disables `reuse_actors` by default for all trainable types, due to the nondeterministic and unstable behavior of the number of total actors spawned throughout training which has been reported by many users.

---------

Signed-off-by: Justin Yu <[email protected]>
  • Loading branch information
justinvyu committed Feb 5, 2024
1 parent 248521a commit e3ce49a
Show file tree
Hide file tree
Showing 14 changed files with 45 additions and 90 deletions.
4 changes: 3 additions & 1 deletion python/ray/air/execution/_internal/actor_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def _try_start_actors(self, max_actors: Optional[int] = None) -> int:

# Iterate through all resource requests
for resource_request in self._resource_request_to_pending_actors:
if max_actors and started_actors >= max_actors:
if max_actors is not None and started_actors >= max_actors:
break

# While we have resources ready and there are actors left to schedule
Expand Down Expand Up @@ -401,6 +401,8 @@ def on_error(exception: Exception):

self._enqueue_cached_actor_tasks(tracked_actor=tracked_actor)

started_actors += 1

return started_actors

def _enqueue_cached_actor_tasks(self, tracked_actor: TrackedActor):
Expand Down
17 changes: 9 additions & 8 deletions python/ray/train/_internal/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,10 @@ def pause_reporting(self):
"""Ignore all future ``session.report()`` calls."""
self.ignore_report = True

def finish(self, timeout: Optional[float] = None):
def finish(self, timeout: Optional[float] = None) -> Optional[Any]:
"""Finishes the training thread.
Either returns the output from training or raises any Exception from
training.
Raises any Exception from training.
"""
# Set the stop event for the training thread to gracefully exit.
self.stop_event.set()
Expand All @@ -244,11 +243,13 @@ def finish(self, timeout: Optional[float] = None):
self.storage.persist_artifacts(force=True)

# Wait for training to finish.
# This will raise any errors that occur during training, including
# SystemError
func_output = self.training_thread.join(timeout=timeout)
# If training finished successfully, then return results.
return func_output
# This will raise any errors that occur during training, including SystemError
# This returns the result of the training function.
output = None
if self.training_started:
output = self.training_thread.join(timeout=timeout)

return output

def get_next(self) -> Optional[_TrainingResult]:
"""Gets the next ``_TrainingResult`` from the result queue.
Expand Down
35 changes: 19 additions & 16 deletions python/ray/train/tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,40 +217,43 @@ def test_train_failure(ray_start_2_cpus):
assert e.finish_training() == [1, 1]


def test_train_single_worker_failure(ray_start_2_cpus):
"""Tests if training fails immediately if only one worker raises an Exception."""
def test_single_worker_user_failure(ray_start_2_cpus):
"""Tests if training fails immediately if one worker raises an Exception
while executing the user training code."""
config = TestConfig()
e = BackendExecutor(config, num_workers=2)
e.start()

def single_worker_fail():
def single_worker_user_failure():
if train.get_context().get_world_rank() == 0:
raise ValueError
raise RuntimeError
else:
time.sleep(1000000)

_start_training(e, single_worker_fail)
_start_training(e, single_worker_user_failure)

with pytest.raises(StartTraceback) as exc:
e.get_next_results()
assert isinstance(exc.value.__cause__, ValueError)
assert isinstance(exc.value.__cause__, RuntimeError)


# TODO(@justinvyu: fix test and/or deprecate relevant code path)
@pytest.mark.skip("Mocked execute_async doesn't work as intended")
def test_worker_failure(ray_start_2_cpus):
def test_single_worker_actor_failure(ray_start_2_cpus):
"""Tests is training fails immediately if one worker actor dies."""
config = TestConfig()
e = BackendExecutor(config, num_workers=2)
e.start()

def train_fail():
ray.actor.exit_actor()
def single_worker_actor_failure():
if train.get_context().get_world_rank() == 0:
# Simulate actor failure
os._exit(1)
else:
time.sleep(1000)

new_execute_func = gen_execute_special(train_fail)
with patch.object(WorkerGroup, "execute_async", new_execute_func):
with pytest.raises(TrainingWorkerError):
_start_training(e, lambda: 1)
e.finish_training()
_start_training(e, single_worker_actor_failure)

with pytest.raises(TrainingWorkerError):
e.get_next_results()


def test_tensorflow_start(ray_start_2_cpus):
Expand Down
3 changes: 1 addition & 2 deletions python/ray/train/tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ def test_world_size(session):

def test_train(session):
session.start()
output = session.finish()
assert output == 1
session.finish()


def test_get_dataset_shard(shutdown_only):
Expand Down
1 change: 1 addition & 0 deletions python/ray/tune/examples/pb2_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
metric="mean_accuracy",
mode="max",
num_samples=8,
reuse_actors=True,
),
param_space={
"lr": 0.0001,
Expand Down
2 changes: 1 addition & 1 deletion python/ray/tune/examples/pb2_ppo_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def explore(config):


if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument("--max", type=int, default=1000000)
parser.add_argument("--algo", type=str, default="PPO")
Expand Down Expand Up @@ -108,6 +107,7 @@ def explore(config):
scheduler=methods[args.method],
verbose=1,
num_samples=args.num_samples,
reuse_actors=True,
stop={args.criteria: args.max},
config={
"env": args.env_name,
Expand Down
1 change: 1 addition & 0 deletions python/ray/tune/examples/pbt_convnet_function_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def stop_all(self):
metric="mean_accuracy",
mode="max",
num_samples=4,
reuse_actors=True,
),
param_space={
"lr": tune.uniform(0.001, 1),
Expand Down
1 change: 1 addition & 0 deletions python/ray/tune/examples/pbt_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def run_tune_pbt(smoke_test=False):
metric="mean_accuracy",
mode="max",
num_samples=8,
reuse_actors=True,
),
param_space={
"lr": 0.0001,
Expand Down
1 change: 1 addition & 0 deletions python/ray/tune/examples/pbt_memnn_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def load_checkpoint(self, checkpoint_dir):
metric="mean_accuracy",
mode="max",
num_samples=2,
reuse_actors=True,
),
param_space={
"finish_fast": args.smoke_test,
Expand Down
2 changes: 1 addition & 1 deletion python/ray/tune/examples/pbt_ppo_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@


if __name__ == "__main__":

# Postprocess the perturbed config to ensure it's still valid
def explore(config):
# ensure we collect enough timesteps to do sgd
Expand Down Expand Up @@ -54,6 +53,7 @@ def explore(config):
num_samples=8,
metric="episode_reward_mean",
mode="max",
reuse_actors=True,
),
param_space={
"env": "Humanoid-v1",
Expand Down
1 change: 1 addition & 0 deletions python/ray/tune/examples/pbt_tune_cifar10_with_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def cleanup(self):
num_samples=4,
metric="mean_accuracy",
mode="max",
reuse_actors=True,
),
param_space=space,
)
Expand Down
18 changes: 0 additions & 18 deletions python/ray/tune/tests/test_actor_reuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,24 +190,6 @@ def test_trial_reuse_disabled(trainable, ray_start_1_cpu):
assert [t.last_result["num_resets"] for t in trials] == [0, 0, 0, 0]


def test_trial_reuse_disabled_per_default(trainable, ray_start_1_cpu):
"""Test that reuse=None disables actor re-use for class trainables.
Setup: Pass `reuse_actors=None` to tune.run()
We assert the `num_resets` of each trainable class to be 0 (no reuse).
"""
analysis = _run_trials_with_frequent_pauses(trainable, reuse=None)
trials = analysis.trials
assert [t.last_result["id"] for t in trials] == [0, 1, 2, 3]
assert [t.last_result["iter"] for t in trials] == [2, 2, 2, 2]
if inspect.isclass(trainable):
assert [t.last_result["num_resets"] for t in trials] == [0, 0, 0, 0]
else:
# reuse=None defaults to True for fn trainables
assert [t.last_result["num_resets"] for t in trials] == [4, 5, 6, 7]


def test_trial_reuse_enabled(trainable, ray_start_1_cpu):
"""Test that reuse=True enables actor re-use.
Expand Down
43 changes: 4 additions & 39 deletions python/ray/tune/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,13 @@
_prepare_progress_reporter_for_ray_client,
_stream_client_output,
)
from ray.tune.registry import get_trainable_cls, is_function_trainable
from ray.tune.registry import get_trainable_cls

# Must come last to avoid circular imports
from ray.tune.schedulers import (
FIFOScheduler,
PopulationBasedTraining,
PopulationBasedTrainingReplay,
ResourceChangingScheduler,
TrialScheduler,
)
from ray.tune.schedulers.util import (
Expand Down Expand Up @@ -260,7 +259,7 @@ def run(
fail_fast: bool = False,
restore: Optional[str] = None,
resume: Union[bool, str] = False,
reuse_actors: Optional[bool] = None,
reuse_actors: bool = False,
raise_on_failed_trial: bool = True,
callbacks: Optional[Sequence[Callback]] = None,
max_concurrent_trials: Optional[int] = None,
Expand Down Expand Up @@ -436,8 +435,7 @@ def run(
when possible. This can drastically speed up experiments that start
and stop actors often (e.g., PBT in time-multiplexing mode). This
requires trials to have the same resource requirements.
Defaults to ``True`` for function trainables and ``False`` for
class and registered trainables.
Defaults to ``False``.
raise_on_failed_trial: Raise TuneError if there exists failed
trial (of ERROR state) when the experiments complete.
callbacks: List of callbacks that will be called at different
Expand Down Expand Up @@ -696,39 +694,6 @@ class and registered trainables.
)
os.environ["TUNE_RESULT_BUFFER_LENGTH"] = "1"

# If reuse_actors is unset, default to False for string and class trainables,
# and default to True for everything else (i.e. function trainables)
if reuse_actors is None:
trainable = (
run_or_experiment.run_identifier
if isinstance(run_or_experiment, Experiment)
else run_or_experiment
)
reuse_actors = (
# Only default to True for function trainables that meet certain conditions
is_function_trainable(trainable)
and not (
# Changing resources requires restarting actors
scheduler
and isinstance(scheduler, ResourceChangingScheduler)
)
and not (
# If GPUs are requested we could run into problems with device memory
_check_gpus_in_resources(resources_per_trial)
)
and not (
# If the resource request is overridden, we don't know if GPUs
# will be requested, yet, so default to False
_check_default_resources_override(trainable)
)
and not (
# Mixins do not work with reuse_actors as the mixin setup will only
# be invoked once
_check_mixin(trainable)
)
)
logger.debug(f"Auto-detected `reuse_actors={reuse_actors}`")

if (
isinstance(scheduler, (PopulationBasedTraining, PopulationBasedTrainingReplay))
and not reuse_actors
Expand Down Expand Up @@ -1067,7 +1032,7 @@ def run_experiments(
verbose: Optional[Union[int, AirVerbosity, Verbosity]] = None,
progress_reporter: Optional[ProgressReporter] = None,
resume: Union[bool, str] = False,
reuse_actors: Optional[bool] = None,
reuse_actors: bool = False,
raise_on_failed_trial: bool = True,
concurrent: bool = True,
callbacks: Optional[Sequence[Callback]] = None,
Expand Down
6 changes: 2 additions & 4 deletions python/ray/tune/tune_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ class TuneConfig:
when possible. This can drastically speed up experiments that start
and stop actors often (e.g., PBT in time-multiplexing mode). This
requires trials to have the same resource requirements.
Defaults to ``True`` for function trainables (including most
Ray Train Trainers) and ``False`` for class and registered trainables
(e.g. RLlib).
Defaults to ``False``.
trial_name_creator: Optional function that takes in a Trial and returns
its name (i.e. its string representation). Be sure to include some unique
identifier (such as `Trial.trial_id`) in each trial's name.
Expand All @@ -71,7 +69,7 @@ class TuneConfig:
num_samples: int = 1
max_concurrent_trials: Optional[int] = None
time_budget_s: Optional[Union[int, float, datetime.timedelta]] = None
reuse_actors: Optional[bool] = None
reuse_actors: bool = False
trial_name_creator: Optional[Callable[[Trial], str]] = None
trial_dirname_creator: Optional[Callable[[Trial], str]] = None
chdir_to_trial_dir: bool = _DEPRECATED_VALUE

0 comments on commit e3ce49a

Please sign in to comment.