Skip to content

Commit

Permalink
[RLlib] DreamerV3: Main algo code and required changes to some RLlib …
Browse files Browse the repository at this point in the history
…APIs (RolloutWorker). (ray-project#35386)
  • Loading branch information
sven1977 committed Jun 19, 2023
1 parent d207361 commit 8290bd1
Show file tree
Hide file tree
Showing 50 changed files with 3,149 additions and 379 deletions.
13 changes: 12 additions & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1066,14 +1066,24 @@ py_test(
srcs = ["algorithms/dqn/tests/test_repro_dqn.py"]
)

# Dreamer
# Dreamer (V1)
py_test(
name = "test_dreamer",
tags = ["team:rllib", "algorithms_dir"],
size = "medium",
srcs = ["algorithms/dreamer/tests/test_dreamer.py"]
)

# DreamerV3
# TODO (sven): Enable once the version conflict for gymnasium/supersuit/pettingzoo
# /shimmy/mujoco has been resolved.
#py_test(
# name = "test_dreamerv3",
# tags = ["team:rllib", "algorithms_dir"],
# size = "large",
# srcs = ["algorithms/dreamerv3/tests/test_dreamerv3.py"]
#)

# DT
py_test(
name = "test_segmentation_buffer",
Expand Down Expand Up @@ -4345,6 +4355,7 @@ py_test_module_list(
files = [
"tests/test_dnc.py",
"tests/test_perf.py",
"algorithms/dreamerv3/tests/test_dreamerv3.py",
"env/wrappers/tests/test_kaggle_wrapper.py",
"examples/env/tests/test_cliff_walking_wall_env.py",
"examples/env/tests/test_coin_game_non_vectorized_env.py",
Expand Down
18 changes: 15 additions & 3 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,19 @@ def setup(self, config: AlgorithmConfig) -> None:
# the two we need to loop through the policy modules and create a simple
# MARLModule from the RLModule within each policy.
local_worker = self.workers.local_worker()
module_spec = local_worker.marl_module_spec
policy_dict, _ = self.config.get_multi_agent_setup(
env=local_worker.env,
spaces=getattr(local_worker, "spaces", None),
)
# TODO (Sven): Unify the inference of the MARLModuleSpec. Right now,
# we get this from the RolloutWorker's `marl_module_spec` property.
# However, this is hacky (information leak) and should not remain this
# way. For other EnvRunner classes (that don't have this property),
# Algorithm should infer this itself.
if hasattr(local_worker, "marl_module_spec"):
module_spec = local_worker.marl_module_spec
else:
module_spec = self.config.get_marl_module_spec(policy_dict=policy_dict)
learner_group_config = self.config.get_learner_group_config(module_spec)
self.learner_group = learner_group_config.build()

Expand Down Expand Up @@ -871,7 +883,7 @@ def evaluate(
# Sync weights to the evaluation WorkerSet.
if self.evaluation_workers is not None:
self.evaluation_workers.sync_weights(
from_worker_or_trainer=self.workers.local_worker()
from_worker_or_learner_group=self.workers.local_worker()
)
self._sync_filters_if_needed(
central_worker=self.workers.local_worker(),
Expand Down Expand Up @@ -1409,7 +1421,7 @@ def training_step(self) -> ResultDict:
if self.config._enable_learner_api:
from_worker_or_trainer = self.learner_group
self.workers.sync_weights(
from_worker_or_trainer=from_worker_or_trainer,
from_worker_or_learner_group=from_worker_or_trainer,
policies=list(train_results.keys()),
global_vars=global_vars,
)
Expand Down
67 changes: 33 additions & 34 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,8 @@ def __init__(self, algo_class=None):
self.normalize_actions = True
self.clip_actions = False
self.disable_env_checking = False
# Whether this env is an atari env (for atari-specific preprocessing).
# If not specified, we will try to auto-detect this.
self.is_atari = None
self.auto_wrap_old_gym_envs = True
self._is_atari = None

# `self.rollouts()`
self.env_runner_cls = None
Expand Down Expand Up @@ -718,31 +716,6 @@ def freeze(self) -> None:
# of themselves? This way, users won't even be able to alter those values
# directly anymore.

def _detect_atari_env(self) -> bool:
"""Returns whether this configured env is an Atari env or not.
Returns:
True, if specified env is an Atari env, False otherwise.
"""
# Atari envs are usually specified via a string like "PongNoFrameskip-v4"
# or "ALE/Breakout-v5".
# We do NOT attempt to auto-detect Atari env for other specified types like
# a callable, to avoid running heavy logics in validate().
# For these cases, users can explicitly set `environment(atari=True)`.
if not type(self.env) == str:
return False

try:
if self.env.startswith("ALE/"):
env = gym.make("GymV26Environment-v0", env_id=self.env)
else:
env = gym.make(self.env)
except gym.error.NameNotFound:
# Not an Atari env if this is not a gym env.
return False

return is_atari(env)

@OverrideToImplementCustomLogic_CallToSuperRecommended
def validate(self) -> None:
"""Validates all values in this config."""
Expand Down Expand Up @@ -988,10 +961,6 @@ def validate(self) -> None:
f"config.framework({self.framework_str})!"
)

# Detect if specified env is an Atari env.
if self.is_atari is None:
self.is_atari = self._detect_atari_env()

if self.input_ == "sampler" and self.off_policy_estimation_methods:
raise ValueError(
"Off-policy estimation methods can only be used if the input is a "
Expand Down Expand Up @@ -1368,7 +1337,7 @@ def environment(
disable_env_checking: If True, disable the environment pre-checking module.
is_atari: This config can be used to explicitly specify whether the env is
an Atari env or not. If not specified, RLlib will try to auto-detect
this during config validation.
this.
auto_wrap_old_gym_envs: Whether to auto-wrap old gym environments (using
the pre 0.24 gym APIs, e.g. reset() returning single obs and no info
dict). If True, RLlib will automatically wrap the given gym env class
Expand Down Expand Up @@ -1405,7 +1374,7 @@ def environment(
if disable_env_checking is not NotProvided:
self.disable_env_checking = disable_env_checking
if is_atari is not NotProvided:
self.is_atari = is_atari
self._is_atari = is_atari
if auto_wrap_old_gym_envs is not NotProvided:
self.auto_wrap_old_gym_envs = auto_wrap_old_gym_envs

Expand Down Expand Up @@ -2319,6 +2288,8 @@ def reporting(
In case there are more than this many episodes collected in a single
training iteration, use all of these episodes for metrics computation,
meaning don't ever cut any "excess" episodes.
Set this to 1 to disable smoothing and to always report only the most
recently collected episode's return.
min_time_s_per_iteration: Minimum time to accumulate within a single
`train()` call. This value does not affect learning,
only the number of times `Algorithm.training_step()` is called by
Expand Down Expand Up @@ -2645,6 +2616,34 @@ def learner_class(self) -> Type["Learner"]:
"""
return self._learner_class or self.get_default_learner_class()

@property
def is_atari(self) -> bool:
"""True if if specified env is an Atari env."""

# Not yet determined, try to figure this out.
if self._is_atari is None:
# Atari envs are usually specified via a string like "PongNoFrameskip-v4"
# or "ALE/Breakout-v5".
# We do NOT attempt to auto-detect Atari env for other specified types like
# a callable, to avoid running heavy logics in validate().
# For these cases, users can explicitly set `environment(atari=True)`.
if not type(self.env) == str:
return False
try:
if self.env.startswith("ALE/"):
env = gym.make("GymV26Environment-v0", env_id=self.env)
else:
env = gym.make(self.env)
# Any gymnasium error -> Cannot be an Atari env.
except gym.error.Error:
return False

self._is_atari = is_atari(env)
# Clean up env's resources, if any.
env.close()

return self._is_atari

# TODO: Make rollout_fragment_length as read-only property and replace the current
# self.rollout_fragment_length a private variable.
def get_rollout_fragment_length(self, worker_index: int = 0) -> int:
Expand Down
27 changes: 27 additions & 0 deletions rllib/algorithms/dreamerv3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# DreamerV3
Implementation (TensorFlow/Keras) of the "DreamerV3" model-based reinforcement learning
(RL) algorithm by D. Hafner et al. 2023

DreamerV3 train a world model in supervised fashion using real environment
interactions. The world model utilizes a recurrent GRU-based architecture
("recurrent state space model" or RSSM) and uses it to predicts rewards,
episode continuation flags, as well as, observations.
With these predictions (dreams) made by the world model, both actor
and critic are trained in classic REINFORCE fashion. In other words, the
actual RL components of the model are never trained on actual environment data,
but on dreamed trajectories only.

For more algorithm details, see:

[1] Mastering Diverse Domains through World Models - 2023
D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
https://arxiv.org/pdf/2301.04104v1.pdf

.. and the "DreamerV2" paper:

[2] Mastering Atari with Discrete World Models - 2021
D. Hafner, T. Lillicrap, M. Norouzi, J. Ba
https://arxiv.org/pdf/2010.02193.pdf

## Results
TODO
15 changes: 15 additions & 0 deletions rllib/algorithms/dreamerv3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
[1] Mastering Diverse Domains through World Models - 2023
D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
https://arxiv.org/pdf/2301.04104v1.pdf
[2] Mastering Atari with Discrete World Models - 2021
D. Hafner, T. Lillicrap, M. Norouzi, J. Ba
https://arxiv.org/pdf/2010.02193.pdf
"""
from ray.rllib.algorithms.dreamerv3.dreamerv3 import DreamerV3, DreamerV3Config

__all__ = [
"DreamerV3",
"DreamerV3Config",
]
Loading

0 comments on commit 8290bd1

Please sign in to comment.