Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Learner API enhancements and cleanups (prep. for DreamerV3). #35877

Conversation

sven1977
Copy link
Contributor

@sven1977 sven1977 commented May 30, 2023

Learner API enhancements and cleanups (prep. for DreamerV3). This is a replacement PR for #35574, which has been closed.

Please read this summary carefully before reviewing, it explains all important changes:

Learner API

  • optimizer_config has been removed and all settings related to optimizers are now included in the base LearnerHyperparameter class. Purely through AlgorithmConfig, the user can only influence: lr or lr-schedule and basic grad clipping behavior for a single Adam optimizer per module. In all other cases (non-Adam, >1 optimizers per module, c'tor args other than lr, non-piecewise lr schedules), the user must override the configure_optimizers_for_module method and possibly bring it their own configs and learner HPs.
  • A new register_optimizer() method has been added to Learner, allowing RLlib and users that write custom Learner subclasses to create and register their own optimizers (always under a ModuleID (ALL_MODULES by default) AND an optimizer name (DEFAULT_OPTIMIZER by default)).
  • With this register_optimizer method, users now can override either(!) configure_optimizers (MARL case where no individual RLModules have their dedicated optimizers) OR configure_optimizers_for_module (normal, independent MARL case) and - in there - call the register_optimizer() method. This is hence analogous of the just introduced register_metrics API.
    ** This way, we massively simplified the optimizer management and customization options, getting rid of the _configure_optimizers_for_module_helper method and typedefs, such as ParamOptimPair(s), NamedParamOptimPair, etc..
    ** get_optimizer(module_id=default, optimizer_name=default), get_optimizers_for_module(module_id) have been added for simpler optimizer access w/o going through the private properties (e.g. self._named_optimizers) anymore.
  • Renamed Learner.get_weights into Learner.get_module_state() for clarity and unity.
  • Renamed Learner.get_optimizer_weights into Learner.get_optimizer_state for clarity and unity.

MARLModules

  • Users can now specify AlgorithmConfig overrides per RLModule within a MultiAgentRLModule (analogous to the old stack's PolicySpec(config={"lr": 0.001}) syntax). These overrides are compiled only inside the LearnerHyperparameter instances (away from the user) and a new LearnerHyperparameter.get_hps_for_module API has been created to extract module-specific settings, e.g. for usage in compute_loss_for_module().
  • The module_id_specific HPs are now passed into all of Learner's ..._for_module(module_id, hps, ...) methods for convenience.
  • For complex MARL cases, the user may override configure_optimizers directly (and call register_optimizer from therein) and define optimizers that operate on the entire MARLModule or on several RLModules therein.

Test cases

  • The test_torch_learner.py file has been merged with test_learner.py. The resulting test case now tests both frameworks (tf2 and torch) identically.

Scheduler API

  • Scheduler objects are no longer per-module as this was a design flaw. Learner itself should handle different schedulers for different module_ids.
  • A new defaultdict (LambdaDefaultDict) has been added that can assign default values depending on the accessed key (a regular python defaultdict always returns the same default value, regardless of the key). This allows us to simplify the individual algorithms' Learners (e.g. APPOLearner) and NOT have to override add_module anymore, just to setup a correct entropy or other Scheduler.

Why are these changes needed?

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: sven1977 <[email protected]>
gradients_dict[k] = (
None if v is None else nn.utils.clip_grad_norm_(v, grad_clip)
)
if v is not None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bug fix

@@ -112,37 +115,6 @@ def compute_gradients(

return grads

@OverrideToImplementCustomLogic_CallToSuperRecommended
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved logic to base class. Possible b/c of new enhanced APIs.

@@ -109,28 +112,12 @@ def compute_gradients(
return grads

@override(Learner)
def postprocess_gradients(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to base class. Possible b/c of new enhanced APIs.

@@ -454,31 +442,6 @@ def helper(_batch):

return self._strategy.run(helper, args=(batch,))

@OverrideToImplementCustomLogic_CallToSuperRecommended
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to base class. Possible b/c of new enhanced APIs.

"""
for key, value in metrics_dict.items():
self.register_metric(module_id, key, value)

def get_weights(self, module_ids: Optional[Set[str]] = None) -> Mapping[str, Any]:
"""Returns the weights of the underlying MultiAgentRLModule.
def get_optimizer(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New APIs to more easily retrieve registered optimizer by their module_ids and names AND to compile parameter sub-dictionaries that only contain param_refs -> params key/value-paris of a particular optimizer.


def _configure_optimizers_for_module_helper(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method has become superflous due to the new register_optimizer API.

optimizers), but rather override the `self.configure_optimizers_for_module(
module_id=..)` method and return those optimizers from there that you need for
the given module.
def register_optimizer(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New API allowing for more flexibility in users overriding configure_optimizers OR configure_optimizers_for_module.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice API change. I love it.

Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
@ollie-iterators
Copy link

Lint error:
rllib/core/testing/utils.py:7:1: F401 'ray.rllib.core.learner.learner.LearnerHyperparameters' imported but unused

Copy link
Contributor

@kouroshHakha kouroshHakha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything looks good. I just notice you introduce an API here that is not used (i.e. _check_registered_optimizer)

optimizers), but rather override the `self.configure_optimizers_for_module(
module_id=..)` method and return those optimizers from there that you need for
the given module.
def register_optimizer(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice API change. I love it.

named_optimizers.append((optim_name, optimizer))
return named_optimizers

def compile_param_dict_for_optimizer(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit: the name does not click with me right off the bat. Maybe filter_param_dict_for_optimizer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, "filter" is much better. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

new_lr = self._optimizer_lr_schedules[optimizer].update(
timestep=timestep
)
self._set_optimizer_lr(optimizer, lr=new_lr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit: Thinking about possible extensions coming from prospect users, I think we probably wanna also make the components of this method public API. like introducing get_lr_scheduler(optimizer_name, optimizer) and `set_optimizer_lr(optimizer, lr).

for optimizer_name, optimizer in self.get_optimizers_for_module(module_id):
            # Only update this optimizer's lr, if a scheduler has been registered
            # along with it.
            scheduler = get_lr_scheduler(optimizer_name, optimizer)
            if scheduler is not None:
                new_lr = scheduler.update(timestep=timestep)
                self.set_optimizer_lr(optimizer, lr=new_lr)
                # Make sure our returned results differentiate by optimizer name
                # (if not the default name).
                stats_name = LEARNER_RESULTS_CURR_LR_KEY
                if optimizer_name != DEFAULT_OPTIMIZER:
                    stats_name += "_" + optimizer_name
                results.update({stats_name: new_lr})

@@ -1189,6 +1368,23 @@ def _make_module(self) -> MultiAgentRLModule:
module = module.as_multi_agent()
return module

def _check_registered_optimizer(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find where we call this function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, ok, could be I erred :) Lemme check ... and fix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally, I erased the main call in Learner.register_optimizer() :) Fixed.

@sven1977
Copy link
Contributor Author

sven1977 commented May 30, 2023

Thanks for the thorough review @kouroshHakha ! On this PR and the closed predecessor. I know it was a hell of a lot of work to go through these changes and your suggestions helped a ton to make the Learner API better. Hopefully, this will get us into quieter waters now on Learner (unblocking Dreamer, etc..). 👍 💯 😃
I'll fix the items you suggested and then merge.

…mer_v3_02_2_learner_api_enhancements_and_cleanups
Signed-off-by: sven1977 <[email protected]>
@@ -483,7 +604,7 @@ def apply_gradients(self, gradients: ParamDict) -> None:
"""

def register_metric(self, module_id: str, key: str, value: Any) -> None:
"""Registers a single key/value metric pair for loss and gradient stats.
"""Registers a single key/value metric pair for loss- and gradient stats.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you mean to add a - after loss?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, as in: "loss stats and gradient stats" -> compress to -> "loss- and gradient stats"

…mer_v3_02_2_learner_api_enhancements_and_cleanups
Signed-off-by: sven1977 <[email protected]>
…mer_v3_02_2_learner_api_enhancements_and_cleanups
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
@sven1977 sven1977 merged commit f1f714c into ray-project:master May 31, 2023
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants