Skip to content

Commit

Permalink
[RLlib] Add error message if Checkpointable.from_checkpoint() is ca…
Browse files Browse the repository at this point in the history
…lled on wrong class. (ray-project#46676)
  • Loading branch information
sven1977 committed Jul 18, 2024
1 parent e70db6e commit 3ed721c
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion rllib/utils/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,22 @@ def from_checkpoint(
# Get the class constructor to call.
with open(path / cls.CLASS_AND_CTOR_ARGS_FILE_NAME, "rb") as f:
ctor_info = pickle.load(f)
ctor = ctor_info["class"]

# Check, whether the constructor actually goes together with `cls`.
if not issubclass(ctor, cls):
raise ValueError(
f"The class ({ctor}) stored in checkpoint ({path}) does not seem to be "
f"a subclass of `cls` ({cls})!"
)
elif not issubclass(ctor, Checkpointable):
raise ValueError(
f"The class ({ctor}) stored in checkpoint ({path}) does not seem to be "
"an implementer of the `Checkpointable` API!"
)

# Construct an initial object.
obj = ctor_info["class"](
obj = ctor(
*ctor_info["ctor_args_and_kwargs"][0],
**ctor_info["ctor_args_and_kwargs"][1],
)
Expand Down

0 comments on commit 3ed721c

Please sign in to comment.