-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Learner API enhancements and cleanups (prep. for DreamerV3). #35877
[RLlib] Learner API enhancements and cleanups (prep. for DreamerV3). #35877
Conversation
Signed-off-by: sven1977 <[email protected]>
gradients_dict[k] = ( | ||
None if v is None else nn.utils.clip_grad_norm_(v, grad_clip) | ||
) | ||
if v is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bug fix
@@ -112,37 +115,6 @@ def compute_gradients( | |||
|
|||
return grads | |||
|
|||
@OverrideToImplementCustomLogic_CallToSuperRecommended |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved logic to base class. Possible b/c of new enhanced APIs.
@@ -109,28 +112,12 @@ def compute_gradients( | |||
return grads | |||
|
|||
@override(Learner) | |||
def postprocess_gradients( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to base class. Possible b/c of new enhanced APIs.
@@ -454,31 +442,6 @@ def helper(_batch): | |||
|
|||
return self._strategy.run(helper, args=(batch,)) | |||
|
|||
@OverrideToImplementCustomLogic_CallToSuperRecommended |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to base class. Possible b/c of new enhanced APIs.
""" | ||
for key, value in metrics_dict.items(): | ||
self.register_metric(module_id, key, value) | ||
|
||
def get_weights(self, module_ids: Optional[Set[str]] = None) -> Mapping[str, Any]: | ||
"""Returns the weights of the underlying MultiAgentRLModule. | ||
def get_optimizer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New APIs to more easily retrieve registered optimizer by their module_ids and names AND to compile parameter sub-dictionaries that only contain param_refs -> params key/value-paris of a particular optimizer.
|
||
def _configure_optimizers_for_module_helper( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method has become superflous due to the new register_optimizer API.
optimizers), but rather override the `self.configure_optimizers_for_module( | ||
module_id=..)` method and return those optimizers from there that you need for | ||
the given module. | ||
def register_optimizer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New API allowing for more flexibility in users overriding configure_optimizers
OR configure_optimizers_for_module
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice API change. I love it.
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Lint error: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything looks good. I just notice you introduce an API here that is not used (i.e. _check_registered_optimizer
)
optimizers), but rather override the `self.configure_optimizers_for_module( | ||
module_id=..)` method and return those optimizers from there that you need for | ||
the given module. | ||
def register_optimizer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice API change. I love it.
rllib/core/learner/learner.py
Outdated
named_optimizers.append((optim_name, optimizer)) | ||
return named_optimizers | ||
|
||
def compile_param_dict_for_optimizer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super nit: the name does not click with me right off the bat. Maybe filter_param_dict_for_optimizer
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, "filter" is much better. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
new_lr = self._optimizer_lr_schedules[optimizer].update( | ||
timestep=timestep | ||
) | ||
self._set_optimizer_lr(optimizer, lr=new_lr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super nit: Thinking about possible extensions coming from prospect users, I think we probably wanna also make the components of this method public API. like introducing get_lr_scheduler(optimizer_name, optimizer)
and `set_optimizer_lr(optimizer, lr).
for optimizer_name, optimizer in self.get_optimizers_for_module(module_id):
# Only update this optimizer's lr, if a scheduler has been registered
# along with it.
scheduler = get_lr_scheduler(optimizer_name, optimizer)
if scheduler is not None:
new_lr = scheduler.update(timestep=timestep)
self.set_optimizer_lr(optimizer, lr=new_lr)
# Make sure our returned results differentiate by optimizer name
# (if not the default name).
stats_name = LEARNER_RESULTS_CURR_LR_KEY
if optimizer_name != DEFAULT_OPTIMIZER:
stats_name += "_" + optimizer_name
results.update({stats_name: new_lr})
@@ -1189,6 +1368,23 @@ def _make_module(self) -> MultiAgentRLModule: | |||
module = module.as_multi_agent() | |||
return module | |||
|
|||
def _check_registered_optimizer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't find where we call this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, ok, could be I erred :) Lemme check ... and fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally, I erased the main call in Learner.register_optimizer()
:) Fixed.
Thanks for the thorough review @kouroshHakha ! On this PR and the closed predecessor. I know it was a hell of a lot of work to go through these changes and your suggestions helped a ton to make the Learner API better. Hopefully, this will get us into quieter waters now on Learner (unblocking Dreamer, etc..). 👍 💯 😃 |
…mer_v3_02_2_learner_api_enhancements_and_cleanups
Signed-off-by: sven1977 <[email protected]>
@@ -483,7 +604,7 @@ def apply_gradients(self, gradients: ParamDict) -> None: | |||
""" | |||
|
|||
def register_metric(self, module_id: str, key: str, value: Any) -> None: | |||
"""Registers a single key/value metric pair for loss and gradient stats. | |||
"""Registers a single key/value metric pair for loss- and gradient stats. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did you mean to add a - after loss?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, as in: "loss stats and gradient stats" -> compress to -> "loss- and gradient stats"
…mer_v3_02_2_learner_api_enhancements_and_cleanups
Signed-off-by: sven1977 <[email protected]>
…mer_v3_02_2_learner_api_enhancements_and_cleanups
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
…ay-project#35877) Signed-off-by: e428265 <[email protected]>
Learner API enhancements and cleanups (prep. for DreamerV3). This is a replacement PR for #35574, which has been closed.
Please read this summary carefully before reviewing, it explains all important changes:
Learner API
configure_optimizers_for_module
method and possibly bring it their own configs and learner HPs.register_optimizer()
method has been added toLearner
, allowing RLlib and users that write custom Learner subclasses to create and register their own optimizers (always under a ModuleID (ALL_MODULES by default) AND an optimizer name (DEFAULT_OPTIMIZER by default)).register_optimizer
method, users now can override either(!)configure_optimizers
(MARL case where no individual RLModules have their dedicated optimizers) ORconfigure_optimizers_for_module
(normal, independent MARL case) and - in there - call theregister_optimizer()
method. This is hence analogous of the just introducedregister_metrics
API.** This way, we massively simplified the optimizer management and customization options, getting rid of the
_configure_optimizers_for_module_helper
method and typedefs, such as ParamOptimPair(s), NamedParamOptimPair, etc..**
get_optimizer(module_id=default, optimizer_name=default)
,get_optimizers_for_module(module_id)
have been added for simpler optimizer access w/o going through the private properties (e.g.self._named_optimizers
) anymore.Learner.get_weights
intoLearner.get_module_state()
for clarity and unity.Learner.get_optimizer_weights
intoLearner.get_optimizer_state
for clarity and unity.MARLModules
PolicySpec(config={"lr": 0.001})
syntax). These overrides are compiled only inside the LearnerHyperparameter instances (away from the user) and a newLearnerHyperparameter.get_hps_for_module
API has been created to extract module-specific settings, e.g. for usage incompute_loss_for_module()
...._for_module(module_id, hps, ...)
methods for convenience.configure_optimizers
directly (and callregister_optimizer
from therein) and define optimizers that operate on the entire MARLModule or on several RLModules therein.Test cases
Scheduler API
add_module
anymore, just to setup a correct entropy or other Scheduler.Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.