Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] RLModule: Add ValueFunctionAPI. #46657

Merged
merged 9 commits into from
Jul 18, 2024

Conversation

sven1977
Copy link
Contributor

@sven1977 sven1977 commented Jul 16, 2024

RLModule: Add ValueFunctionAPI to PPO/IMPALA/APPO.

Our goal for custom, user-written RLModules is that users should only have to write "a simple torch.nn.Module", subclassing from TorchRLModule (defining the three different forward passes). No subclassing from complex algo-specific RLModules should be required, unless(!) users want to use our RLlib default models and develop on top of these.

Our current algo-specific RLModule classes are too complex and scare users away from subclassing and implementing them. Also, being able to switch more easily between algos, once a custom network architecture has been written out, is currently extremely difficult.

We instead propose simple APIs (if possible w/o any code/implementations, just abstract method definitions), e.g. this ValueFunctionAPI here, to be plugged into the users' custom RLModules. This way, an algo-specific Learner can communicate to the user that the provided custom RLModule lacks implementing API xyz and all the user then has to do is inherit from API xyz as well and implement one or more abstract methods described by this API.

Why are these changes needed?

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
@sven1977 sven1977 enabled auto-merge (squash) July 17, 2024 10:02
@github-actions github-actions bot added the go add ONLY when ready to merge, run all tests label Jul 17, 2024
Copy link
Collaborator

@simonsays1980 simonsays1980 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I left some comments. As this is work in progress I assume this is one of several iterations.

@override(PPORLModule)
def _compute_values(self, batch, device=None):
@override(ValueFunctionAPI)
def compute_values(self, batch: Dict[str, Any]) -> TensorType:
infos = batch.pop(Columns.INFOS, None)
batch = tree.map_structure(lambda s: tf.convert_to_tensor(s), batch)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although from old stack: I am not sure if tf.convert_to_tensor uses the GPU in this context. I guess that we need a with tf.device() context here. Is then the device needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tf doesn't really need to do any .to([some device] when processing tensors. Unlike torch, it can handle CPU inputs if the model is on the GPU (torch can't and requires you to explicitly move the tensor to GPU first). So I think, this simplification here is ok. We will soon get rid of tf support altogether.

from ray.rllib.utils.typing import TensorType


class ValueFunctionAPI(abc.ABC):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get a better idea of how this should look like. I think these APIs (or better abstract interfaces) give users a great structure in how to think of their RLModules. This said, it does not simplify the major complexity of RLModules imo as the latter is specifically driven, by the complex setup with a __init__, a __post_init__, a setup and a given call order through multiple parents.

What becomes more easy is to define modules usable with different RLlib algorithms, e.g. PPO, DQN, etc. We need, however, to give it some thoughts how to best load a state into modules when some parts (e.g. Q function) are not part of the loaded state, but the loading RLModule. Here we either need to take care of this in the APIs, or via the Checkpointable (checking also for which components in a module are available and in the state, matching these and leaving the rest untouched (initialized by the frameworks initializers).

I also have a suggestion: if we use these new abstract APIs, it would be nice, if the inference-only version is then managed by these APIs themselves. The basic RLModules should have only an actor (or in case of DQN an encoder).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good points! Let me iterate a bit into a future final design, trying to bring in your thoughts:

  • I agree. These new APIs should help (a bit) to ease the burden on the user to think too much up front about the algo being used. For example, in an ideal world, the user would only be concerned about the policy net, producing the actions.
  • I agree also that the call orders and having to think about __init__, __post_init__, and setup are still too complex at this point. However, if we could settle on a design, in which the user would only have to override setup, then that would be fine.
  • The new APIs will help with the complex call order, b/c we can now get rid of multi-inheritance. This is only given the APIs are purely abstract (and thus don't have a constructor themselves <- I think this is crucial and we need to stick to this rule 100%).
  • inference-only: This should probably be a) an integral part of RLModule or b) another API, but I tend to prefer a).
  • Checkpointing: I think an example would help here: Pre-train a simple (custom) policy-only RLModule with BC, then checkpoint a the decently trained policy. Write a sub-class of the policy-only RLModule, in which we add a value function net (separate or shared, doesn't matter). If in this subclass we don't touch the original policy-net structure, we should be able to load the checkpoint into this new ValueFunctionAPI implementing RLModule. The loaded module will then have the trained (BC) policy and some randomly initialized value function that can be fine-tuned with online e.g. PPO.

@github-actions github-actions bot disabled auto-merge July 17, 2024 14:57
@sven1977 sven1977 enabled auto-merge (squash) July 17, 2024 15:23
@github-actions github-actions bot disabled auto-merge July 17, 2024 17:15
Signed-off-by: sven1977 <[email protected]>
@sven1977 sven1977 enabled auto-merge (squash) July 18, 2024 07:59
@github-actions github-actions bot disabled auto-merge July 18, 2024 07:59
@sven1977 sven1977 enabled auto-merge (squash) July 18, 2024 09:23
@sven1977 sven1977 merged commit 3fd2ab8 into ray-project:master Jul 18, 2024
7 checks passed
@sven1977 sven1977 deleted the rlmodule_value_function_api branch July 18, 2024 16:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
go add ONLY when ready to merge, run all tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants