-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Learner API (+DreamerV3 prep): Learner.register_metrics
API, cleanup, etc..
#35573
[RLlib] Learner API (+DreamerV3 prep): Learner.register_metrics
API, cleanup, etc..
#35573
Conversation
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
@@ -883,6 +973,31 @@ def update( | |||
return results | |||
return reduce_fn(results) | |||
|
|||
@abc.abstractmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new "core" in-graph/traced method. Subclasses only need to override this method, not update()
itself.
@@ -1093,27 +1208,8 @@ def _check_result(self, result: Mapping[str, Any]) -> None: | |||
f"module id. Valid module ids are: {list(self.module.keys())}." | |||
) | |||
|
|||
@OverrideToImplementCustomLogic_CallToSuperRecommended |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved specialized implementations of _update
to TfLearner and TorchLearner.
rllib/core/learner/tf/tf_learner.py
Outdated
return reduce_fn(results) | ||
|
||
def _do_update_fn( | ||
def _update( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
simplified ;)
Signed-off-by: sven1977 <[email protected]>
assert isinstance(rl_module, DiscreteBCTFModule) | ||
|
||
def test_bc_algorithm_w_custom_marl_module(self): | ||
"""Tests the independent multi-agent case with shared encoders.""" | ||
|
||
policies = {"policy_1", "policy_2"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved this here to avoid LINTer warning.
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AWESOME. This PR made my day man :) Very nice and high-quality clean ups. Just a few comments and questions. Let's get this merged asap.
@@ -220,7 +222,7 @@ def compute_loss(self, fwd_out, batch): | |||
# compute the loss based on batch and output of the forward pass | |||
# to access the learner hyper-parameters use `self._hps` | |||
|
|||
return {self.TOTAL_LOSS_KEY: loss} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should remove self.TOTAL_LOSS_KEY
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see my answer below. I do think we still need it to distinguish from other (registered) metrics.
@@ -313,12 +324,12 @@ def configure_optimizers(self) -> ParamOptimizerPairs: | |||
""" | |||
param_optimizer_pairs = [] | |||
name_to_optim = {} | |||
for module_id in self._module.keys(): | |||
if self._is_module_compatible_with_learner(self._module[module_id]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Further cleanup: In this method we should mention in the docstring that the user is not expected to override this method and instead should override configure_optimizer_for_module
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
rllib/core/learner/learner.py
Outdated
"""Applies the gradients to the MultiAgentRLModule parameters. | ||
|
||
Args: | ||
gradients: A dictionary of gradients, in the same format as self._params. | ||
gradients: A dictionary of gradients in the same (flat) format as |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: two spaces.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
rllib/core/learner/learner.py
Outdated
postprocessed_gradients: Mapping[str, Any], | ||
loss_per_module: Mapping[str, TensorType], | ||
postprocessed_gradients: ParamDict, | ||
metrics_per_module: Dict[ModuleID, Dict[str, Any]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to pass in the metrics_per_module or having access to a property for self._metric
will do the work here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like it's cleaner to do so. What I do have doubts about is the passing of the gradients. These tensors could potentially be massive (for larger models) and converting them to numpy even though most of the time. we won't need to extract further stats from them (global norm is already computed and returned) is very costly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, this way, when users override this method, they have access to the metrics_per_module
information.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On the gradients, I feel like now that users have the register_metrics
, they can do any computations of metrics inside postprocess_gradients (while still "in-graph") and register those. Then we don't need to move around large grad tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. Then let’s do that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, done
# We restructure the loss to be module_id -> LEARNER_STATS_KEY -> key-values. | ||
# This matches what the legacy RLlib policies used to return. | ||
# We compile the metrics to have the structure: | ||
# top-leve key: module_id -> [key, e.g. self.TOTAL_LOSS_KEY] -> [value]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we still have self.TOTAL_LOSS_KEY
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I do think we still need it.
loss_per_module = self.compute_loss(...)
grads = self.compute_gradients(...)
...
results = self.compile_results(...)
Now, results has the structure:
[module ID]
total_loss: 0.1
grad_norm: ...
vf_loss: ...
ALL_MODULES:
total_loss: 0.1
So we still need it to indicate the total_loss (e.g. sum of pi- vf- entropy loss terms?), even within some module ID, but also for ALL_MODULES (b/c ALL_MODULES might have other stats as well).
self, | ||
*, | ||
module_ids_to_update: Sequence[ModuleID] = None, | ||
timestep: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep nice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep. This is exactly what I had implemented in my old PR that attempted to do the same thing. https://github.com/ray-project/ray/blob/6d3cf2d37b8134da7b982344c09903a83d70a8f8/rllib/core/learner/tf/tf_learner.py#LL153C18-L153C44
I have one recommendation for the easier digestion of this tracing. Instead of updating the _update
method to be traced vs. non-traced Let's have its implementation always be like:
def _update(...):
return self._possibly_traced_update(...)
def _untraced_update(...)
def helper(_batch):
with tf.GradientTape() as tape:
fwd_out = self._module.forward_train(_batch)
loss_per_module = self.compute_loss(fwd_out=fwd_out, batch=_batch)
gradients = self.compute_gradients(loss_per_module, gradient_tape=tape)
postprocessed_gradients = self.postprocess_gradients(gradients)
self.apply_gradients(postprocessed_gradients)
return fwd_out, loss_per_module, postprocessed_gradients, self._metrics
return self._strategy.run(helper, args=(batch,))
And have _possibly_traced_update()
to get assigned to either _untraced_update()
or tf.function(_untraced_update)
. Then the user won't be confused to see _update
to be a traced vs untraced fn.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, good point. Will change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
self, | ||
batch: MultiAgentBatch, | ||
batch: NestedDict, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this _ray_trace_ctx? Do we need that? I don't know how it appeared here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is this super strange interaction between tf-eager-tracing and ray tracing :(
Putting this extra "dummy"'ish arg here does suppress/fix an error. It took me a while to figure this out almost 2 years ago when we cleaned up the eager_tracing_policy stuff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we put a comment for it then? Explain the reason? Cite the stack overflow issue, etc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, will do!
…mer_v3_02_1_1_learner_register_metrics_api
Signed-off-by: sven1977 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice. A few nits on docstrings …
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
…, cleanup, etc.. (ray-project#35573) Signed-off-by: e428265 <[email protected]>
Learner API (+DreamerV3 prep):
Learner.register_metrics
API, cleanup, etc..This PR ...
Learner....per_module
methods intoLearner. ... for_module()
for more a correct semantic meaning. In these methods, we don't compute stuff PER (all) modules, but rather for a single one.ImpalaLearner.compile_results
into the base class (ALL_MODULE's trained env/agent steps should be recorded for all algorithms, not just IMPALA).LEARNER_STATS_KEY
fromLearner.update()
returned dict to reduce information leakage! Algorithm itself should handle the different layers (sampler results, learner results, worker health, etc..).Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.