Skip to content

Commit

Permalink
[RLlib] DreamerV3: Minor fixes and enhancements. (ray-project#37977)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored Aug 2, 2023
1 parent 2852dae commit 48488f5
Show file tree
Hide file tree
Showing 11 changed files with 108 additions and 56 deletions.
14 changes: 13 additions & 1 deletion rllib/algorithms/dreamerv3/dreamerv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
SYNCH_WORKER_WEIGHTS_TIMER,
)
from ray.rllib.utils.replay_buffers.episode_replay_buffer import EpisodeReplayBuffer
from ray.rllib.utils.typing import ResultDict
from ray.rllib.utils.typing import LearningRateOrSchedule, ResultDict


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -172,6 +172,9 @@ def training(
train_critic: Optional[bool] = NotProvided,
train_actor: Optional[bool] = NotProvided,
intrinsic_rewards_scale: Optional[float] = NotProvided,
world_model_lr: Optional[LearningRateOrSchedule] = NotProvided,
actor_lr: Optional[LearningRateOrSchedule] = NotProvided,
critic_lr: Optional[LearningRateOrSchedule] = NotProvided,
world_model_grad_clip_by_global_norm: Optional[float] = NotProvided,
critic_grad_clip_by_global_norm: Optional[float] = NotProvided,
actor_grad_clip_by_global_norm: Optional[float] = NotProvided,
Expand Down Expand Up @@ -225,6 +228,9 @@ def training(
must also be True (cannot train actor w/o training the critic).
intrinsic_rewards_scale: The factor to multiply intrinsic rewards with
before adding them to the extrinsic (environment) rewards.
world_model_lr: The learning rate or schedule for the world model optimizer.
actor_lr: The learning rate or schedule for the actor optimizer.
critic_lr: The learning rate or schedule for the critic optimizer.
world_model_grad_clip_by_global_norm: World model grad clipping value
(by global norm).
critic_grad_clip_by_global_norm: Critic grad clipping value
Expand Down Expand Up @@ -271,6 +277,12 @@ def training(
self.train_actor = train_actor
if intrinsic_rewards_scale is not NotProvided:
self.intrinsic_rewards_scale = intrinsic_rewards_scale
if world_model_lr is not NotProvided:
self.world_model_lr = world_model_lr
if actor_lr is not NotProvided:
self.actor_lr = actor_lr
if critic_lr is not NotProvided:
self.critic_lr = critic_lr
if world_model_grad_clip_by_global_norm is not NotProvided:
self.world_model_grad_clip_by_global_norm = (
world_model_grad_clip_by_global_norm
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/dreamerv3/dreamerv3_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
model_config_dict=model_config_dict,
)

self.model_size = self.model_config_dict["model_size"]
self.model_size = self._model_config_dict["model_size"]
self.is_img_space = len(self.observation_space.shape) in [2, 3]
self.is_gray_scale = (
self.is_img_space and len(self.observation_space.shape) == 2
Expand Down
25 changes: 17 additions & 8 deletions rllib/algorithms/dreamerv3/dreamerv3_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import abc

import gymnasium as gym
import numpy as np

from ray.rllib.algorithms.dreamerv3.utils import do_symlog_obs
Expand Down Expand Up @@ -72,15 +73,23 @@ def setup(self):
np.expand_dims(self.config.observation_space.sample(), (0, 1)),
reps=(B, T) + (1,) * len(self.config.observation_space.shape),
)
test_actions = np.tile(
np.expand_dims(
one_hot(
self.config.action_space.sample(), depth=self.config.action_space.n
if isinstance(self.config.action_space, gym.spaces.Discrete):
test_actions = np.tile(
np.expand_dims(
one_hot(
self.config.action_space.sample(),
depth=self.config.action_space.n,
),
(0, 1),
),
(0, 1),
),
reps=(B, T, 1),
)
reps=(B, T, 1),
)
else:
test_actions = np.tile(
np.expand_dims(self.config.action_space.sample(), (0, 1)),
reps=(B, T, 1),
)

self.dreamer_model(
inputs=_convert_to_tf(test_obs),
actions=_convert_to_tf(test_actions.astype(np.float32)),
Expand Down
21 changes: 7 additions & 14 deletions rllib/algorithms/dreamerv3/tests/test_dreamerv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,10 @@ def test_dreamerv3_compilation(self):
# Build a DreamerV3Config object.
config = (
dreamerv3.DreamerV3Config()
.framework(eager_tracing=True)
.training(
# Keep things simple. Especially the long dream rollouts seem
# to take an enormous amount of time (initially).
batch_size_B=2 * 2, # shared w/ model AND learner AND env runner
batch_length_T=16,
horizon_H=5,
# TODO (sven): Fix having to provide this.
# Should be compiled automatically as `RLModuleConfig` by
# AlgorithmConfig (see comment below)?
Expand Down Expand Up @@ -139,17 +136,13 @@ def test_dreamerv3_dreamer_model_sizes(self):
"XL_atari": 9708799,
}

config = (
dreamerv3.DreamerV3Config()
.framework("tf2", eager_tracing=True)
.training(
model={
"batch_length_T": 16,
"horizon_H": 5,
"gamma": 0.997,
"symlog_obs": True,
}
)
config = dreamerv3.DreamerV3Config().training(
model={
"batch_length_T": 16,
"horizon_H": 5,
"gamma": 0.997,
"symlog_obs": True,
}
)

# Check all model_sizes described in the paper ([1]) on matching the number
Expand Down
13 changes: 6 additions & 7 deletions rllib/algorithms/dreamerv3/tf/dreamerv3_tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,40 +53,39 @@ def configure_optimizers_for_module(
dreamerv3_module = self._module[module_id]

# World Model optimizer.
optim_world_model = tf.keras.optimizers.Adam(
learning_rate=hps.world_model_lr, epsilon=1e-8
)
optim_world_model = tf.keras.optimizers.Adam(epsilon=1e-8)
optim_world_model.build(dreamerv3_module.world_model.trainable_variables)
params_world_model = self.get_parameters(dreamerv3_module.world_model)
self.register_optimizer(
module_id=module_id,
optimizer_name="world_model",
optimizer=optim_world_model,
params=params_world_model,
lr_or_lr_schedule=hps.world_model_lr,
)

# Actor optimizer.
optim_actor = tf.keras.optimizers.Adam(learning_rate=hps.actor_lr, epsilon=1e-5)
optim_actor = tf.keras.optimizers.Adam(epsilon=1e-5)
optim_actor.build(dreamerv3_module.actor.trainable_variables)
params_actor = self.get_parameters(dreamerv3_module.actor)
self.register_optimizer(
module_id=module_id,
optimizer_name="actor",
optimizer=optim_actor,
params=params_actor,
lr_or_lr_schedule=hps.actor_lr,
)

# Critic optimizer.
optim_critic = tf.keras.optimizers.Adam(
learning_rate=hps.critic_lr, epsilon=1e-5
)
optim_critic = tf.keras.optimizers.Adam(epsilon=1e-5)
optim_critic.build(dreamerv3_module.critic.trainable_variables)
params_critic = self.get_parameters(dreamerv3_module.critic)
self.register_optimizer(
module_id=module_id,
optimizer_name="critic",
optimizer=optim_critic,
params=params_critic,
lr_or_lr_schedule=hps.critic_lr,
)

@override(TfLearner)
Expand Down
6 changes: 1 addition & 5 deletions rllib/algorithms/dreamerv3/utils/env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import gymnasium as gym
import numpy as np
from supersuit.generic_wrappers import resize_v1
import tree # pip install dm_tree

from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
Expand Down Expand Up @@ -48,11 +49,6 @@ def __init__(
# Create the gym.vector.Env object.
# Atari env.
if self.config.env.startswith("ALE/"):
# TODO (sven): This import currently causes a Tune test to fail. Either way,
# we need to figure out how to properly setup the CI environment with
# the correct versions of all gymnasium-related packages.
from supersuit.generic_wrappers import resize_v1

# [2]: "We down-scale the 84 × 84 grayscale images to 64 × 64 pixels so that
# we can apply the convolutional architecture of DreamerV1."
# ...
Expand Down
25 changes: 25 additions & 0 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,12 +893,25 @@ def compile_results(
loss_per_module_numpy = convert_to_numpy(loss_per_module)

for module_id in list(batch.policy_batches.keys()) + [ALL_MODULES]:
# Report total loss per module and other registered metrics.
module_learner_stats[module_id].update(
{
self.TOTAL_LOSS_KEY: loss_per_module_numpy[module_id],
**convert_to_numpy(metrics_per_module[module_id]),
}
)
# Report registered optimizers' learning rates.
module_learner_stats[module_id].update(
{
f"{optim_name}_lr": convert_to_numpy(
self._get_optimizer_lr(optimizer)
)
for optim_name, optimizer in (
self.get_optimizers_for_module(module_id=module_id)
)
}
)

return dict(module_learner_stats)

@OverrideToImplementCustomLogic_CallToSuperRecommended
Expand Down Expand Up @@ -1636,6 +1649,18 @@ def _get_tensor_variable(
dtype and trainable/requires_grad property.
"""

@staticmethod
@abc.abstractmethod
def _get_optimizer_lr(optimizer: Optimizer) -> float:
"""Returns the current learning rate of the given local optimizer.
Args:
optimizer: The local optimizer to get the current learning rate for.
Returns:
The learning rate value (float) of the given optimizer.
"""

@staticmethod
@abc.abstractmethod
def _set_optimizer_lr(optimizer: Optimizer, lr: float) -> None:
Expand Down
5 changes: 5 additions & 0 deletions rllib/core/learner/tf/tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,11 @@ def _get_tensor_variable(self, value, dtype=None, trainable=False) -> "tf.Tensor
),
)

@staticmethod
@override(Learner)
def _get_optimizer_lr(optimizer: "tf.Optimizer") -> float:
return optimizer.lr

@staticmethod
@override(Learner)
def _set_optimizer_lr(optimizer: "tf.Optimizer", lr: float) -> None:
Expand Down
6 changes: 6 additions & 0 deletions rllib/core/learner/torch/torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,12 @@ def _get_tensor_variable(
),
)

@staticmethod
@override(Learner)
def _get_optimizer_lr(optimizer: "torch.optim.Optimizer") -> float:
for g in optimizer.param_groups:
return g["lr"]

@staticmethod
@override(Learner)
def _set_optimizer_lr(optimizer: "torch.optim.Optimizer", lr: float) -> None:
Expand Down
25 changes: 11 additions & 14 deletions rllib/tuned_examples/dreamerv3/atari_100k.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,6 @@

config = (
DreamerV3Config()
# Switch on eager_tracing by default.
.framework("tf2", eager_tracing=True)
.resources(
num_learner_workers=0 if num_gpus == 1 else num_gpus,
num_gpus_per_learner_worker=1 if num_gpus else 0,
num_cpus_for_local_worker=1,
)
# TODO (sven): concretize this: If you use >1 GPU and increase the batch size
# accordingly, you might also want to increase the number of envs per worker
.rollouts(
num_envs_per_worker=(num_gpus or 1),
remote_worker_envs=True,
)
.environment(
# [2]: "We follow the evaluation protocol of Machado et al. (2018) with 200M
# environment steps, action repeat of 4, a time limit of 108,000 steps per
Expand All @@ -49,6 +36,17 @@
"frameskip": 1,
}
)
.resources(
num_learner_workers=0 if num_gpus == 1 else num_gpus,
num_gpus_per_learner_worker=1 if num_gpus else 0,
num_cpus_for_local_worker=1,
)
.rollouts(
# If we use >1 GPU and increase the batch size accordingly, we should also
# increase the number of envs per worker.
num_envs_per_worker=(num_gpus or 1),
remote_worker_envs=True,
)
.reporting(
metrics_num_episodes_for_smoothing=(num_gpus or 1),
report_images_and_videos=False,
Expand All @@ -60,7 +58,6 @@
model_size="S",
training_ratio=1024,
batch_size_B=16 * (num_gpus or 1),
# TODO
model={
"batch_length_T": 64,
"horizon_H": 15,
Expand Down
22 changes: 16 additions & 6 deletions rllib/tuned_examples/dreamerv3/dm_control_suite_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,39 @@
D. Hafner, T. Lillicrap, M. Norouzi, J. Ba
https://arxiv.org/pdf/2010.02193.pdf
"""
from ray.rllib.algorithms.dreamerv3.dreamerv3 import DreamerV3Config

# Run with:
# python run_regression_tests.py --dir [this file] --env DMC/[task]/[domain]
# e.g. --env=DMC/cartpole/swingup

from ray.rllib.algorithms.dreamerv3.dreamerv3 import DreamerV3Config


# Number of GPUs to run on.
num_gpus = 1

config = (
DreamerV3Config()
# Use image observations.
.environment(env_config={"from_pixels": True})
.resources(
num_learner_workers=1,
num_gpus_per_learner_worker=1,
num_learner_workers=0 if num_gpus == 1 else num_gpus,
num_gpus_per_learner_worker=1 if num_gpus else 0,
num_cpus_for_local_worker=1,
)
.rollouts(num_envs_per_worker=4, remote_worker_envs=True)
.rollouts(num_envs_per_worker=4 * (num_gpus or 1), remote_worker_envs=True)
.reporting(
metrics_num_episodes_for_smoothing=(num_gpus or 1),
report_images_and_videos=False,
report_dream_data=False,
report_individual_batch_item_stats=False,
)
# See Appendix A.
.training(
model_size="S",
training_ratio=512,
batch_size_B=16 * (num_gpus or 1),
# TODO
model={
"batch_size_B": 16,
"batch_length_T": 64,
"horizon_H": 15,
"gamma": 0.997,
Expand Down

0 comments on commit 48488f5

Please sign in to comment.