Skip to content

Commit

Permalink
[Tune] Allow overwriting a trainable on Tuner restore (#30351)
Browse files Browse the repository at this point in the history
Introduces a way to re-specify a trainable when resuming a Tune experiment using Tuner.restore. Allows restoration of a Tune experiment with a non-serializable trainable (ex: a function Trainable wrapped with tune.with_parameters, an AIR trainer with a non-serializable dataset, a trainable with object references attached).

Signed-off-by: Justin Yu <[email protected]>
  • Loading branch information
justinvyu authored Nov 29, 2022
1 parent 20e4848 commit c54dcc3
Show file tree
Hide file tree
Showing 7 changed files with 377 additions and 52 deletions.
55 changes: 55 additions & 0 deletions python/ray/train/tests/test_tune.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import logging

import pytest

Expand All @@ -21,6 +22,14 @@
from ray.tune.tuner import Tuner


@pytest.fixture
def propagate_logs():
logger = logging.getLogger("ray")
logger.propagate = True
yield
logger.propagate = False


@pytest.fixture
def ray_start_4_cpus():
address_info = ray.init(num_cpus=4)
Expand Down Expand Up @@ -224,6 +233,52 @@ def train_func():
assert len(trial_dfs[0]["training_iteration"]) == 4


def test_restore_with_new_trainer(ray_start_4_cpus, tmpdir, propagate_logs, caplog):
def train_func(config):
raise RuntimeError("failing!")

trainer = DataParallelTrainer(
train_func,
backend_config=TestConfig(),
scaling_config=ScalingConfig(num_workers=1),
run_config=RunConfig(local_dir=str(tmpdir), name="restore_new_trainer"),
datasets={"train": ray.data.from_items([{"a": i} for i in range(10)])},
)
results = Tuner(trainer).fit()
assert results.errors

def train_func(config):
dataset = session.get_dataset_shard("train")
assert session.get_world_size() == 2
assert dataset.count() == 10

trainer = DataParallelTrainer(
# Training function can be modified
train_func,
backend_config=TestConfig(),
# ScalingConfig can be modified
scaling_config=ScalingConfig(num_workers=2),
# New RunConfig will be ignored
run_config=RunConfig(name="ignored"),
# Datasets and preprocessors can be re-specified
datasets={"train": ray.data.from_items([{"a": i} for i in range(20)])},
)
caplog.clear()
with caplog.at_level(logging.WARNING, logger="ray.tune.impl.tuner_internal"):
with pytest.warns() as warn_record:
tuner = Tuner.restore(
str(tmpdir / "restore_new_trainer"),
overwrite_trainable=trainer,
resume_errored=True,
)
# Should warn about the RunConfig being ignored
assert "RunConfig" in str(warn_record[0].message)
assert "The trainable will be overwritten" in caplog.text

results = tuner.fit()
assert not results.errors


if __name__ == "__main__":
import sys

Expand Down
1 change: 1 addition & 0 deletions python/ray/tune/execution/trial_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def _load_trial_from_checkpoint(
if new_local_dir:
trial_cp["local_dir"] = new_local_dir
new_trial.__setstate__(trial_cp)
new_trial.refresh_default_resource_request()
return new_trial


Expand Down
15 changes: 13 additions & 2 deletions python/ray/tune/experiment/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,9 @@ def init_logdir(self):

self.invalidate_json_state()

def update_resources(self, resources: Union[Dict, PlacementGroupFactory]):
def update_resources(
self, resources: Union[Dict, Resources, PlacementGroupFactory]
):
"""EXPERIMENTAL: Updates the resource requirements.
Should only be called when the trial is not running.
Expand All @@ -613,7 +615,7 @@ def update_resources(self, resources: Union[Dict, PlacementGroupFactory]):
placement_group_factory = None
if isinstance(resources, PlacementGroupFactory):
placement_group_factory = resources
else:
elif isinstance(resources, dict):
resources = Resources(**resources)

self.placement_group_factory = _to_pg_factory(
Expand All @@ -622,6 +624,15 @@ def update_resources(self, resources: Union[Dict, PlacementGroupFactory]):

self.invalidate_json_state()

def refresh_default_resource_request(self):
"""Update trial resources according to the trainable's default resource
request, if it is provided."""
trainable_cls = self.get_trainable_cls()
if trainable_cls:
default_resources = trainable_cls.default_resource_request(self.config)
if default_resources:
self.update_resources(default_resources)

def set_runner(self, runner):
self.runner = runner
if runner:
Expand Down
122 changes: 105 additions & 17 deletions python/ray/tune/impl/tuner_internal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import os
import math
import logging
import warnings
import shutil
import tempfile
Expand Down Expand Up @@ -31,6 +32,11 @@
_PARAM_SPACE_KEY = "_param_space"
_EXPERIMENT_ANALYSIS_KEY = "_experiment_analysis"

logger = logging.getLogger(__name__)

TrainableType = Union[str, Callable, Type[Trainable]]
TrainableTypeOrTrainer = Union[TrainableType, "BaseTrainer"]


class TunerInternal:
"""The real implementation behind external facing ``Tuner``.
Expand Down Expand Up @@ -65,14 +71,7 @@ def __init__(
self,
restore_path: str = None,
resume_config: Optional[_ResumeConfig] = None,
trainable: Optional[
Union[
str,
Callable,
Type[Trainable],
"BaseTrainer",
]
] = None,
trainable: Optional[TrainableTypeOrTrainer] = None,
param_space: Optional[Dict[str, Any]] = None,
tune_config: Optional[TuneConfig] = None,
run_config: Optional[RunConfig] = None,
Expand All @@ -88,10 +87,14 @@ def __init__(
self._tune_config = tune_config or TuneConfig()
self._run_config = run_config or RunConfig()

self._missing_params_error_message = None

# Restore from Tuner checkpoint.
if restore_path:
self._restore_from_path_or_uri(
path_or_uri=restore_path, resume_config=resume_config
path_or_uri=restore_path,
resume_config=resume_config,
overwrite_trainable=trainable,
)
return

Expand Down Expand Up @@ -196,8 +199,84 @@ def _maybe_warn_resource_contention(self):
stacklevel=4,
)

def _validate_overwrite_trainable(
self,
original_trainable: TrainableTypeOrTrainer,
overwrite_trainable: Optional[TrainableTypeOrTrainer],
):
"""Determines whether the new `overwrite_trainable` is compatible
with the restored experiment with some basic sanity checks
(ensuring same type and name as the original trainable).
"""

# Check if the trainable was wrapped with `tune.with_parameters`,
# Set the Tuner to fail on fit if the trainable is not re-specified.
trainable_wrapped_params = getattr(
original_trainable, "_attached_param_names", None
)
if trainable_wrapped_params and not overwrite_trainable:
self._missing_params_error_message = (
"The original trainable cannot be used to resume training, since "
"`tune.with_parameters` attached references to objects "
"in the Ray object store that may not exist anymore. "
"You must re-supply the trainable with the same parameters "
f"{trainable_wrapped_params} attached:\n\n"
"from ray import tune\n\n"
"# Reconstruct the trainable with the same parameters\n"
"trainable_with_params = tune.with_parameters(trainable, ...)\n"
"tuner = tune.Tuner.restore(\n"
" ..., overwrite_trainable=trainable_with_params\n"
")\n\nSee https://docs.ray.io/en/master/tune/api_docs/trainable.html"
"#tune-with-parameters for more details."
)
if not overwrite_trainable:
return

error_message = (
"Usage of `overwrite_trainable` is limited to re-specifying the "
"same trainable that was passed to `Tuner`, in the case "
"that the trainable is not serializable (e.g. it holds object references)."
)

if type(original_trainable) != type(overwrite_trainable):
raise ValueError(
f"{error_message}\n"
f"Got new trainable of type {type(overwrite_trainable)} "
f"but expected {type(original_trainable)}."
)

from ray.train.trainer import BaseTrainer

if isinstance(overwrite_trainable, BaseTrainer):
if overwrite_trainable.run_config != original_trainable.run_config:
warnings.warn(
"Overwriting the AIR Trainer with a new `RunConfig` is not "
"supported - the restored experiment will continue with the old "
"config. To avoid this warning, revert changes made to `RunConfig`."
)
overwrite_trainable.run_config = original_trainable.run_config
else:
original_name = Experiment.get_trainable_name(original_trainable)
overwrite_name = Experiment.get_trainable_name(overwrite_trainable)
if original_name != overwrite_name:
raise ValueError(
f"{error_message}\nGot new trainable with identifier "
f"{overwrite_name} but expected {original_name}."
)

logger.warning(
"The trainable will be overwritten - this should be done with caution: "
"it's possible to supply an incompatible trainable, and there are "
"no guarantees that the resumed experiment will continue successfully. "
"If you encounter errors during training, ensure that you are passing "
"in the same trainable that was passed into the initial `Tuner` object."
)

def _restore_from_path_or_uri(
self, path_or_uri: str, resume_config: Optional[_ResumeConfig]
self,
path_or_uri: str,
resume_config: Optional[_ResumeConfig],
overwrite_trainable: Optional[TrainableTypeOrTrainer],
):
# Sync down from cloud storage if needed
synced, experiment_checkpoint_dir = self._maybe_sync_down_tuner_state(
Expand All @@ -223,6 +302,10 @@ def _restore_from_path_or_uri(
tuner = pickle.load(fp)
self.__dict__.update(tuner.__dict__)

self._validate_overwrite_trainable(trainable, overwrite_trainable)
if overwrite_trainable:
trainable = overwrite_trainable

self._is_restored = True
self.trainable = trainable
self._resume_config = resume_config
Expand Down Expand Up @@ -306,19 +389,19 @@ def get_experiment_checkpoint_dir(self) -> str:
return self._experiment_checkpoint_dir

@property
def trainable(self):
def trainable(self) -> TrainableTypeOrTrainer:
return self._trainable

@property
def converted_trainable(self):
def converted_trainable(self) -> TrainableType:
return self._converted_trainable

@trainable.setter
def trainable(self, trainable):
def trainable(self, trainable: TrainableTypeOrTrainer):
self._trainable = trainable
self._converted_trainable = self._convert_trainable(trainable)

def _convert_trainable(self, trainable) -> Union[str, Callable, Type[Trainable]]:
def _convert_trainable(self, trainable: TrainableTypeOrTrainer) -> TrainableType:
"""Converts an AIR Trainer to a Tune trainable and saves the converted
trainable. If not using an AIR Trainer, this leaves the trainable as is."""
from ray.train.trainer import BaseTrainer
Expand Down Expand Up @@ -350,7 +433,7 @@ def get_results(self) -> ResultGrid:
)
return ResultGrid(self._experiment_analysis)

def _get_tune_run_arguments(self, trainable) -> Dict[str, Any]:
def _get_tune_run_arguments(self, trainable: TrainableType) -> Dict[str, Any]:
"""Get tune.run arguments common for both new and resumed runs."""
checkpoint_freq = self._run_config.checkpoint_config.checkpoint_frequency
checkpoint_at_end = self._run_config.checkpoint_config.checkpoint_at_end
Expand Down Expand Up @@ -429,7 +512,9 @@ def _get_tune_run_arguments(self, trainable) -> Dict[str, Any]:
chdir_to_trial_dir=self._tune_config.chdir_to_trial_dir,
)

def _fit_internal(self, trainable, param_space) -> ExperimentAnalysis:
def _fit_internal(
self, trainable: TrainableType, param_space
) -> ExperimentAnalysis:
"""Fitting for a fresh Tuner."""
args = {
**self._get_tune_run_arguments(trainable),
Expand All @@ -450,8 +535,11 @@ def _fit_internal(self, trainable, param_space) -> ExperimentAnalysis:
self.clear_remote_string_queue()
return analysis

def _fit_resume(self, trainable) -> ExperimentAnalysis:
def _fit_resume(self, trainable: TrainableType) -> ExperimentAnalysis:
"""Fitting for a restored Tuner."""
if self._missing_params_error_message:
raise ValueError(self._missing_params_error_message)

resume = "AUTO"

if self._resume_config:
Expand Down
Loading

0 comments on commit c54dcc3

Please sign in to comment.