Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Provide msgpack checkpoint translation utility for Policy-only cases. #38825

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
wip
Signed-off-by: sven1977 <[email protected]>
  • Loading branch information
sven1977 committed Aug 24, 2023
commit 0f52ac2aa582d46c17968d2e1c5beee73135042c
10 changes: 8 additions & 2 deletions rllib/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,15 @@ def from_checkpoint(

# Policy checkpoint: Return a single Policy instance.
else:
msgpack = None
if checkpoint_info.get("format") == "msgpack":
msgpack = try_import_msgpack(error=True)

with open(checkpoint_info["state_file"], "rb") as f:
state = pickle.load(f)
if msgpack is not None:
state = msgpack.load(f)
else:
state = pickle.load(f)
return Policy.from_state(state)

@staticmethod
Expand Down Expand Up @@ -1843,7 +1850,6 @@ def get_gym_space_from_struct_of_tensors(
value: Union[Mapping, Tuple, List, TensorType],
batched_input=True,
) -> gym.Space:

start_idx = 1 if batched_input else 0
struct = tree.map_structure(
lambda x: gym.spaces.Box(
Expand Down
36 changes: 36 additions & 0 deletions rllib/utils/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,42 @@ def convert_to_msgpack_checkpoint(
return msgpack_checkpoint_dir


@PublicAPI(stability="beta")
def convert_to_msgpack_policy_checkpoint(
policy_checkpoint: Union[str, Checkpoint, NewCheckpoint],
msgpack_checkpoint_dir: str,
) -> str:
"""Converts a Policy checkpoint (pickle based) to a msgpack based one.
Msgpack has the advantage of being python version independent.
Args:
policy_checkpoint: The directory, in which to find the Policy checkpoint (pickle
based).
msgpack_checkpoint_dir: The directory, in which to create the new msgpack
based checkpoint.
Returns:
The directory in which the msgpack checkpoint has been created. Note that
this is the same as `msgpack_checkpoint_dir`.
"""
from ray.rllib.policy.policy import Policy

policy = Policy.from_checkpoint(policy_checkpoint)

os.makedirs(msgpack_checkpoint_dir, exist_ok=True)
policy.export_checkpoint(
msgpack_checkpoint_dir,
policy_state=policy.get_state(),
checkpoint_format="msgpack",
)

# Release all resources used by the Policy.
del policy

return msgpack_checkpoint_dir


@PublicAPI
def try_import_msgpack(error: bool = False):
"""Tries importing msgpack and msgpack_numpy and returns the patched msgpack module.
Expand Down
45 changes: 43 additions & 2 deletions rllib/utils/tests/test_checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from ray.rllib.algorithms.dqn import DQNConfig
from ray.rllib.algorithms.simple_q import SimpleQConfig
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.checkpoints import (
get_checkpoint_info,
convert_to_msgpack_checkpoint,
convert_to_msgpack_policy_checkpoint,
)
from ray.rllib.utils.test_utils import check
from ray import tune
Expand Down Expand Up @@ -86,7 +88,7 @@ def test_get_policy_checkpoint_info_v1_1(self):
def test_msgpack_checkpoint_translation(self):
"""Tests, whether a checkpoint can be translated into a msgpack-checkpoint ...
... and recovered back into and Algorithm, which is identical to a
... and recovered back into an Algorithm, which is identical to a
pickle-checkpoint-recovered Algorithm (given same initial config).
"""
# Base config used for both pickle-based checkpoint and msgpack-based one.
Expand Down Expand Up @@ -141,9 +143,10 @@ def test_msgpack_checkpoint_translation(self):
def test_msgpack_checkpoint_translation_multi_agent(self):
"""Tests, whether a checkpoint can be translated into a msgpack-checkpoint ...
... and recovered back into and Algorithm, which is identical to a
... and recovered back into an Algorithm, which is identical to a
pickle-checkpoint-recovered Algorithm (given same initial config).
"""

# Base config used for both pickle-based checkpoint and msgpack-based one.
def mapping_fn(aid, episode, worker, **kwargs):
return "pol" + str(aid)
Expand Down Expand Up @@ -223,6 +226,44 @@ def mapping_fn(aid, episode, worker, **kwargs):
algo1.stop()
algo2.stop()

def test_msgpack_policy_checkpoint_translation(self):
"""Tests, whether a Policy checkpoint can be translated into msgpack ...
... and recovered back into a Policy, which is identical to a
pickle-checkpoint-recovered Policy (given same initial config).
"""
# Base config used for both pickle-based checkpoint and msgpack-based one.
config = SimpleQConfig().environment("CartPole-v1")
# Build algorithm/policy objects.
algo1 = config.build()
pol1 = algo1.get_policy()
# Get its state.
pickle_state = pol1.get_state()

# Create standard (pickle-based) checkpoint.
with tempfile.TemporaryDirectory() as pickle_cp_dir:
pol1.export_checkpoint(pickle_cp_dir)
# Now convert pickle checkpoint to msgpack using the provided
# utility function.
with tempfile.TemporaryDirectory() as msgpack_cp_dir:
convert_to_msgpack_policy_checkpoint(pickle_cp_dir, msgpack_cp_dir)
msgpack_cp_info = get_checkpoint_info(msgpack_cp_dir)
self.assertTrue(msgpack_cp_info["type"] == "Policy")
self.assertTrue(msgpack_cp_info["format"] == "msgpack")
self.assertTrue(msgpack_cp_info["policy_ids"] is None)
# Try recreating a new policy object from the msgpack checkpoint.
pol2 = Policy.from_checkpoint(msgpack_cp_dir)
# Get the state of the policy recovered from msgpack.
msgpack_state = pol2.get_state()

# Make sure the states? match 100%. Our `check` utility
# cannot handle comparing types/classes, so we'll have to serialize the
# pickle'd config (which contains types, rather than class strings).
pickle_state["policy_spec"]["config"] = AlgorithmConfig._serialize_dict(
pickle_state["policy_spec"]["config"]
)
check(pickle_state, msgpack_state)


if __name__ == "__main__":
import pytest
Expand Down
Loading