Skip to content

Commit

Permalink
[RLlib] Disable RL Modules for policy server examples. (ray-project#3…
Browse files Browse the repository at this point in the history
  • Loading branch information
ArturNiederfahrenhorst committed Apr 25, 2023
1 parent 131d4dc commit 790ef9e
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 33 deletions.
73 changes: 47 additions & 26 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,9 @@ def __init__(self, algo_class=None):
# `self.rl_module()`
self.rl_module_spec = None
self._enable_rl_module_api = False
# Whether to error out if exploration config is set when using RLModules.
self._validate_exploration_conf_and_rl_modules = True
# Helper to keep track of the original exploration config when dis-/enabling
# rl modules.
self.__prior_exploration_config = None

# `self.experimental()`
self._tf_policy_handles_more_than_one_loss = False
Expand Down Expand Up @@ -568,6 +569,13 @@ def update_from_dict(
"""
eval_call = {}

# We deal with this special key before all others because it may influence
# stuff like "exploration_config".
# Namely, we want to re-instantiate the exploration config this config had
# inside `self.rl_module()` before potentially overwriting it in the following.
if "_enable_rl_module_api" in config_dict:
self.rl_module(_enable_rl_module_api=config_dict["_enable_rl_module_api"])

# Modify our properties one by one.
for key, value in config_dict.items():
key = self._translate_special_keys(key, warn_deprecated=False)
Expand All @@ -577,8 +585,11 @@ def update_from_dict(
if key == TRIAL_INFO:
continue

if key == "_enable_rl_module_api":
# We've dealt with this above.
continue
# Set our multi-agent settings.
if key == "multiagent":
elif key == "multiagent":
kwargs = {
k: value[k]
for k in [
Expand Down Expand Up @@ -863,13 +874,13 @@ def validate(self) -> None:
self.enable_connectors = True

# Explore parameter cannot be False with RLModule API enabled.
# The reason is that the explore is not just a parameter that will get passed
# The reason is that `explore` is not just a parameter that will get passed
# down to the policy.compute_actions() anymore. It is a phase in which RLModule.
# forward_exploration() will get called during smapling. If user needs to
# forward_exploration() will get called during sampling. If user needs to
# really disable the stochasticity during this phase, they need to override the
# RLModule.forward_exploration() method or setup model parameters such that it
# will disable the stocalisticity of this method (e.g. by setting the std to 0
# or setting temprature to 0 for the Categorical distribution).
# will disable the stochasticity of this method (e.g. by setting the std to 0
# or setting temperature to 0 for the Categorical distribution).

if self._enable_rl_module_api and not self.explore:
raise ValueError(
Expand Down Expand Up @@ -1002,25 +1013,16 @@ def validate(self) -> None:
self.rl_module_spec = default_rl_module_spec

if self.exploration_config:
if self._validate_exploration_conf_and_rl_modules:
# This is not compatible with RLModules, which have a method
# `forward_exploration` to specify custom exploration behavior.
raise ValueError(
"When RLModule API are enabled, exploration_config can not be "
"set. If you want to implement custom exploration behaviour, "
"please modify the `forward_exploration` method of the "
"RLModule at hand. On configs that have a default exploration "
"config, this must be done with "
"`config.exploration_config={}`."
)
else:
# RLModules don't support exploration_configs anymore.
# AlgorithmConfig has a default exploration config.
logger.warning(
"When RLModule API are enabled, exploration_config "
"will be ignored. Disable RLModule API make use of an "
"exploration_config."
)
# This is not compatible with RLModules, which have a method
# `forward_exploration` to specify custom exploration behavior.
raise ValueError(
"When RLModule API are enabled, exploration_config can not be "
"set. If you want to implement custom exploration behaviour, "
"please modify the `forward_exploration` method of the "
"RLModule at hand. On configs that have a default exploration "
"config, this must be done with "
"`config.exploration_config={}`."
)

# make sure the resource requirements for learner_group is valid
if self.num_learner_workers == 0 and self.num_gpus_per_worker > 1:
Expand Down Expand Up @@ -2420,7 +2422,26 @@ def rl_module(
"config, this must be done with "
"`config.exploration_config={}`."
)
self.__prior_exploration_config = self.exploration_config
self.exploration_config = {}
elif _enable_rl_module_api is False and not self.exploration_config:
if self.__prior_exploration_config is not None:
logger.warning(
f"Setting `exploration_config="
f"{self.__prior_exploration_config}` because you set "
f"`_enable_rl_modules=False`. This exploration config was "
f"restored from a prior exploration config that was overriden "
f"when setting `_enable_rl_modules=True`. This occurs because "
f"when RLModule API are enabled, exploration_config can not "
f"be set."
)
self.exploration_config = self.__prior_exploration_config
self.__prior_exploration_config = None
else:
logger.warning(
"config._enable_rl_module_api was set to False, but no prior "
"exploration config was found to be restored."
)
else:
# throw a warning if the user has used this API but not enabled it.
logger.warning(
Expand Down
5 changes: 3 additions & 2 deletions rllib/algorithms/tests/test_algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import ray
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.algorithms.callbacks import make_multi_callbacks
from ray.rllib.algorithms.ppo import PPO, PPOConfig
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.ppo import PPO
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.core.rl_module.marl_module import (
Expand All @@ -16,7 +17,7 @@
class TestAlgorithmConfig(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init(num_cpus=6)
ray.init(num_cpus=6, local_mode=True)

@classmethod
def tearDownClass(cls):
Expand Down
4 changes: 4 additions & 0 deletions rllib/examples/serving/cartpole_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ def _input(ioctx):
# Set to INFO so we'll see the server's actual address:port.
.debugging(log_level="INFO")
)
# Disable RLModules because they need connectors
# TODO(Artur): Deprecate ExternalEnv and reenable connectors and RL Modules here
config.rl_module(_enable_rl_module_api=False)
config.training(_enable_learner_api=False)

# DQN.
if args.run == "DQN" or args.run == "APEX" or args.run == "R2D2":
Expand Down
5 changes: 5 additions & 0 deletions rllib/examples/serving/unity3d_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ def _input(ioctx):
.evaluation(off_policy_estimation_methods={})
)

# Disable RLModules because they need connectors
# TODO(Artur): Deprecate ExternalEnv and reenable connectors and RL Modules here
config.rl_module(_enable_rl_module_api=False)
config._enable_learner_api = False

# Create the Trainer used for Policy serving.
algo = config.build()

Expand Down
10 changes: 8 additions & 2 deletions rllib/policy/tests/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ def test_policy_get_and_set_state(self):
policy.set_state(state1)
state3 = policy.get_state()
# Make sure everything is the same.
check(state1["_exploration_state"], state3["_exploration_state"])
# This is only supported without RLModule API. See AlgorithmConfig for
# more info.
if not config._enable_rl_module_api:
check(state1["_exploration_state"], state3["_exploration_state"])
check(state1["global_timestep"], state3["global_timestep"])
check(state1["weights"], state3["weights"])

Expand All @@ -42,7 +45,10 @@ def test_policy_get_and_set_state(self):
if isinstance(policy, (EagerTFPolicyV2, DynamicTFPolicyV2, TorchPolicyV2)):
policy_restored_from_scratch = Policy.from_state(state3)
state4 = policy_restored_from_scratch.get_state()
check(state3["_exploration_state"], state4["_exploration_state"])
# This is only supported without RLModule API. See AlgorithmConfig for
# more info.
if not config._enable_rl_module_api:
check(state3["_exploration_state"], state4["_exploration_state"])
check(state3["global_timestep"], state4["global_timestep"])
# For tf static graph, the new model has different layer names
# (as it gets written into the same graph as the old one).
Expand Down
18 changes: 15 additions & 3 deletions rllib/tests/test_rllib_train_and_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,23 @@ def learn_test_plus_evaluate(algo: str, env="CartPole-v1"):
print("Saving results to {}".format(tmp_dir))

rllib_dir = str(Path(__file__).parent.parent.absolute())

# This is only supported without RLModule API. See AlgorithmConfig for
# more info. We need to prefetch the default config that will be used when we
# call rllib train here to see if the RLModule API is enabled.
algo_cls = get_trainable_cls(algo)
config = algo_cls.get_default_config()
if config._enable_rl_module_api:
eval_ = ', \\"evaluation_config\\": {}'
else:
eval_ = ', \\"evaluation_config\\": {\\"explore\\": false}'

print("RLlib dir = {}\nexists={}".format(rllib_dir, os.path.exists(rllib_dir)))
os.system(
"python {}/train.py --local-dir={} --run={} "
"--checkpoint-freq=1 --checkpoint-at-end ".format(rllib_dir, tmp_dir, algo)
+ '--config="{\\"num_gpus\\": 0, \\"num_workers\\": 1, '
'\\"evaluation_config\\": {\\"explore\\": false}'
+ '--config="{\\"num_gpus\\": 0, \\"num_workers\\": 1'
+ eval_
+ fw_
+ '}" '
+ '--stop="{\\"episode_reward_mean\\": 100.0}"'
Expand Down Expand Up @@ -182,7 +193,8 @@ def policy_fn(agent_id, episode, **kwargs):
policy_mapping_fn=policy_fn,
)
.resources(num_gpus=0)
.evaluation(evaluation_config=AlgorithmConfig.overrides(explore=False))
.evaluation(evaluation_config=AlgorithmConfig.overrides(explore=True))
.evaluation(evaluation_config=AlgorithmConfig.overrides(explore=True))
.rl_module(
rl_module_spec=MultiAgentRLModuleSpec(
module_specs={
Expand Down
3 changes: 3 additions & 0 deletions rllib/utils/exploration/tests/test_explorations.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def do_test_explorations(config, dummy_obs, prev_a=None, expected_mean_action=No
for exploration in [None, "Random"]:
local_config = config.copy()
if exploration == "Random":
if local_config._enable_rl_module_api:
# TODO(Artur): Support Random exploration with RL Modules.
continue
local_config.exploration(exploration_config={"type": "Random"})
print("exploration={}".format(exploration or "default"))

Expand Down

0 comments on commit 790ef9e

Please sign in to comment.