-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] APPO/IMPALA: Enable using 2 separate optimizers for policy and vs (and 2 learning rates) on the old API stack. #40927
Conversation
Signed-off-by: sven1977 <[email protected]>
…_torch_old_stack_enable_two_optimizers_two_lrs
…_torch_old_stack_enable_two_optimizers_two_lrs
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
# Figure out, which parameters of the model belong to the value | ||
# function (and which to the policy net). | ||
dummy_batch = self._lazy_tensor_dict( | ||
self._get_dummy_batch_from_view_requirements() | ||
) | ||
# Zero out all gradients (set to None) | ||
for param in self.model.parameters(): | ||
param.grad = None | ||
# Perform a dummy forward pass (through the policy net, which should be | ||
# separated from the value function in this particular user setup). | ||
out = self.model(dummy_batch) | ||
# Perform a (dummy) backward pass to be able to see, which params have | ||
# gradients and are therefore used for the policy computations (vs vf | ||
# computations). | ||
torch.sum(out[0]).backward() # [0] -> Model returns out and state-outs. | ||
# Collect policy vs value function params separately. | ||
policy_params = [] | ||
value_params = [] | ||
for param in self.model.parameters(): | ||
if param.grad is None: | ||
value_params.append(param) | ||
else: | ||
policy_params.append(param) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the need for this. Why can't you directly index to the model and ask it to give you .value.parameters()
and .policy.parameters()
? There should be a better way than treating self.model
as a blackbox with only the knowledge hat if I do forward pass on the model directly it will use the parameters that are used for policy
. Also what if there are shared parameters between the value and policy components? This will lump them up into the policy's optimizer. They won't get updated based on the loss from value function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, the problem here is that the API is NOT defined at all and some users might have self.policy
, others self.policy_net
, etc..
The only thing that is required for you if you want a value function to be present is to implement the self.value_function()
method. Take a look at our torch default models (ModelV2). They are all different in how they store the (separate) value sub-networks. It's quite a mess. I'm with you that this is not the normal way we should solve this, but since this is old API stack, which will get 100% retired very soon, I'm personally fine with this. Suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation, I figured that might be the reason. We should be explicit about this in the comments
rllib/algorithms/registry.py
Outdated
@@ -229,6 +229,7 @@ def _import_leela_chess_zero(): | |||
"DreamerV3": _import_dreamerv3, | |||
"DT": _import_dt, | |||
"IMPALA": _import_impala, | |||
"Impala": _import_impala, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait, where is this coming from? It will mess up with our telemetry.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me explain: We added a new test case in this PS, which is IMPALA (separate policy and vf) on cartpole. The new tuned_example
file is a python file (I'm trying to create as few new yamls as possible nowadays). Hence, in there I'm using the ImpalaConfig()
class/object. It seems to not work well with tune.run_experiment
for whatever reason.
I didn't think about telemetry. Let me see, whether there is a better way that would not break things ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, ok, this is the culprit here (in rllib/train.py
).
experiments = {
f"default_{uuid.uuid4().hex}": {
"run": algo_config.__class__.__name__.replace("Config", ""),
"env": config.get("env"),
"config": config,
"stop": stop,
}
}
Ok, let me provide a better fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah We should import the name directly from the registry if possible or avoid run.run_experiments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All algo config objects know what their corresponding algo class is, so this is solved now much more elegantly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One big comment about how the value vs policy parameters are retrieved.
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Fixed. @kouroshHakha, thanks for the review! Please take another look. |
the tests are failing. Let's hold onto merging until the issue is resolved. @can-anyscale Can you tell me what is wrong with the tests? All rllib-tests are complaining about a grpc plugin missing. |
Signed-off-by: sven1977 <[email protected]>
…optimizers_two_lrs' into appo_torch_old_stack_enable_two_optimizers_two_lrs
…_torch_old_stack_enable_two_optimizers_two_lrs
…optimizers_two_lrs' into appo_torch_old_stack_enable_two_optimizers_two_lrs
…d vs (and 2 learning rates) on the old API stack. (ray-project#40927)
APPO/IMPALA: Enable using 2 separate optimizers for policy and value function (and 2 learning rates) on the old API stack.
Note that this feature had already existed for tf/tf2, but not for torch.
Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.