Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] DreamerV3: Main algo code and required changes to some RLlib APIs (RolloutWorker) #35386

Merged
merged 155 commits into from
Jun 19, 2023

Conversation

sven1977
Copy link
Contributor

@sven1977 sven1977 commented May 16, 2023

DreamerV3:

  • Main algo code (dreamerv3.py, README) compilation and model size (architecture) tests.

  • Added DreamerV3Catalog.

  • Added DreamerV3 Algorithm class and config.

  • Some changes to RLlib:

    • The class to use for sampling (default: RolloutWorker) is now publicly configurable via the AlgorithmConfig.rollouts(env_runner_class=...) setting.
    • Had to make the gradient tape in TfLearner.update() persistent=True.
  • Managed to keep the Learner API as-is by simply overriding the DreamerV3TfLearner.compute_gradients() method. W/o overriding this, DreamerV3 on tf will not learn as computing gradients for the TOTAL_LOSS_KEY over all model params messes up world model gradients.

Why are these changes needed?

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
@@ -296,6 +295,9 @@ def __init__(self, algo_class=None):
self.auto_wrap_old_gym_envs = True

# `self.rollouts()`
# TODO (sven): Clean up the configuration of fully customizable
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can now publicly configure the class used for rollouts. This used to be configurable before via config.debugging(worker_cls=..), but was not working correctly.

@@ -838,7 +843,7 @@ def validate(self) -> None:
self.model["_disable_action_flattening"] = True
if self.model.get("custom_preprocessor"):
deprecation_warning(
old="model_config['custom_preprocessor']",
old="AlgorithmConfig.training(model={'custom_preprocessor': ...})",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enhanced

@@ -2716,12 +2731,22 @@ def get_multi_agent_setup(
# Normal env (gym.Env or MultiAgentEnv): These should have the
# `observation_space` and `action_space` properties.
elif env is not None:
if hasattr(env, "observation_space") and isinstance(
if hasattr(env, "single_observation_space") and isinstance(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Support new gym.vector.Env envs, which have a single_action|observation_space property.

@@ -60,7 +60,7 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]:

return output

@override(TfRLModule)
@override(RLModule)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@@ -352,7 +352,7 @@ def _configure_optimizers_per_module_helper(
pairs.append(pair)
elif isinstance(pair_or_pairs, dict):
# pair_or_pairs is a NamedParamOptimizerPairs
for name, pair in pairs.items():
for name, pair in pair_or_pairs.items():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a bug, but not visible for Learners that only use the default (single) optimizer path.

@@ -435,8 +435,25 @@ def compute_gradients(self, loss: Mapping[str, Any]) -> ParamDictType:
The gradients in teh same format as self._params.
"""

@OverrideToImplementCustomLogic
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorted this into a better position. It should always be together:
compute_grads
postprocess_grads
apply_grads

^ in that order

@abc.abstractmethod
def apply_gradients(self, gradients: ParamDictType) -> None:
def apply_gradients(self, gradients_dict: ParamDictType) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consistent args naming.

forward passes within this method, and to use the "forward_train" outputs to
compute the required tensors for loss calculation.
"fwd_out". The returned dictionary must contain a key called
`self.TOTAL_LOSS_KEY`, which will be used to compute gradients. It is
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use constant name for this key.

@@ -811,7 +807,7 @@ def update(
reduce_fn: Callable[[List[Mapping[str, Any]]], ResultDict] = (
_reduce_mean_results
),
) -> Mapping[str, Any]:
) -> Union[Mapping[str, Any], List[Mapping[str, Any]]]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If reduce_fn not given, might return a list of dicts.

@@ -124,7 +122,7 @@ def postprocess_gradients(
return gradients_dict

@override(Learner)
def apply_gradients(self, gradients: ParamDictType) -> None:
def apply_gradients(self, gradients_dict: ParamDictType) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same: more consistent args naming.

@@ -490,11 +489,13 @@ def helper(_batch):
# constraint on forward_train and compute_loss APIs. This seems to be
# in-efficient. Make it efficient.
_batch = NestedDict(_batch)
with tf.GradientTape() as tape:
with tf.GradientTape(persistent=True) as tape:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Necessary for multiple optimizers that operate on the same RLModule.

@@ -93,24 +94,8 @@ def compute_gradients(

return grads

@OverrideToImplementCustomLogic_CallToSuperRecommended
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved down to a better location in the file.

self._params[pid].grad = grad

# for each optimizer call its step function with the gradients
for optim in self._optimizer_parameters:
optim.step()

@OverrideToImplementCustomLogic_CallToSuperRecommended
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.. to here :)

@@ -25,93 +25,6 @@
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space


def _multi_action_dist_partial_helper(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved down for making the main Catalog class in this file more prominent. We should generally always move private functions to the end of files to avoid confusion and make the main class(es) in a file more visible.

Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
@sven1977 sven1977 merged commit 8290bd1 into ray-project:master Jun 19, 2023
2 checks passed
krfricke added a commit that referenced this pull request Jun 20, 2023
…e RLlib APIs (RolloutWorker). (#35386)"

This reverts commit 8290bd1.
krfricke added a commit that referenced this pull request Jun 20, 2023
…e RLlib APIs (RolloutWorker). (#35386)" (#36564)

This reverts commit 8290bd1.
vitsai pushed a commit to vitsai/ray that referenced this pull request Jun 21, 2023
vitsai pushed a commit to vitsai/ray that referenced this pull request Jun 21, 2023
scottsun94 pushed a commit that referenced this pull request Jun 21, 2023
scottsun94 pushed a commit that referenced this pull request Jun 21, 2023
…e RLlib APIs (RolloutWorker). (#35386)" (#36564)

This reverts commit 8290bd1.
SongGuyang pushed a commit to alipay/ant-ray that referenced this pull request Jul 12, 2023
harborn pushed a commit to harborn/ray that referenced this pull request Aug 17, 2023
harborn pushed a commit to harborn/ray that referenced this pull request Aug 17, 2023
arvind-chandra pushed a commit to lmco/ray that referenced this pull request Aug 31, 2023
arvind-chandra pushed a commit to lmco/ray that referenced this pull request Aug 31, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants