Skip to content

Commit

Permalink
[RLlib] Learner API enhancements and cleanups (prep. for DreamerV3). (r…
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 committed May 31, 2023
1 parent 2ead1ce commit f1f714c
Show file tree
Hide file tree
Showing 36 changed files with 1,438 additions and 1,195 deletions.
8 changes: 0 additions & 8 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2014,14 +2014,6 @@ py_test(
srcs = ["core/learner/tests/test_learner.py"]
)

# TODO (Kourosh): to be removed in favor of test_learner.py
py_test(
name = "test_torch_learner",
tags = ["team:rllib", "core", "ray_data"],
size = "medium",
srcs = ["core/learner/torch/tests/test_torch_learner.py"]
)

py_test(
name ="tests/test_algorithm_save_load_checkpoint_learner",
tags = ["team:rllib", "core"],
Expand Down
143 changes: 94 additions & 49 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
Callable,
Container,
Dict,
List,
Mapping,
Optional,
Tuple,
Expand All @@ -27,7 +26,7 @@
ModuleSpec,
)
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.core.rl_module.rl_module import ModuleID, SingleAgentRLModuleSpec
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.wrappers.atari_wrappers import is_atari
Expand Down Expand Up @@ -67,6 +66,7 @@
AlgorithmConfigDict,
EnvConfigDict,
EnvType,
LearningRateOrSchedule,
MultiAgentPolicyConfigDict,
PartialAlgorithmConfigDict,
PolicyID,
Expand Down Expand Up @@ -331,7 +331,6 @@ def __init__(self, algo_class=None):
# `self.training()`
self.gamma = 0.99
self.lr = 0.001
self.lr_schedule = None
self.grad_clip = None
self.grad_clip_by = "global_norm"
self.train_batch_size = 32
Expand All @@ -352,6 +351,7 @@ def __init__(self, algo_class=None):

# `self.multi_agent()`
self.policies = {DEFAULT_POLICY_ID: PolicySpec()}
self.algorithm_config_overrides_per_module = {}
self.policy_map_capacity = 100
self.policy_mapping_fn = self.DEFAULT_POLICY_MAPPING_FN
self.policies_to_train = None
Expand Down Expand Up @@ -911,7 +911,11 @@ def validate(self) -> None:

# LR-schedule checking.
if self._enable_learner_api:
Scheduler.validate(self.lr_schedule, "lr_schedule", "learning rate")
Scheduler.validate(
fixed_value_or_schedule=self.lr,
setting_name="lr",
description="learning rate",
)

# Validate grad clipping settings.
if self.grad_clip_by not in ["value", "norm", "global_norm"]:
Expand Down Expand Up @@ -1652,8 +1656,7 @@ def training(
self,
*,
gamma: Optional[float] = NotProvided,
lr: Optional[float] = NotProvided,
lr_schedule: Optional[List[List[Union[int, float]]]] = NotProvided,
lr: Optional[LearningRateOrSchedule] = NotProvided,
grad_clip: Optional[float] = NotProvided,
grad_clip_by: Optional[str] = NotProvided,
train_batch_size: Optional[int] = NotProvided,
Expand All @@ -1667,40 +1670,46 @@ def training(
Args:
gamma: Float specifying the discount factor of the Markov Decision process.
lr: The default learning rate.
lr_schedule: Learning rate schedule. In the format of
lr: The learning rate (float) or learning rate schedule in the format of
[[timestep, lr-value], [timestep, lr-value], ...]
Intermediary timesteps will be assigned to interpolated learning rate
values. A schedule config's first entry must start with timestep 0,
i.e.: [[0, initial_value], [...]].
grad_clip: The value to use for gradient clipping. Depending on the
`grad_clip_by` setting, gradients will either be clipped by value,
norm, or global_norm (see docstring on `grad_clip_by` below for more
details). If `grad_clip` is None, gradients will be left unclipped.
grad_clip_by: If 'value': Will clip all computed gradients individually
inside the interval [-grad_clip, +grad_clip].
If 'norm', will compute the L2-norm of each weight/bias
gradient tensor and then clip all gradients such that this L2-norm does
not exceed `grad_clip`. The L2-norm of a tensor is computed via:
`sqrt(SUM(w0^2, w1^2, ..., wn^2))` where w[i] are the elements of the
tensor (no matter what the shape of this tensor is).
If 'global_norm', will compute the square of the L2-norm of each
weight/bias gradient tensor, sum up all these squared L2-norms across
all given gradient tensors (e.g. the entire module to
In case of a schedule, intermediary timesteps will be assigned to
linearly interpolated learning rate values. A schedule config's first
entry must start with timestep 0, i.e.: [[0, initial_value], [...]].
Note: If you require a) more than one optimizer (per RLModule),
b) optimizer types that are not Adam, c) a learning rate schedule that
is not a linearly interpolated, piecewise schedule as described above,
or d) specifying c'tor arguments of the optimizer that are not the
learning rate (e.g. Adam's epsilon), then you must override your
Learner's `configure_optimizer_for_module()` method and handle
lr-scheduling yourself.
grad_clip: If None, no gradient clipping will be applied. Otherwise,
depending on the setting of `grad_clip_by`, the (float) value of
`grad_clip` will have the following effect:
If `grad_clip_by=value`: Will clip all computed gradients individually
inside the interval [-`grad_clip`, +`grad_clip`].
If `grad_clip_by=norm`, will compute the L2-norm of each weight/bias
gradient tensor individually and then clip all gradients such that these
L2-norms do not exceed `grad_clip`. The L2-norm of a tensor is computed
via: `sqrt(SUM(w0^2, w1^2, ..., wn^2))` where w[i] are the elements of
the tensor (no matter what the shape of this tensor is).
If `grad_clip_by=global_norm`, will compute the square of the L2-norm of
each weight/bias gradient tensor individually, sum up all these squared
L2-norms across all given gradient tensors (e.g. the entire module to
be updated), square root that overall sum, and then clip all gradients
such that this "global" L2-norm does not exceed the given value.
such that this global L2-norm does not exceed the given value.
The global L2-norm over a list of tensors (e.g. W and V) is computed
via:
`sqrt[SUM(w0^2, w1^2, ..., wn^2) + SUM(v0^2, v1^2, ..., vm^2)]`, where
w[i] and v[j] are the elements of the tensors W and V (no matter what
the shapes of these tensors are).
Note that if `grad_clip` is None, the `grad_clip_by` setting has no
effect.
grad_clip_by: See `grad_clip` for the effect of this setting on gradient
clipping. Allowed values are `value`, `norm`, and `global_norm`.
train_batch_size: Training batch size, if applicable.
model: Arguments passed into the policy model. See models/catalog.py for a
full list of the available model options.
TODO: Provide ModelConfig objects instead of dicts.
optimizer: Arguments to pass to the policy optimizer.
optimizer: Arguments to pass to the policy optimizer. This setting is not
used when `_enable_learner_api=True`.
max_requests_in_flight_per_sampler_worker: Max number of inflight requests
to each sampling worker. See the FaultTolerantActorManager class for
more details.
Expand All @@ -1724,8 +1733,6 @@ def training(
self.gamma = gamma
if lr is not NotProvided:
self.lr = lr
if lr_schedule is not NotProvided:
self.lr_schedule = lr_schedule
if grad_clip is not NotProvided:
self.grad_clip = grad_clip
if grad_clip_by is not NotProvided:
Expand Down Expand Up @@ -2103,6 +2110,9 @@ def multi_agent(
self,
*,
policies=NotProvided,
algorithm_config_overrides_per_module: Optional[
Dict[ModuleID, PartialAlgorithmConfigDict]
] = NotProvided,
policy_map_capacity: Optional[int] = NotProvided,
policy_mapping_fn: Optional[
Callable[[AgentID, "Episode"], PolicyID]
Expand Down Expand Up @@ -2130,6 +2140,18 @@ def multi_agent(
4-tuples of (policy_cls, obs_space, act_space, config) or PolicySpecs.
These tuples or PolicySpecs define the class of the policy, the
observation- and action spaces of the policies, and any extra config.
algorithm_config_overrides_per_module: Only used if both
`_enable_learner_api` and `_enable_rl_module_api` are True.
A mapping from ModuleIDs to
per-module AlgorithmConfig override dicts, which apply certain settings,
e.g. the learning rate, from the main AlgorithmConfig only to this
particular module (within a MultiAgentRLModule).
You can create override dicts by using the `AlgorithmConfig.overrides`
utility. For example, to override your learning rate and (PPO) lambda
setting just for a single RLModule with your MultiAgentRLModule, do:
config.multi_agent(algorithm_config_overrides_per_module={
"module_1": PPOConfig.overrides(lr=0.0002, lambda_=0.75),
})
policy_map_capacity: Keep this many policies in the "policy_map" (before
writing least-recently used ones to disk/S3).
policy_mapping_fn: Function mapping agent ids to policy ids. The signature
Expand Down Expand Up @@ -2198,6 +2220,11 @@ def multi_agent(
)
self.policies = policies

if algorithm_config_overrides_per_module is not NotProvided:
self.algorithm_config_overrides_per_module = (
algorithm_config_overrides_per_module
)

if policy_map_capacity is not NotProvided:
self.policy_map_capacity = policy_map_capacity

Expand Down Expand Up @@ -2513,7 +2540,7 @@ def rl_module(
if _enable_rl_module_api is True and self.exploration_config:
logger.warning(
"Setting `exploration_config={}` because you set "
"`_enable_rl_modules=True`. When RLModule API are "
"`_enable_rl_module_api=True`. When RLModule API are "
"enabled, exploration_config can not be "
"set. If you want to implement custom exploration behaviour, "
"please modify the `forward_exploration` method of the "
Expand All @@ -2526,13 +2553,13 @@ def rl_module(
elif _enable_rl_module_api is False and not self.exploration_config:
if self.__prior_exploration_config is not None:
logger.warning(
f"Setting `exploration_config="
"Setting `exploration_config="
f"{self.__prior_exploration_config}` because you set "
f"`_enable_rl_modules=False`. This exploration config was "
f"restored from a prior exploration config that was overriden "
f"when setting `_enable_rl_modules=True`. This occurs because "
f"when RLModule API are enabled, exploration_config can not "
f"be set."
"`_enable_rl_module_api=False`. This exploration config was "
"restored from a prior exploration config that was overriden "
"when setting `_enable_rl_module_api=True`. This occurs "
"because when RLModule API are enabled, exploration_config "
"can not be set."
)
self.exploration_config = self.__prior_exploration_config
self.__prior_exploration_config = None
Expand Down Expand Up @@ -3226,23 +3253,14 @@ def get_learner_group_config(self, module_spec: ModuleSpec) -> LearnerGroupConfi
if not self._is_frozen:
raise ValueError(
"Cannot call `get_learner_group_config()` on an unfrozen "
"AlgorithmConfig! Please call `freeze()` first."
"AlgorithmConfig! Please call `AlgorithmConfig.freeze()` first."
)

config = (
LearnerGroupConfig()
.module(module_spec)
.learner(
learner_class=self.learner_class,
# TODO (Kourosh): optimizer config can now be more complicated.
# TODO (Sven): Shouldn't optimizer config be part of learner HPs?
# E.g. if we have a lr schedule, this will have to be managed by
# the learner, NOT the optimizer directly.
optimizer_config={
"lr": self.lr,
"grad_clip": self.grad_clip,
"grad_clip_by": self.grad_clip_by,
},
learner_hyperparameters=self.get_learner_hyperparameters(),
)
.resources(
Expand Down Expand Up @@ -3275,8 +3293,35 @@ def get_learner_hyperparameters(self) -> LearnerHyperparameters:
Note that LearnerHyperparameters should always be derived directly from a
AlgorithmConfig object's own settings and considered frozen/read-only.
Returns:
A LearnerHyperparameters instance for the respective Learner.
"""
return LearnerHyperparameters(lr_schedule=self.lr_schedule)
# Compile the per-module learner hyperparameter instances (if applicable).
per_module_learner_hp_overrides = {}
if self.algorithm_config_overrides_per_module:
for (
module_id,
overrides,
) in self.algorithm_config_overrides_per_module.items():
# Copy this AlgorithmConfig object (unfreeze copy), update copy from
# the provided override dict for this module_id, then
# create a new LearnerHyperparameter object from this altered
# AlgorithmConfig.
config_for_module = self.copy(copy_frozen=False).update_from_dict(
overrides
)
config_for_module.algorithm_config_overrides_per_module = None
per_module_learner_hp_overrides[
module_id
] = config_for_module.get_learner_hyperparameters()

return LearnerHyperparameters(
learning_rate=self.lr,
grad_clip=self.grad_clip,
grad_clip_by=self.grad_clip_by,
_per_module_overrides=per_module_learner_hp_overrides,
)

def __setattr__(self, key, value):
"""Gatekeeper in case we are in frozen state and need to error."""
Expand Down
6 changes: 3 additions & 3 deletions rllib/algorithms/appo/appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,9 @@ def _get_additional_update_kwargs(self, train_results) -> dict:
return dict(
last_update=self._counters[LAST_TARGET_UPDATE_TS],
mean_kl_loss_per_module={
mid: r[LEARNER_RESULTS_KL_KEY]
for mid, r in train_results.items()
if mid != ALL_MODULES
module_id: r[LEARNER_RESULTS_KL_KEY]
for module_id, r in train_results.items()
if module_id != ALL_MODULES
},
)

Expand Down
Loading

0 comments on commit f1f714c

Please sign in to comment.