-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] RLModule: Add ValueFunctionAPI
.
#46657
[RLlib] RLModule: Add ValueFunctionAPI
.
#46657
Conversation
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
…dule_value_function_api
Signed-off-by: sven1977 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I left some comments. As this is work in progress I assume this is one of several iterations.
@override(PPORLModule) | ||
def _compute_values(self, batch, device=None): | ||
@override(ValueFunctionAPI) | ||
def compute_values(self, batch: Dict[str, Any]) -> TensorType: | ||
infos = batch.pop(Columns.INFOS, None) | ||
batch = tree.map_structure(lambda s: tf.convert_to_tensor(s), batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although from old stack: I am not sure if tf.convert_to_tensor
uses the GPU in this context. I guess that we need a with tf.device()
context here. Is then the device
needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tf doesn't really need to do any .to([some device]
when processing tensors. Unlike torch, it can handle CPU inputs if the model is on the GPU (torch can't and requires you to explicitly move the tensor to GPU first). So I think, this simplification here is ok. We will soon get rid of tf support altogether.
from ray.rllib.utils.typing import TensorType | ||
|
||
|
||
class ValueFunctionAPI(abc.ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get a better idea of how this should look like. I think these APIs (or better abstract interfaces) give users a great structure in how to think of their RLModules
. This said, it does not simplify the major complexity of RLModule
s imo as the latter is specifically driven, by the complex setup with a __init__
, a __post_init__
, a setup
and a given call order through multiple parents.
What becomes more easy is to define modules usable with different RLlib
algorithms, e.g. PPO
, DQN
, etc. We need, however, to give it some thoughts how to best load a state into modules when some parts (e.g. Q function) are not part of the loaded state, but the loading RLModule
. Here we either need to take care of this in the APIs, or via the Checkpointable
(checking also for which components in a module are available and in the state, matching these and leaving the rest untouched (initialized by the frameworks initializers).
I also have a suggestion: if we use these new abstract APIs, it would be nice, if the inference-only
version is then managed by these APIs themselves. The basic RLModule
s should have only an actor (or in case of DQN an encoder).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very good points! Let me iterate a bit into a future final design, trying to bring in your thoughts:
- I agree. These new APIs should help (a bit) to ease the burden on the user to think too much up front about the algo being used. For example, in an ideal world, the user would only be concerned about the policy net, producing the actions.
- I agree also that the call orders and having to think about
__init__
,__post_init__
, andsetup
are still too complex at this point. However, if we could settle on a design, in which the user would only have to overridesetup
, then that would be fine. - The new APIs will help with the complex call order, b/c we can now get rid of multi-inheritance. This is only given the APIs are purely abstract (and thus don't have a constructor themselves <- I think this is crucial and we need to stick to this rule 100%).
- inference-only: This should probably be a) an integral part of RLModule or b) another API, but I tend to prefer a).
- Checkpointing: I think an example would help here: Pre-train a simple (custom) policy-only
RLModule
with BC, then checkpoint a the decently trained policy. Write a sub-class of the policy-onlyRLModule
, in which we add a value function net (separate or shared, doesn't matter). If in this subclass we don't touch the original policy-net structure, we should be able to load the checkpoint into this newValueFunctionAPI
implementing RLModule. The loaded module will then have the trained (BC) policy and some randomly initialized value function that can be fine-tuned with online e.g. PPO.
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
…o rlmodule_value_function_api
Signed-off-by: sven1977 <[email protected]>
RLModule: Add
ValueFunctionAPI
to PPO/IMPALA/APPO.Our goal for custom, user-written RLModules is that users should only have to write "a simple torch.nn.Module", subclassing from TorchRLModule (defining the three different forward passes). No subclassing from complex algo-specific RLModules should be required, unless(!) users want to use our RLlib default models and develop on top of these.
Our current algo-specific RLModule classes are too complex and scare users away from subclassing and implementing them. Also, being able to switch more easily between algos, once a custom network architecture has been written out, is currently extremely difficult.
We instead propose simple APIs (if possible w/o any code/implementations, just abstract method definitions), e.g. this
ValueFunctionAPI
here, to be plugged into the users' custom RLModules. This way, an algo-specific Learner can communicate to the user that the provided customRLModule
lacks implementing API xyz and all the user then has to do is inherit from API xyz as well and implement one or more abstract methods described by this API.Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.