Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Algorithm Level Checkpointing with Learner and RL Modules #34717

Merged

Conversation

avnishn
Copy link
Member

@avnishn avnishn commented Apr 24, 2023

Signed-off-by: Avnish [email protected]

This PR introduces algorithm level checkpointing with the RL modules stack. It also introduces a test for making sure that the checkpointing runs. Checkpointing however isn't seed reproducible. Upon some inspection by me and @kouroshHakha, there is some portion of the sampler that is not seed reproducible.

That being said, if I take an algorithm, checkpoint it, and then multiple times restore it and train it, the restored versions are seed reproducible with respect to each other. I've added a test that reflects this.

The more I think about it the more I realize that the algorithm won't be seed reproducible across interrupts. This is because when loading from checkpoint, we first construct an algorithm instance, then seed it, then load training state in. We aren't restoring the seeded state at the time that the algorithm was checkpointed, therefore this random state won't carry across checkpoints.

Why are these changes needed?

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

rllib/evaluation/rollout_worker.py Show resolved Hide resolved
rllib/algorithms/algorithm.py Outdated Show resolved Hide resolved
@@ -2131,6 +2148,17 @@ def load_checkpoint(self, checkpoint: Union[Dict, str]) -> None:
else:
checkpoint_data = checkpoint
self.__setstate__(checkpoint_data)
if isinstance(checkpoint, str) and self.config._enable_learner_api:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the location of this logic ties well to the existing code where checkpoint can take both a dict or str value. You need to map the checkpoint input (str or dict) to a checkpoint data first and then use checkpoint data inside the __setstate__() api to set the state of the learner group.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my question is more like what do you need to do if checkpoint is a dict? when would that happen, and what would that mean for the learner group

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok so this is interesting. upon further inspection, this reason that this is supposed to accept a dict is in the case that trainable.save_checkpoint ever returns a dictionary. However, we don't do this, which means that we don't need to even support this inside of load_checkpoint to begin with. I just ended up removing all the logic related to handling dicts.

rllib/algorithms/ppo/tests/test_ppo_learner.py Outdated Show resolved Hide resolved
rllib/tests/test_algorithm_save_load_checkpoint_learner.py Outdated Show resolved Hide resolved
rllib/tests/test_algorithm_save_load_checkpoint_learner.py Outdated Show resolved Hide resolved
rllib/tests/test_algorithm_save_load_checkpoint_learner.py Outdated Show resolved Hide resolved
if self.config._enable_learner_api:
learner_state_dir = os.path.join(checkpoint_dir, "learner")
self.learner_group.save_state(learner_state_dir)
state["learner_state_dir"] = "learner/"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

state dict has been already dumped into a file when we get to this line. So what's the point of writing new kvs into it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leftover from experimenting, you're right :)

rllib/algorithms/algorithm_config.py Outdated Show resolved Hide resolved
rllib/policy/sample_batch.py Outdated Show resolved Hide resolved
Copy link
Contributor

@kouroshHakha kouroshHakha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approved contingent on tests passing. Thanks @avnishn

@gjoliver gjoliver merged commit 6b59692 into ray-project:master Apr 26, 2023
ProjectsByJackHe pushed a commit to ProjectsByJackHe/ray that referenced this pull request May 4, 2023
architkulkarni pushed a commit to architkulkarni/ray that referenced this pull request May 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants