Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Learner API (+DreamerV3 prep): Learner.register_metrics API, cleanup, etc.. #35573

Conversation

sven1977
Copy link
Contributor

@sven1977 sven1977 commented May 20, 2023

Learner API (+DreamerV3 prep): Learner.register_metrics API, cleanup, etc..

This PR ...

  • adds a Learner.register_metric(s) API for registering/storing metrics inside loss, compute_gradients, apply_gradients, and postprocess_gradients.
  • Unifies Learner.update/_update between torch and tf by condensing everything that should "go into the graph" (e.g. eager tracing) into self._update and making the TorchLearner and TfLearner only override this smaller subsection of the logic (e.g. TfLearner needs to gradient tape fwd_train AND compute_loss).
  • Fix inconsistent naming of some LearnerHyperparameter subclasses.
  • Rename Learner....per_module methods into Learner. ... for_module() for more a correct semantic meaning. In these methods, we don't compute stuff PER (all) modules, but rather for a single one.
  • Dissolved ImpalaLearner.compile_results into the base class (ALL_MODULE's trained env/agent steps should be recorded for all algorithms, not just IMPALA).
  • Fixed some typos "tf" -> "tf2" on the new stack (new stack does NOT use "tf" anymore).
  • Removed LEARNER_STATS_KEY from Learner.update() returned dict to reduce information leakage! Algorithm itself should handle the different layers (sampler results, learner results, worker health, etc..).

Why are these changes needed?

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
@@ -883,6 +973,31 @@ def update(
return results
return reduce_fn(results)

@abc.abstractmethod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new "core" in-graph/traced method. Subclasses only need to override this method, not update() itself.

@@ -1093,27 +1208,8 @@ def _check_result(self, result: Mapping[str, Any]) -> None:
f"module id. Valid module ids are: {list(self.module.keys())}."
)

@OverrideToImplementCustomLogic_CallToSuperRecommended
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved specialized implementations of _update to TfLearner and TorchLearner.

return reduce_fn(results)

def _do_update_fn(
def _update(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simplified ;)

Signed-off-by: sven1977 <[email protected]>
assert isinstance(rl_module, DiscreteBCTFModule)

def test_bc_algorithm_w_custom_marl_module(self):
"""Tests the independent multi-agent case with shared encoders."""

policies = {"policy_1", "policy_2"}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this here to avoid LINTer warning.

Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Copy link
Contributor

@kouroshHakha kouroshHakha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AWESOME. This PR made my day man :) Very nice and high-quality clean ups. Just a few comments and questions. Let's get this merged asap.

@@ -220,7 +222,7 @@ def compute_loss(self, fwd_out, batch):
# compute the loss based on batch and output of the forward pass
# to access the learner hyper-parameters use `self._hps`

return {self.TOTAL_LOSS_KEY: loss}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should remove self.TOTAL_LOSS_KEY

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see my answer below. I do think we still need it to distinguish from other (registered) metrics.

@@ -313,12 +324,12 @@ def configure_optimizers(self) -> ParamOptimizerPairs:
"""
param_optimizer_pairs = []
name_to_optim = {}
for module_id in self._module.keys():
if self._is_module_compatible_with_learner(self._module[module_id]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Further cleanup: In this method we should mention in the docstring that the user is not expected to override this method and instead should override configure_optimizer_for_module

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""Applies the gradients to the MultiAgentRLModule parameters.

Args:
gradients: A dictionary of gradients, in the same format as self._params.
gradients: A dictionary of gradients in the same (flat) format as
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: two spaces.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

postprocessed_gradients: Mapping[str, Any],
loss_per_module: Mapping[str, TensorType],
postprocessed_gradients: ParamDict,
metrics_per_module: Dict[ModuleID, Dict[str, Any]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to pass in the metrics_per_module or having access to a property for self._metric will do the work here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like it's cleaner to do so. What I do have doubts about is the passing of the gradients. These tensors could potentially be massive (for larger models) and converting them to numpy even though most of the time. we won't need to extract further stats from them (global norm is already computed and returned) is very costly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this way, when users override this method, they have access to the metrics_per_module information.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the gradients, I feel like now that users have the register_metrics, they can do any computations of metrics inside postprocess_gradients (while still "in-graph") and register those. Then we don't need to move around large grad tensors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Then let’s do that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, done

# We restructure the loss to be module_id -> LEARNER_STATS_KEY -> key-values.
# This matches what the legacy RLlib policies used to return.
# We compile the metrics to have the structure:
# top-leve key: module_id -> [key, e.g. self.TOTAL_LOSS_KEY] -> [value].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we still have self.TOTAL_LOSS_KEY?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I do think we still need it.

loss_per_module = self.compute_loss(...)
grads = self.compute_gradients(...)
...
results = self.compile_results(...)

Now, results has the structure:

[module ID]
   total_loss: 0.1
   grad_norm: ...
   vf_loss: ...

ALL_MODULES:
   total_loss: 0.1

So we still need it to indicate the total_loss (e.g. sum of pi- vf- entropy loss terms?), even within some module ID, but also for ALL_MODULES (b/c ALL_MODULES might have other stats as well).

self,
*,
module_ids_to_update: Sequence[ModuleID] = None,
timestep: int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep nice.

Copy link
Contributor

@kouroshHakha kouroshHakha May 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep. This is exactly what I had implemented in my old PR that attempted to do the same thing. https://github.com/ray-project/ray/blob/6d3cf2d37b8134da7b982344c09903a83d70a8f8/rllib/core/learner/tf/tf_learner.py#LL153C18-L153C44

I have one recommendation for the easier digestion of this tracing. Instead of updating the _update method to be traced vs. non-traced Let's have its implementation always be like:

def _update(...):
      return self._possibly_traced_update(...)


def _untraced_update(...)
        def helper(_batch):
            with tf.GradientTape() as tape:
                fwd_out = self._module.forward_train(_batch)
                loss_per_module = self.compute_loss(fwd_out=fwd_out, batch=_batch)
            gradients = self.compute_gradients(loss_per_module, gradient_tape=tape)
            postprocessed_gradients = self.postprocess_gradients(gradients)
            self.apply_gradients(postprocessed_gradients)

            return fwd_out, loss_per_module, postprocessed_gradients, self._metrics

        return self._strategy.run(helper, args=(batch,))

And have _possibly_traced_update() to get assigned to either _untraced_update() or tf.function(_untraced_update). Then the user won't be confused to see _update to be a traced vs untraced fn.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, good point. Will change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

self,
batch: MultiAgentBatch,
batch: NestedDict,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this _ray_trace_ctx? Do we need that? I don't know how it appeared here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is this super strange interaction between tf-eager-tracing and ray tracing :(

Putting this extra "dummy"'ish arg here does suppress/fix an error. It took me a while to figure this out almost 2 years ago when we cleaned up the eager_tracing_policy stuff.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put a comment for it then? Explain the reason? Cite the stack overflow issue, etc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, will do!

Signed-off-by: sven1977 <[email protected]>
Copy link
Contributor

@kouroshHakha kouroshHakha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. A few nits on docstrings …

Signed-off-by: sven1977 <[email protected]>
@sven1977 sven1977 added the tests-ok The tagger certifies test failures are unrelated and assumes personal liability. label May 22, 2023
Signed-off-by: sven1977 <[email protected]>
@sven1977 sven1977 merged commit 0fd06ad into ray-project:master May 22, 2023
2 checks passed
scv119 pushed a commit to scv119/ray that referenced this pull request Jun 16, 2023
arvind-chandra pushed a commit to lmco/ray that referenced this pull request Aug 31, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tests-ok The tagger certifies test failures are unrelated and assumes personal liability.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants