diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index c43420fe056be..fb5837eb83552 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -340,8 +340,15 @@ def from_checkpoint( # Policy checkpoint: Return a single Policy instance. else: + msgpack = None + if checkpoint_info.get("format") == "msgpack": + msgpack = try_import_msgpack(error=True) + with open(checkpoint_info["state_file"], "rb") as f: - state = pickle.load(f) + if msgpack is not None: + state = msgpack.load(f) + else: + state = pickle.load(f) return Policy.from_state(state) @staticmethod @@ -1843,7 +1850,6 @@ def get_gym_space_from_struct_of_tensors( value: Union[Mapping, Tuple, List, TensorType], batched_input=True, ) -> gym.Space: - start_idx = 1 if batched_input else 0 struct = tree.map_structure( lambda x: gym.spaces.Box( diff --git a/rllib/utils/checkpoints.py b/rllib/utils/checkpoints.py index 428ef3d64e8d9..1795f75113860 100644 --- a/rllib/utils/checkpoints.py +++ b/rllib/utils/checkpoints.py @@ -275,6 +275,42 @@ def convert_to_msgpack_checkpoint( return msgpack_checkpoint_dir +@PublicAPI(stability="beta") +def convert_to_msgpack_policy_checkpoint( + policy_checkpoint: Union[str, Checkpoint, NewCheckpoint], + msgpack_checkpoint_dir: str, +) -> str: + """Converts a Policy checkpoint (pickle based) to a msgpack based one. + + Msgpack has the advantage of being python version independent. + + Args: + policy_checkpoint: The directory, in which to find the Policy checkpoint (pickle + based). + msgpack_checkpoint_dir: The directory, in which to create the new msgpack + based checkpoint. + + Returns: + The directory in which the msgpack checkpoint has been created. Note that + this is the same as `msgpack_checkpoint_dir`. + """ + from ray.rllib.policy.policy import Policy + + policy = Policy.from_checkpoint(policy_checkpoint) + + os.makedirs(msgpack_checkpoint_dir, exist_ok=True) + policy.export_checkpoint( + msgpack_checkpoint_dir, + policy_state=policy.get_state(), + checkpoint_format="msgpack", + ) + + # Release all resources used by the Policy. + del policy + + return msgpack_checkpoint_dir + + @PublicAPI def try_import_msgpack(error: bool = False): """Tries importing msgpack and msgpack_numpy and returns the patched msgpack module. diff --git a/rllib/utils/tests/test_checkpoint_utils.py b/rllib/utils/tests/test_checkpoint_utils.py index ff74a69a8f500..7fae5348f7f8a 100644 --- a/rllib/utils/tests/test_checkpoint_utils.py +++ b/rllib/utils/tests/test_checkpoint_utils.py @@ -9,9 +9,11 @@ from ray.rllib.algorithms.dqn import DQNConfig from ray.rllib.algorithms.simple_q import SimpleQConfig from ray.rllib.examples.env.multi_agent import MultiAgentCartPole +from ray.rllib.policy.policy import Policy from ray.rllib.utils.checkpoints import ( get_checkpoint_info, convert_to_msgpack_checkpoint, + convert_to_msgpack_policy_checkpoint, ) from ray.rllib.utils.test_utils import check from ray import tune @@ -86,7 +88,7 @@ def test_get_policy_checkpoint_info_v1_1(self): def test_msgpack_checkpoint_translation(self): """Tests, whether a checkpoint can be translated into a msgpack-checkpoint ... - ... and recovered back into and Algorithm, which is identical to a + ... and recovered back into an Algorithm, which is identical to a pickle-checkpoint-recovered Algorithm (given same initial config). """ # Base config used for both pickle-based checkpoint and msgpack-based one. @@ -141,9 +143,10 @@ def test_msgpack_checkpoint_translation(self): def test_msgpack_checkpoint_translation_multi_agent(self): """Tests, whether a checkpoint can be translated into a msgpack-checkpoint ... - ... and recovered back into and Algorithm, which is identical to a + ... and recovered back into an Algorithm, which is identical to a pickle-checkpoint-recovered Algorithm (given same initial config). """ + # Base config used for both pickle-based checkpoint and msgpack-based one. def mapping_fn(aid, episode, worker, **kwargs): return "pol" + str(aid) @@ -223,6 +226,44 @@ def mapping_fn(aid, episode, worker, **kwargs): algo1.stop() algo2.stop() + def test_msgpack_policy_checkpoint_translation(self): + """Tests, whether a Policy checkpoint can be translated into msgpack ... + + ... and recovered back into a Policy, which is identical to a + pickle-checkpoint-recovered Policy (given same initial config). + """ + # Base config used for both pickle-based checkpoint and msgpack-based one. + config = SimpleQConfig().environment("CartPole-v1") + # Build algorithm/policy objects. + algo1 = config.build() + pol1 = algo1.get_policy() + # Get its state. + pickle_state = pol1.get_state() + + # Create standard (pickle-based) checkpoint. + with tempfile.TemporaryDirectory() as pickle_cp_dir: + pol1.export_checkpoint(pickle_cp_dir) + # Now convert pickle checkpoint to msgpack using the provided + # utility function. + with tempfile.TemporaryDirectory() as msgpack_cp_dir: + convert_to_msgpack_policy_checkpoint(pickle_cp_dir, msgpack_cp_dir) + msgpack_cp_info = get_checkpoint_info(msgpack_cp_dir) + self.assertTrue(msgpack_cp_info["type"] == "Policy") + self.assertTrue(msgpack_cp_info["format"] == "msgpack") + self.assertTrue(msgpack_cp_info["policy_ids"] is None) + # Try recreating a new policy object from the msgpack checkpoint. + pol2 = Policy.from_checkpoint(msgpack_cp_dir) + # Get the state of the policy recovered from msgpack. + msgpack_state = pol2.get_state() + + # Make sure the states? match 100%. Our `check` utility + # cannot handle comparing types/classes, so we'll have to serialize the + # pickle'd config (which contains types, rather than class strings). + pickle_state["policy_spec"]["config"] = AlgorithmConfig._serialize_dict( + pickle_state["policy_spec"]["config"] + ) + check(pickle_state, msgpack_state) + if __name__ == "__main__": import pytest