diff --git a/rllib/utils/checkpoints.py b/rllib/utils/checkpoints.py index bb3ce107db9dc..8b9e2d7775f8b 100644 --- a/rllib/utils/checkpoints.py +++ b/rllib/utils/checkpoints.py @@ -410,8 +410,22 @@ def from_checkpoint( # Get the class constructor to call. with open(path / cls.CLASS_AND_CTOR_ARGS_FILE_NAME, "rb") as f: ctor_info = pickle.load(f) + ctor = ctor_info["class"] + + # Check, whether the constructor actually goes together with `cls`. + if not issubclass(ctor, cls): + raise ValueError( + f"The class ({ctor}) stored in checkpoint ({path}) does not seem to be " + f"a subclass of `cls` ({cls})!" + ) + elif not issubclass(ctor, Checkpointable): + raise ValueError( + f"The class ({ctor}) stored in checkpoint ({path}) does not seem to be " + "an implementer of the `Checkpointable` API!" + ) + # Construct an initial object. - obj = ctor_info["class"]( + obj = ctor( *ctor_info["ctor_args_and_kwargs"][0], **ctor_info["ctor_args_and_kwargs"][1], )