-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Load state from load_state_path for rlmodule spec #35180
[RLlib] Load state from load_state_path for rlmodule spec #35180
Conversation
Signed-off-by: Avnish <[email protected]>
Signed-off-by: Avnish <[email protected]>
Signed-off-by: Avnish <[email protected]>
…ccept a dir instead of a path Signed-off-by: Avnish <[email protected]>
Signed-off-by: Avnish <[email protected]>
Signed-off-by: Avnish <[email protected]>
…state_loading_rl_module_spec
Signed-off-by: Avnish <[email protected]>
Signed-off-by: avnishn <[email protected]>
add end to end tests for the marl module uncheckpointing with the ppo algorithm. Move the kl checking in ppo tf module because it is causing a tf auto graph error for some reason Signed-off-by: avnishn <[email protected]>
Signed-off-by: avnishn <[email protected]>
…state_loading_rl_module_spec
Signed-off-by: avnishn <[email protected]>
…state_loading_rl_module_spec
Signed-off-by: Avnish <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
had a couple of questions I want to answer first.
release/rllib_tests/checkpointing_tests/test_rl_module_spec_uncheckpointing.py
Outdated
Show resolved
Hide resolved
…state_loading_rl_module_spec
Signed-off-by: Avnish <[email protected]>
Signed-off-by: Avnish <[email protected]>
…state_loading_rl_module_spec
Signed-off-by: Avnish <[email protected]>
…state_loading_rl_module_spec
Signed-off-by: Avnish <[email protected]>
Signed-off-by: avnishn <[email protected]>
…state_loading_rl_module_spec
Signed-off-by: Avnish <[email protected]>
TODOS:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice PR man. I really enjoyed it. I have a couple of qs, nothing that is merge blockers tho. Let me know when I should push the button.
@@ -403,12 +400,24 @@ class MultiAgentRLModuleSpec: | |||
module_specs: The module specs for each individual module. It can be either a | |||
SingleAgentRLModuleSpec used for all module_ids or a dictionary mapping | |||
from module IDs to SingleAgentRLModuleSpecs for each individual module. | |||
load_state_path: The path to the module state to load from. NOTE: This must be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice. Thanks for preemptively answering my questions ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dude @kouroshHakha, why is a path part of the spec?
would we save this path with the serialized module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No it's for when a user want to load up a new MARL module / part of a marl module from an old (already checkpointed) one via setting this attributed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I included the path as a part of the MARL Module spec because I want users to be able to load up the module by specifying the path over there. From a ux experience it made the most sense to me.
The path does not get saved when we serialize the spec via a call to to_dict, and therefore isn't included in the checkpoint later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's sounds perfect. thanks for the explanation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Loving the example ...
@@ -363,6 +366,84 @@ def test_save_load_state(self): | |||
weights_after_1_update_with_break, weights_after_1_update_without_break | |||
) | |||
|
|||
def test_load_module_state(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just want to point out that the resources / time required to run this test may increase after adding this unittest, If that happens we may have to break the test down to smaller isolated unittest.
rllib/core/learner/learner_group.py
Outdated
agent RLModules take precedence over the module states in the | ||
MultiAgentRLModule checkpoint. | ||
|
||
NOTE: At lease one of multi_agent_module_state or single_agent_module_states |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get this NOTE: there is no multi_agent_module_state
or single_agent_module_states
in the args of this method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These may be left overs of your dev history
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some nits.
rllib/core/learner/learner_group.py
Outdated
# also in the RLModule checkpoints. | ||
if modules_to_load: | ||
for module_id in rl_module_ckpt_dirs.keys(): | ||
if module_id in modules_to_load: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if any([dir in modeles_to_load for dir in rl_module_ckpt_dirs.keys()])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
rllib/core/learner/learner_group.py
Outdated
path / RLMODULE_STATE_DIR_NAME | ||
) | ||
else: | ||
assert len(self._workers) == self._worker_manager.num_healthy_actors() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we write the else logics in a separate util function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah lemme go ahead and do that this function has gotten too big.
@@ -403,12 +400,24 @@ class MultiAgentRLModuleSpec: | |||
module_specs: The module specs for each individual module. It can be either a | |||
SingleAgentRLModuleSpec used for all module_ids or a dictionary mapping | |||
from module IDs to SingleAgentRLModuleSpecs for each individual module. | |||
load_state_path: The path to the module state to load from. NOTE: This must be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dude @kouroshHakha, why is a path part of the spec?
would we save this path with the serialized module?
Signed-off-by: avnishn <[email protected]>
…/avnishn/ray into add_state_loading_rl_module_spec
…l ci Signed-off-by: avnishn <[email protected]>
…state_loading_rl_module_spec
Ok I finished my todos, but I still need to update with regards to jun's nits, and some of the comments that accidentally got left in, and then this should be ready to go. |
Signed-off-by: Avnish <[email protected]>
…state_loading_rl_module_spec
Signed-off-by: Avnish <[email protected]>
Signed-off-by: Avnish <[email protected]>
…state_loading_rl_module_spec
for module_id, path in rl_module_ckpt_dirs.items(): | ||
w.module[module_id].load_state(path / RLMODULE_STATE_DIR_NAME) | ||
|
||
# remove the temporary directories on the worker if any were created |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious, do you really need to remove these, given that we used tempfile for them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
they are a tempfile, but doesn't /tmp/ only get cleared after one week?
I'm not using the with: scope for the tempfile, so the directories won't be automatically removed by tempfile library.
@@ -403,12 +400,24 @@ class MultiAgentRLModuleSpec: | |||
module_specs: The module specs for each individual module. It can be either a | |||
SingleAgentRLModuleSpec used for all module_ids or a dictionary mapping | |||
from module IDs to SingleAgentRLModuleSpecs for each individual module. | |||
load_state_path: The path to the module state to load from. NOTE: This must be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's sounds perfect. thanks for the explanation.
Signed-off-by: Avnish <[email protected]>
…state_loading_rl_module_spec
This should be good to merge. The tests that failed on ci are unrelated or flakey. The GCE release test failed because the cluster failed to come up, but the AWS cluster came up and the tests passed. |
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
…ct#35180) Signed-off-by: e428265 <[email protected]>
Signed-off-by: Avnish [email protected]
Add ability for rl module and marl modules to be created and their states be loaded immediately via the rl module spec.
Add tests for basic spec loading, and multinode uncheckpointing.
Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.