Skip to content

Commit

Permalink
[RLlib] Issue 40347: Fix saving and loading algos by loading stored l…
Browse files Browse the repository at this point in the history
…earner state. (ray-project#42090)
  • Loading branch information
simonsays1980 authored Jan 2, 2024
1 parent 7c818f2 commit 61b531c
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 7 deletions.
21 changes: 17 additions & 4 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def from_state(state: Dict) -> "Algorithm":
new_algo = algorithm_class(config=config)
# Set the new algo's state.
new_algo.__setstate__(state)

# Return the new algo.
return new_algo

Expand Down Expand Up @@ -551,7 +552,6 @@ def _remote_worker_ids_for_metrics(self) -> List[int]:
@OverrideToImplementCustomLogic_CallToSuperRecommended
@override(Trainable)
def setup(self, config: AlgorithmConfig) -> None:

# Setup our config: Merge the user-supplied config dict (which could
# be a partial config dict) with the class' default.
if not isinstance(config, AlgorithmConfig):
Expand Down Expand Up @@ -2357,7 +2357,6 @@ def cleanup(self) -> None:
def default_resource_request(
cls, config: Union[AlgorithmConfig, PartialAlgorithmConfigDict]
) -> Union[Resources, PlacementGroupFactory]:

# Default logic for RLlib Algorithms:
# Create one bundle per individual worker (local or remote).
# Use `num_cpus_for_local_worker` and `num_gpus` for the local worker and
Expand Down Expand Up @@ -2773,7 +2772,7 @@ def __setstate__(self, state) -> None:
# Also, what should the behavior be if e.g. some training parameter
# (e.g. lr) changed?

if hasattr(self, "workers") and "worker" in state:
if hasattr(self, "workers") and "worker" in state and state["worker"]:
self.workers.local_worker().set_state(state["worker"])
remote_state = ray.put(state["worker"])
self.workers.foreach_worker(
Expand Down Expand Up @@ -2815,6 +2814,15 @@ def _setup_eval_worker(w):
"data found in state!"
)

if self.config._enable_new_api_stack:
if "learner_state_dir" in state:
self.learner_group.load_state(state["learner_state_dir"])
else:
logger.warning(
"You configured `_enable_new_api_stack=True`, but no "
"`learner_state_dir` key could be found in the state dict!"
)

if "counters" in state:
self._counters = state["counters"]

Expand Down Expand Up @@ -2878,6 +2886,7 @@ def _checkpoint_info_to_algorithm_state(
if (
checkpoint_info["checkpoint_version"] > version.Version("0.1")
and state.get("worker") is not None
and state.get("worker")
):
worker_state = state["worker"]

Expand Down Expand Up @@ -2966,6 +2975,11 @@ def _checkpoint_info_to_algorithm_state(
):
worker_state["is_policy_to_train"] = policies_to_train

if state["config"]._enable_new_api_stack:
state["learner_state_dir"] = os.path.join(
checkpoint_info["checkpoint_dir"], "learner"
)

return state

@DeveloperAPI
Expand Down Expand Up @@ -3331,7 +3345,6 @@ def get_time_taken_sec(self) -> float:
return self.time_stop - self.time_start

def should_stop(self, results):

# Before first call to `step()`.
if results is None:
# Fail after n retries.
Expand Down
27 changes: 26 additions & 1 deletion rllib/algorithms/dreamerv3/dreamerv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ray.rllib.models.catalog import MODEL_DEFAULTS
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
from ray.rllib.utils import deep_update
from ray.rllib.utils.annotations import override
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.numpy import one_hot
from ray.rllib.utils.metrics import (
Expand Down Expand Up @@ -698,3 +698,28 @@ def training_ratio(self) -> float:
@staticmethod
def _reduce_results(results: List[Dict[str, Any]]):
return tree.map_structure(lambda *s: np.mean(s, axis=0), *results)

# TODO (sven): Remove this once DreamerV3 is on the new SingleAgentEnvRunner.
@PublicAPI
def __setstate__(self, state) -> None:
"""Sts the algorithm to the provided state
Args:
state: The state dictionary to restore this `DreamerV3` instance to.
`state` may have been returned by a call to an `Algorithm`'s
`__getstate__()` method.
"""
# Call the `Algorithm`'s `__setstate__()` method.
super().__setstate__(state=state)

# Assign the module to the local `EnvRunner` if sharing is enabled.
# Note, in `Learner.load_state()` the module is first deleted
# and then a new one is built - therefore the worker has no
# longer a copy of the learner.
if self.config.share_module_between_env_runner_and_learner:
assert id(self.workers.local_worker().module) != id(
self.learner_group._learner.module[DEFAULT_POLICY_ID]
)
self.workers.local_worker().module = self.learner_group._learner.module[
DEFAULT_POLICY_ID
]
2 changes: 1 addition & 1 deletion rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1396,7 +1396,7 @@ def load_state(
self._check_is_built()
path = pathlib.Path(path)
del self._module
# TODO(avnishn) from checkpoint doesn't currently support modules_to_load,
# TODO (avnishn): from checkpoint doesn't currently support modules_to_load,
# but it should, so we will add it later.
self._module_obj = MultiAgentRLModule.from_checkpoint(path / "module_state")
self._reset()
Expand Down
4 changes: 4 additions & 0 deletions rllib/env/env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def get_state(self) -> Dict[str, Any]:
Returns:
The current state of this EnvRunner.
"""
# TODO (sven, simon): `Algorithm.save_checkpoint()` will store with
# this an empty worker state and in `Algorithm.from_checkpoint()`
# the empty state (not `None`) must be ensured separately. Shall we
# return here as a default `None`?
return {}

def set_state(self, state: Dict[str, Any]) -> None:
Expand Down
2 changes: 1 addition & 1 deletion rllib/policy/torch_policy_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,7 +1085,7 @@ def set_state(self, state: PolicyState) -> None:
if hasattr(self, "exploration") and "_exploration_state" in state:
self.exploration.set_state(state=state["_exploration_state"])

# Restore glbal timestep.
# Restore global timestep.
self.global_timestep = state["global_timestep"]

# Then the Policy's (NN) weights and connectors.
Expand Down

0 comments on commit 61b531c

Please sign in to comment.