Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Learner group checkpointing #34379

Merged
merged 13 commits into from
Apr 18, 2023

Conversation

avnishn
Copy link
Member

@avnishn avnishn commented Apr 13, 2023

Signed-off-by: Avnish [email protected]

Implement multinode learner group checkpointing and tests.

Why are these changes needed?

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

- stop creating multiple distributed tf strategies
- add multinode release test for checkpointing

Signed-off-by: avnishn <[email protected]>
Signed-off-by: avnishn <[email protected]>
@@ -609,3 +609,7 @@ def as_multi_agent(self) -> "MultiAgentRLModule":
marl_module = MultiAgentRLModule()
marl_module.add_module(DEFAULT_POLICY_ID, self)
return marl_module

def unwrapped(self) -> "RLModule":
"""Returns the underlying module if this module is a wrapper."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you specify, what wrapper here means?
Like what are examples for RLModule wrappers?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch rl modules get wrapped with the torch ddp rl module wrapper

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

from ray.rllib.env.multi_agent_env import make_multi_agent
from ray.rllib.utils.test_utils import check


DEFAULT_POLICY_ID = "default_policy"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! :)

@@ -26,6 +25,9 @@
Optimizer = Union["tf.keras.optimizers.Optimizer", "torch.optim.Optimizer"]


DEFAULT_POLICY_ID = "default_policy"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we import this from policy.py here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to avoid mixing policy code in the new stack.

# the default strategy is a no-op that can be used in the local mode
# cpu only case, build will override this if needed.
self._strategy = tf.distribute.get_strategy()
self._strategy = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you leave the comment on what self._strategy is (or should be when not None)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strategy is a tf distributed strategy object that is used for the ddp logic.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a param notation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -349,6 +347,25 @@ def remove_module(self, module_id: ModuleID) -> None:
if self._enable_tf_function:
self._update_fn = tf.function(self._do_update_fn, reduce_retracing=True)

def _make_distributed_strategy(self):
"""Create a distributed strategy for the learner."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, can you add a little more explanation here on what a "strategy" is and which types exist (an example?)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strategy is a tf distributed strategy object.

The different types of strategies are contained within the function.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -483,7 +483,7 @@ def from_module(self, module: MultiAgentRLModule) -> "MultiAgentRLModuleSpec":
The MultiAgentRLModuleSpec.
"""
module_specs = {
module_id: SingleAgentRLModuleSpec.from_module(rl_module)
module_id: SingleAgentRLModuleSpec.from_module(rl_module.unwrapped())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Explain why we need to unwrap here. rl_module could be a framework-specific DDP wrapper?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

@@ -19,7 +25,7 @@

REMOTE_SCALING_CONFIGS = {
"remote-cpu": LearnerGroupScalingConfig(num_workers=1),
"remote-gpu": LearnerGroupScalingConfig(num_workers=1, num_gpus_per_worker=0.5),
"remote-gpu": LearnerGroupScalingConfig(num_workers=1, num_gpus_per_worker=1),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we change this? Would it break if we used fractional GPUs here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this learner/group actually won't even take fractional gpus. so it is pointless. I changed it while I was doing some debugging

learner_group.load_state(initial_learner_checkpoint_dir)
check(learner_group.get_weights(), initial_learner_group_weights)
learner_group.update(batch.as_multi_agent(), reduce_fn=None)
results_without_break = learner_group.update(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we check here again to see whether the weights after one update (based off the initial state) are the same as the weights of the original learner (after one update)?

Copy link
Contributor

@sven1977 sven1977 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome PR @avnish! Thanks for covering this important feature in our release tests from here on.
Just a few nits, questions, and suggestions for better comments.

Copy link
Contributor

@sven1977 sven1977 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now. Thanks!

@amogkam amogkam merged commit 4995e14 into ray-project:master Apr 18, 2023
Copy link
Member

@gjoliver gjoliver left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me talk to you offline about how you intend to use this?

def remove_dir(w):
import shutil

shutil.rmtree(worker_temp_dir)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you make this a member function on Worker as well?
so you can do lambda w: w.remove_worker_temp_dir() below.

import socket
import tempfile

hostname = socket.gethostname()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ray.util.get_node_ip

elliottower pushed a commit to elliottower/ray that referenced this pull request Apr 22, 2023
Implement multinode learner group checkpointing and tests.

---------

Signed-off-by: Avnish <[email protected]>
Signed-off-by: avnishn <[email protected]>
Signed-off-by: elliottower <[email protected]>
ProjectsByJackHe pushed a commit to ProjectsByJackHe/ray that referenced this pull request May 4, 2023
Implement multinode learner group checkpointing and tests.

---------

Signed-off-by: Avnish <[email protected]>
Signed-off-by: avnishn <[email protected]>
Signed-off-by: Jack He <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants