Skip to content

Commit

Permalink
Revert "[RLlib] DreamerV3: Main algo code and required changes to som…
Browse files Browse the repository at this point in the history
…e RLlib APIs (RolloutWorker). (ray-project#35386)" (ray-project#36564)

This reverts commit 8290bd1.
  • Loading branch information
krfricke committed Jun 20, 2023
1 parent 3eba893 commit 42e06e3
Show file tree
Hide file tree
Showing 50 changed files with 379 additions and 3,149 deletions.
13 changes: 1 addition & 12 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1066,24 +1066,14 @@ py_test(
srcs = ["algorithms/dqn/tests/test_repro_dqn.py"]
)

# Dreamer (V1)
# Dreamer
py_test(
name = "test_dreamer",
tags = ["team:rllib", "algorithms_dir"],
size = "medium",
srcs = ["algorithms/dreamer/tests/test_dreamer.py"]
)

# DreamerV3
# TODO (sven): Enable once the version conflict for gymnasium/supersuit/pettingzoo
# /shimmy/mujoco has been resolved.
#py_test(
# name = "test_dreamerv3",
# tags = ["team:rllib", "algorithms_dir"],
# size = "large",
# srcs = ["algorithms/dreamerv3/tests/test_dreamerv3.py"]
#)

# DT
py_test(
name = "test_segmentation_buffer",
Expand Down Expand Up @@ -4355,7 +4345,6 @@ py_test_module_list(
files = [
"tests/test_dnc.py",
"tests/test_perf.py",
"algorithms/dreamerv3/tests/test_dreamerv3.py",
"env/wrappers/tests/test_kaggle_wrapper.py",
"examples/env/tests/test_cliff_walking_wall_env.py",
"examples/env/tests/test_coin_game_non_vectorized_env.py",
Expand Down
18 changes: 3 additions & 15 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,19 +706,7 @@ def setup(self, config: AlgorithmConfig) -> None:
# the two we need to loop through the policy modules and create a simple
# MARLModule from the RLModule within each policy.
local_worker = self.workers.local_worker()
policy_dict, _ = self.config.get_multi_agent_setup(
env=local_worker.env,
spaces=getattr(local_worker, "spaces", None),
)
# TODO (Sven): Unify the inference of the MARLModuleSpec. Right now,
# we get this from the RolloutWorker's `marl_module_spec` property.
# However, this is hacky (information leak) and should not remain this
# way. For other EnvRunner classes (that don't have this property),
# Algorithm should infer this itself.
if hasattr(local_worker, "marl_module_spec"):
module_spec = local_worker.marl_module_spec
else:
module_spec = self.config.get_marl_module_spec(policy_dict=policy_dict)
module_spec = local_worker.marl_module_spec
learner_group_config = self.config.get_learner_group_config(module_spec)
self.learner_group = learner_group_config.build()

Expand Down Expand Up @@ -883,7 +871,7 @@ def evaluate(
# Sync weights to the evaluation WorkerSet.
if self.evaluation_workers is not None:
self.evaluation_workers.sync_weights(
from_worker_or_learner_group=self.workers.local_worker()
from_worker_or_trainer=self.workers.local_worker()
)
self._sync_filters_if_needed(
central_worker=self.workers.local_worker(),
Expand Down Expand Up @@ -1421,7 +1409,7 @@ def training_step(self) -> ResultDict:
if self.config._enable_learner_api:
from_worker_or_trainer = self.learner_group
self.workers.sync_weights(
from_worker_or_learner_group=from_worker_or_trainer,
from_worker_or_trainer=from_worker_or_trainer,
policies=list(train_results.keys()),
global_vars=global_vars,
)
Expand Down
67 changes: 34 additions & 33 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,10 @@ def __init__(self, algo_class=None):
self.normalize_actions = True
self.clip_actions = False
self.disable_env_checking = False
# Whether this env is an atari env (for atari-specific preprocessing).
# If not specified, we will try to auto-detect this.
self.is_atari = None
self.auto_wrap_old_gym_envs = True
self._is_atari = None

# `self.rollouts()`
self.env_runner_cls = None
Expand Down Expand Up @@ -716,6 +718,31 @@ def freeze(self) -> None:
# of themselves? This way, users won't even be able to alter those values
# directly anymore.

def _detect_atari_env(self) -> bool:
"""Returns whether this configured env is an Atari env or not.
Returns:
True, if specified env is an Atari env, False otherwise.
"""
# Atari envs are usually specified via a string like "PongNoFrameskip-v4"
# or "ALE/Breakout-v5".
# We do NOT attempt to auto-detect Atari env for other specified types like
# a callable, to avoid running heavy logics in validate().
# For these cases, users can explicitly set `environment(atari=True)`.
if not type(self.env) == str:
return False

try:
if self.env.startswith("ALE/"):
env = gym.make("GymV26Environment-v0", env_id=self.env)
else:
env = gym.make(self.env)
except gym.error.NameNotFound:
# Not an Atari env if this is not a gym env.
return False

return is_atari(env)

@OverrideToImplementCustomLogic_CallToSuperRecommended
def validate(self) -> None:
"""Validates all values in this config."""
Expand Down Expand Up @@ -961,6 +988,10 @@ def validate(self) -> None:
f"config.framework({self.framework_str})!"
)

# Detect if specified env is an Atari env.
if self.is_atari is None:
self.is_atari = self._detect_atari_env()

if self.input_ == "sampler" and self.off_policy_estimation_methods:
raise ValueError(
"Off-policy estimation methods can only be used if the input is a "
Expand Down Expand Up @@ -1337,7 +1368,7 @@ def environment(
disable_env_checking: If True, disable the environment pre-checking module.
is_atari: This config can be used to explicitly specify whether the env is
an Atari env or not. If not specified, RLlib will try to auto-detect
this.
this during config validation.
auto_wrap_old_gym_envs: Whether to auto-wrap old gym environments (using
the pre 0.24 gym APIs, e.g. reset() returning single obs and no info
dict). If True, RLlib will automatically wrap the given gym env class
Expand Down Expand Up @@ -1374,7 +1405,7 @@ def environment(
if disable_env_checking is not NotProvided:
self.disable_env_checking = disable_env_checking
if is_atari is not NotProvided:
self._is_atari = is_atari
self.is_atari = is_atari
if auto_wrap_old_gym_envs is not NotProvided:
self.auto_wrap_old_gym_envs = auto_wrap_old_gym_envs

Expand Down Expand Up @@ -2288,8 +2319,6 @@ def reporting(
In case there are more than this many episodes collected in a single
training iteration, use all of these episodes for metrics computation,
meaning don't ever cut any "excess" episodes.
Set this to 1 to disable smoothing and to always report only the most
recently collected episode's return.
min_time_s_per_iteration: Minimum time to accumulate within a single
`train()` call. This value does not affect learning,
only the number of times `Algorithm.training_step()` is called by
Expand Down Expand Up @@ -2616,34 +2645,6 @@ def learner_class(self) -> Type["Learner"]:
"""
return self._learner_class or self.get_default_learner_class()

@property
def is_atari(self) -> bool:
"""True if if specified env is an Atari env."""

# Not yet determined, try to figure this out.
if self._is_atari is None:
# Atari envs are usually specified via a string like "PongNoFrameskip-v4"
# or "ALE/Breakout-v5".
# We do NOT attempt to auto-detect Atari env for other specified types like
# a callable, to avoid running heavy logics in validate().
# For these cases, users can explicitly set `environment(atari=True)`.
if not type(self.env) == str:
return False
try:
if self.env.startswith("ALE/"):
env = gym.make("GymV26Environment-v0", env_id=self.env)
else:
env = gym.make(self.env)
# Any gymnasium error -> Cannot be an Atari env.
except gym.error.Error:
return False

self._is_atari = is_atari(env)
# Clean up env's resources, if any.
env.close()

return self._is_atari

# TODO: Make rollout_fragment_length as read-only property and replace the current
# self.rollout_fragment_length a private variable.
def get_rollout_fragment_length(self, worker_index: int = 0) -> int:
Expand Down
27 changes: 0 additions & 27 deletions rllib/algorithms/dreamerv3/README.md

This file was deleted.

15 changes: 0 additions & 15 deletions rllib/algorithms/dreamerv3/__init__.py

This file was deleted.

Loading

0 comments on commit 42e06e3

Please sign in to comment.