Skip to content

Commit

Permalink
[RLlib] DreamerV3: Make 200M (XL model) work; mixed float16 option (r…
Browse files Browse the repository at this point in the history
…ay-project#38461)

Signed-off-by: sven1977 <[email protected]>
  • Loading branch information
sven1977 committed Aug 25, 2023
1 parent 3f11cf4 commit b004579
Show file tree
Hide file tree
Showing 35 changed files with 947 additions and 259 deletions.
2 changes: 1 addition & 1 deletion doc/source/rllib/images/dreamerv3/atari100k_1_vs_4gpus.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion doc/source/rllib/images/dreamerv3/dmc_1_vs_4gpus.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added doc/source/rllib/images/dreamerv3/dreamerv3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion doc/source/rllib/images/dreamerv3/pong_1_2_and_4gpus.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 0 additions & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4479,7 +4479,6 @@ py_test_module_list(
files = [
"tests/test_dnc.py",
"tests/test_perf.py",
"algorithms/dreamerv3/tests/test_dreamerv3.py",
"env/wrappers/tests/test_kaggle_wrapper.py",
"examples/env/tests/test_cliff_walking_wall_env.py",
"examples/env/tests/test_coin_game_non_vectorized_env.py",
Expand Down
168 changes: 130 additions & 38 deletions rllib/algorithms/dreamerv3/README.md
Original file line number Diff line number Diff line change
@@ -1,71 +1,58 @@
# DreamerV3

![DreamerV3](../../../doc/source/rllib/images/dreamerv3/dreamerv3.png)

## Overview
An RLlib-based implementation of the
[DreamerV3 model-based reinforcement learning algorithm](https://arxiv.org/pdf/2301.04104v1.pdf)
by D. Hafner et al. (Google DeepMind) 2023, in TensorFlow/Keras.

This implementation allows scaling up the training by using multi-GPU machines for
This implementation allows scaling up training by using multi-GPU machines for
neural network updates (see below for tips and tricks, example configs, and command lines).

DreamerV3 trains a world model in supervised fashion using real environment
interactions. The world model's objective is to correctly predict all aspects
of the transition dynamics of the RL environment, which includes predicting the
correct next observations, the received rewards, as well as a boolean episode
correct next environment state, received rewards, as well as a boolean episode
continuation flag.
Just like in a standard policy gradient algorithm (e.g. REINFORCE), the critic tries to
predict a correct value function (based on the world model-predicted rewards), whereas
the actor tries to come up with good actions to take for maximizing accumulated rewards
over time.
In other words, the actual RL components of the model (actor and critic) are never
trained on real environment data, but on dreamed trajectories only.
predict a correct value function and the actor tries to come up with good actions
choices that maximize accumulated rewards over time.
However, both actor and critic are never trained on real environment data, but solely on
dreamed trajectories produced by the world model.

For more specific details about the algorithm refer to the
For more specific details about DreamerV3 architecture and math refer to the
[original paper](https://arxiv.org/pdf/2301.04104v1.pdf) (see below for all references).

## Note on Hyperparameter Tuning for DreamerV3
DreamerV3 is an extremely versatile and stable algorithm that not only works well on
different action- and observation spaces (i.e. discrete and continuous actions, as well
as image and vector observations), but also has very little hyperparameters that require tuning.
as image and vector observations) and reward functions (sparse or dense),
but also has very little hyperparameters that require tuning.

All you need is a simple "model size" setting (from "XS" to "XL") and a value for the training ratio, which
specifies how many steps to replay from the buffer for a training update vs how many
steps to take in the actual environment.

For examples on how to set these config settings within your `DreamerV3Config`, see below.


## Note on multi-GPU Training with DreamerV3
We found that when using multiple GPUs for DreamerV3 training, the following simple
adjustments should be made on top of the default config.
For examples on how to set these config settings within your `DreamerV3Config` objects,
see below.

- Multiply the batch size (default `B=16`) by the number of GPUs you are using.
Use the `DreamerV3Config.training(batch_size_B=..)` API for this. For example, for 2 GPUs,
use a batch size of `B=32`.
- Multiply the number of environments you sample from in parallel by the number of GPUs you are using.
Use the `DreamerV3Config.rollouts(num_envs_per_worker=..)` for this.
For example, for 4 GPUs and a default environment count of 8 (the single-GPU default for
this setting depends on the benchmark you are running), use 32 parallel environments instead.
- Use a learning rate schedule for all learning rates (world model, actor, critic) with "priming".
- In particular, the first ~10% of total env step needed for the experiment should use low
rates of `0.4` times of the published rates (i.e. world model: `4e-5`, critic and actor: `1.2e-5`).
- Over the course of the next ~10% of total env steps, linearly increase all rates to
n times their published values, where `n=max(4, [num GPUs])`.
- For examples on how to set these LR-schedules within your `DreamerV3Config`, see below.
- [See here](https://aws.amazon.com/blogs/machine-learning/the-importance-of-hyperparameter-tuning-for-scaling-deep-learning-training-to-multiple-gpus/) for more details on learning rate "priming".
## Example Configs and Command Lines

<b>Note:</b> For a quick setup guide on how to get started with RLlib, refer to this
[documentation page here](https://docs.ray.io/en/latest/rllib/index.html#rllib-in-60-seconds).

## Example Configs and Command Lines
Use the config examples and templates in
[the tuned_examples folder here](https://github.com/ray-project/ray/tree/master/rllib/tuned_examples/dreamerv3)
Use the config examples and templates in the
[tuned_examples folder](../../tuned_examples/dreamerv3)
in combination with the following scripts and command lines in order to run RLlib's DreamerV3 algorithm in your experiments:

### Atari100k
### [Atari100k](../../tuned_examples/dreamerv3/atari_100k.py)
```shell
$ cd ray/rllib
$ rllib train file tuned_examples/dreamerv3/atari_100k.py --env ALE/Pong-v5
```

### DeepMind Control Suite (vision)
### [DeepMind Control Suite (vision)](../../tuned_examples/dreamerv3/dm_control_suite_vision.py)
```shell
$ cd ray/rllib
$ rllib train file tuned_examples/dreamerv3/dm_control_suite_vision.py --env DMC/cartpole/swingup
Expand All @@ -74,26 +61,131 @@ Other `--env` options for the DM Control Suite would be `--env DMC/hopper/hop`,
Note that you can also switch on WandB logging with the above script via the options
`--wandb-key=[your WandB API key] --wandb-project=[some project name] --wandb-run-name=[some run name]`

## Running DreamerV3 with arbitrary Envs and Configs
Can I run DreamerV3 with any gym or custom environments? Yes, you can!

<img src="../../../doc/source/rllib/images/dreamerv3/flappy_bird_env.png" alt="Flappy Bird gymnasium env" width="300" height="300" />

Let's try the Flappy Bird gymnasium env. It's image space is a cellphone-style
288 x 512 RGB, very different from DreamerV3's Atari benchmark norm (which is 64x64 RGB).
So we will have to custom-wrap observations to resize/normalize FlappyBird's ``Box(0, 255, (288, 512, 3), f32)``
space into a new ``Box(-1, 1, (64, 64, 3), f32)``.

First we quickly install ``flappy_bird_gymnasium`` in our dev environment:
```shell
$ pip install flappy_bird_gymnasium
```

Now, let's create a new python file for this RLlib experiment and call it ``flappy_bird.py``:

```python
from ray import tune
from ray.rllib.algorithms.dreamerv3.dreamerv3 import DreamerV3Config


def _env_creator(ctx):
import flappy_bird_gymnasium # doctest: +SKIP
import gymnasium as gym
from supersuit.generic_wrappers import resize_v1
from ray.rllib.algorithms.dreamerv3.utils.env_runner import NormalizedImageEnv

return NormalizedImageEnv(
resize_v1( # resize to 64x64 and normalize images
gym.make("FlappyBird-rgb-v0", audio_on=False), x_size=64, y_size=64
)
)


# Register the FlappyBird-rgb-v0 env including necessary wrappers via the
# `tune.register_env()` API.
tune.register_env("flappy-bird", _env_creator)

# Define the `config` variable to use for training.
config = (
DreamerV3Config()
# set the env to the pre-registered string
.environment("flappy-bird")
# play around with the insanely high number of hyperparameters for DreamerV3 ;)
.training(
model_size="S",
training_ratio=1024,
)
)

# Run the tuner job.
results = tune.Tuner(trainable="DreamerV3", param_space=config).fit()
```

Great! Now, let's run this experiment:

```shell
$ python flappy_bird.py
```

This should be it. Feel free to try out running this on multiple GPUs using these
more advanced config examples [here (Atari100k)](../../tuned_examples/dreamerv3/atari_100k.py) and
[here (DM Control Suite)](../../tuned_examples/dreamerv3/dm_control_suite_vision.py).
Also see the notes below on good recipes for running on multiple GPUs.

<b>IMPORTANT:</b> DreamerV3 out-of-the-box only supports image observation spaces of
shape 64x64x3 as well as any vector observations (1D float32 Box spaces).
Should you require a special world model encoder- and decoder for other observation
spaces (e.g. a text embedding or images of other dimensions), you will have to
subclass [DreamerV3's catalog class](dreamerv3_catalog.py) and then configure this
new catalog via your ``DreamerV3Config`` object as follows:

```python
from ray.rllib.algorithms.dreamerv3.tf.dreamerv3_tf_rl_module import DreamerV3TfRLModule
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec

config.rl_module(
rl_module_spec=SingleAgentRLModuleSpec(
module_class=DreamerV3TfRLModule,
catalog_class=[your DreamerV3Catalog subclass],
)
)
```


## Note on multi-GPU Training with DreamerV3
We found that when using multiple GPUs for DreamerV3 training, the following simple
adjustments should be made on top of the default config.

- Multiply the batch size (default `B=16`) by the number of GPUs you are using.
Use the `DreamerV3Config.training(batch_size_B=..)` API for this. For example, for 2 GPUs,
use a batch size of `B=32`.
- Multiply the number of environments you sample from in parallel by the number of GPUs you are using.
Use the `DreamerV3Config.rollouts(num_envs_per_worker=..)` for this.
For example, for 4 GPUs and a default environment count of 8 (the single-GPU default for
this setting depends on the benchmark you are running), use 32
parallel environments instead.
- Roughly use learning rates that are the default values multiplied by the square root of the number of GPUs.
For example, when using 4 GPUs, multiply all default learning rates (for world model, critic, and actor) by 2.
- Additionally, a "priming"-style warmup schedule might help. Thereby, increase the learning rates from 0.0
to the final value(s) over the first ~10% of total env steps needed for the experiment.
- For examples on how to set such schedules within your `DreamerV3Config`, see below.
- [See here](https://aws.amazon.com/blogs/machine-learning/the-importance-of-hyperparameter-tuning-for-scaling-deep-learning-training-to-multiple-gpus/) for more details on learning rate "priming".


## Results
Our results on the Atari 100k and (visual) DeepMind Control Suite benchmarks match those
reported in the paper.

### Pong-v5 (100k) 1GPU vs 2GPUs vs 4GPUs
<img src="../../../doc/source/rllib/images/dreamerv3/pong_1_2_and_4gpus.svg">
<img src="../../../doc/source/rllib/images/dreamerv3/pong_1_2_and_4gpus.svg" style="display:block;">

### Atari 100k
<img src="../../../doc/source/rllib/images/dreamerv3/atari100k_1_vs_4gpus.svg">
<img src="../../../doc/source/rllib/images/dreamerv3/atari100k_1_vs_4gpus.svg" style="display:block;">

### DeepMind Control Suite (vision)
<img src="../../../doc/source/rllib/images/dreamerv3/dmc_1_vs_4gpus.svg">
<img src="../../../doc/source/rllib/images/dreamerv3/dmc_1_vs_4gpus.svg" style="display:block;">


## References
For more algorithm details, see the original Dreamer-V3 paper:

[1] [Mastering Diverse Domains through World Models - 2023 D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap](https://arxiv.org/pdf/2301.04104v1.pdf)

.. and the Dreamer-V2 paper:
.. and the (predecessor) Dreamer-V2 paper:

[2] [Mastering Atari with Discrete World Models - 2021 D. Hafner, T. Lillicrap, M. Norouzi, J. Ba](https://arxiv.org/pdf/2010.02193.pdf)
56 changes: 38 additions & 18 deletions rllib/algorithms/dreamerv3/dreamerv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(self, algo_class=None):
self.critic_grad_clip_by_global_norm = 100.0
self.actor_grad_clip_by_global_norm = 100.0
self.symlog_obs = "auto"
self.use_float16 = False

# Reporting.
# DreamerV3 is super sample efficient and only needs very few episodes
Expand Down Expand Up @@ -169,6 +170,7 @@ def model(self):
"horizon_H": self.horizon_H,
"model_size": self.model_size,
"symlog_obs": self.symlog_obs,
"use_float16": self.use_float16,
}
)
return model
Expand Down Expand Up @@ -196,6 +198,7 @@ def training(
critic_grad_clip_by_global_norm: Optional[float] = NotProvided,
actor_grad_clip_by_global_norm: Optional[float] = NotProvided,
symlog_obs: Optional[Union[bool, str]] = NotProvided,
use_float16: Optional[bool] = NotProvided,
replay_buffer_config: Optional[dict] = NotProvided,
**kwargs,
) -> "DreamerV3Config":
Expand Down Expand Up @@ -257,6 +260,10 @@ def training(
symlog_obs: Whether to symlog observations or not. If set to "auto"
(default), will check for the environment's observation space and then
only symlog if not an image space.
use_float16: Whether to train with mixed float16 precision. In this mode,
model parameters are stored as float32, but all computations are
performed in float16 space (except for losses and distribution params
and outputs).
replay_buffer_config: Replay buffer config.
Only serves in DreamerV3 to set the capacity of the replay buffer.
Note though that in the paper ([1]) a size of 1M is used for all
Expand Down Expand Up @@ -314,6 +321,8 @@ def training(
self.actor_grad_clip_by_global_norm = actor_grad_clip_by_global_norm
if symlog_obs is not NotProvided:
self.symlog_obs = symlog_obs
if use_float16 is not NotProvided:
self.use_float16 = use_float16
if replay_buffer_config is not NotProvided:
# Override entire `replay_buffer_config` if `type` key changes.
# Update, if `type` key remains the same or is not specified.
Expand Down Expand Up @@ -441,6 +450,7 @@ def get_learner_hyperparameters(self) -> LearnerHyperparameters:
),
actor_grad_clip_by_global_norm=self.actor_grad_clip_by_global_norm,
critic_grad_clip_by_global_norm=self.critic_grad_clip_by_global_norm,
use_float16=self.use_float16,
report_individual_batch_item_stats=(
self.report_individual_batch_item_stats
),
Expand Down Expand Up @@ -542,21 +552,36 @@ def training_step(self) -> ResultDict:
# c) we have not sampled at all yet in this `training_step()` call.
or not have_sampled
):
# Sample using the env runner's module.
done_episodes, ongoing_episodes = env_runner.sample()
# Add ongoing and finished episodes into buffer. The buffer will
# automatically take care of properly concatenating (by episode IDs)
# the different chunks of the same episodes, even if they come in via
# separate `add()` calls.
self.replay_buffer.add(episodes=done_episodes + ongoing_episodes)
have_sampled = True

# We took B x T env steps.
env_steps_last_sample = sum(
env_steps_last_regular_sample = sum(
len(eps) for eps in done_episodes + ongoing_episodes
)
self._counters[NUM_AGENT_STEPS_SAMPLED] += env_steps_last_sample
self._counters[NUM_ENV_STEPS_SAMPLED] += env_steps_last_sample
total_sampled = env_steps_last_regular_sample

# If we have never sampled before (just started the algo and not
# recovered from a checkpoint), sample B random actions first.
if self._counters[NUM_AGENT_STEPS_SAMPLED] == 0:
d_, o_ = env_runner.sample(
num_timesteps=(
self.config.batch_size_B * self.config.batch_length_T
)
- env_steps_last_regular_sample,
random_actions=True,
)
self.replay_buffer.add(episodes=d_ + o_)
total_sampled += sum(len(eps) for eps in d_ + o_)

# Add ongoing and finished episodes into buffer. The buffer will
# automatically take care of properly concatenating (by episode IDs)
# the different chunks of the same episodes, even if they come in via
# separate `add()` calls.
self.replay_buffer.add(episodes=done_episodes + ongoing_episodes)
self._counters[NUM_AGENT_STEPS_SAMPLED] += total_sampled
self._counters[NUM_ENV_STEPS_SAMPLED] += total_sampled

# Summarize environment interaction and buffer data.
results[ALL_MODULES] = report_sampling_and_replay_buffer(
Expand All @@ -569,14 +594,13 @@ def training_step(self) -> ResultDict:
# go back and collect more samples again from the actual environment.
# However, when calculating the `training_ratio` here, we use only the
# trained steps in this very `training_step()` call over the most recent sample
# amount (`env_steps_last_sample`), not the global values. This is to avoid a
# heavy overtraining at the very beginning when we have just pre-filled the
# buffer with the minimum amount of samples.
# amount (`env_steps_last_regular_sample`), not the global values. This is to
# avoid a heavy overtraining at the very beginning when we have just pre-filled
# the buffer with the minimum amount of samples.
replayed_steps_this_iter = sub_iter = 0
while (
replayed_steps_this_iter / env_steps_last_sample
replayed_steps_this_iter / env_steps_last_regular_sample
) < self.config.training_ratio:

# Time individual batch updates.
with self._timers[LEARN_ON_BATCH_TIMER]:
logger.info(f"\tSub-iteration {self.training_iteration}/{sub_iter})")
Expand All @@ -589,10 +613,6 @@ def training_step(self) -> ResultDict:
replayed_steps = self.config.batch_size_B * self.config.batch_length_T
replayed_steps_this_iter += replayed_steps

# Convert some bool columns to float32 and one-hot actions.
sample["is_first"] = sample["is_first"].astype(np.float32)
sample["is_last"] = sample["is_last"].astype(np.float32)
sample["is_terminated"] = sample["is_terminated"].astype(np.float32)
if isinstance(env_runner.env.single_action_space, gym.spaces.Discrete):
sample["actions_ints"] = sample[SampleBatch.ACTIONS]
sample[SampleBatch.ACTIONS] = one_hot(
Expand All @@ -612,7 +632,7 @@ def training_step(self) -> ResultDict:
# update.
with self._timers["critic_ema_update"]:
self.learner_group.additional_update(
timestep=self._counters[NUM_ENV_STEPS_TRAINED],
timestep=self._counters[NUM_ENV_STEPS_SAMPLED],
reduce_fn=self._reduce_results,
)

Expand Down
1 change: 1 addition & 0 deletions rllib/algorithms/dreamerv3/dreamerv3_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class to configure your algorithm.
world_model_grad_clip_by_global_norm: float = None
actor_grad_clip_by_global_norm: float = None
critic_grad_clip_by_global_norm: float = None
use_float16: bool = None
# Reporting settings.
report_individual_batch_item_stats: bool = None
report_dream_data: bool = None
Expand Down
Loading

0 comments on commit b004579

Please sign in to comment.