Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Learner API: Fix and unify grad-clipping configs and behaviors. #34464

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions rllib/algorithms/a3c/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,13 @@ def __init__(self, algo_class=None):
self.use_critic = True
self.use_gae = True
self.lambda_ = 1.0

self.grad_clip = 40.0
# Note: Only when using _enable_learner_api=True can the clipping mode be
# configured by the user. On the old API stack, RLlib will always clip by
# global_norm, no matter the value of `grad_clip_by`.
self.grad_clip_by = "global_norm"

self.lr_schedule = None
self.vf_loss_coeff = 0.5
self.entropy_coeff = 0.01
Expand Down
44 changes: 42 additions & 2 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,8 @@ def __init__(self, algo_class=None):
# `self.training()`
self.gamma = 0.99
self.lr = 0.001
self.grad_clip = None
self.grad_clip_by = "global_norm"
self.train_batch_size = 32
self.model = copy.deepcopy(MODEL_DEFAULTS)
self.optimizer = {}
Expand Down Expand Up @@ -881,7 +883,6 @@ def validate(self) -> None:
# RLModule.forward_exploration() method or setup model parameters such that it
# will disable the stochasticity of this method (e.g. by setting the std to 0
# or setting temperature to 0 for the Categorical distribution).

if self._enable_rl_module_api and not self.explore:
raise ValueError(
"When RLModule API is enabled, explore parameter cannot be False. "
Expand All @@ -895,6 +896,13 @@ def validate(self) -> None:
"setting temperature to 0 for the Categorical distribution)."
)

# Validate grad clipping settings.
if self.grad_clip_by not in ["value", "norm", "global_norm"]:
raise ValueError(
f"`grad_clip_by` ({self.grad_clip_by}) must be one of: 'value', "
"'norm', or 'global_norm'!"
)

# TODO: Deprecate self.simple_optimizer!
# Multi-GPU settings.
if self.simple_optimizer is True:
Expand Down Expand Up @@ -1031,7 +1039,7 @@ def validate(self) -> None:
"(i.e. num_learner_workers = 0)"
)

# resolve learner class
# Resolve learner class.
if self._enable_learner_api and self.learner_class is None:
learner_class_path = self.get_default_learner_class()
self.learner_class = deserialize_type(learner_class_path)
Expand Down Expand Up @@ -1591,8 +1599,11 @@ def rollouts(

def training(
self,
*,
gamma: Optional[float] = NotProvided,
lr: Optional[float] = NotProvided,
grad_clip: Optional[float] = NotProvided,
grad_clip_by: Optional[str] = NotProvided,
train_batch_size: Optional[int] = NotProvided,
model: Optional[dict] = NotProvided,
optimizer: Optional[dict] = NotProvided,
Expand All @@ -1605,6 +1616,29 @@ def training(
Args:
gamma: Float specifying the discount factor of the Markov Decision process.
lr: The default learning rate.
grad_clip: The value to use for gradient clipping. Depending on the
`grad_clip_by` setting, gradients will either be clipped by value,
norm, or global_norm (see docstring on `grad_clip_by` below for more
details). If `grad_clip` is None, gradients will be left unclipped.
grad_clip_by: If 'value': Will clip all computed gradients individually
inside the interval [-grad_clip, +grad_clip].
If 'norm', will compute the L2-norm of each weight/bias
gradient tensor and then clip all gradients such that this L2-norm does
not exceed `grad_clip`. The L2-norm of a tensor is computed via:
`sqrt(SUM(w0^2, w1^2, ..., wn^2))` where w[i] are the elements of the
tensor (no matter what the shape of this tensor is).
If 'global_norm', will compute the square of the L2-norm of each
weight/bias gradient tensor, sum up all these squared L2-norms across
all given gradient tensors (e.g. the entire module to
be updated), square root that overall sum, and then clip all gradients
such that this "global" L2-norm does not exceed the given value.
The global L2-norm over a list of tensors (e.g. W and V) is computed
via:
`sqrt[SUM(w0^2, w1^2, ..., wn^2) + SUM(v0^2, v1^2, ..., vm^2)]`, where
w[i] and v[j] are the elements of the tensors W and V (no matter what
the shapes of these tensors are).
Note that if `grad_clip` is None, the `grad_clip_by` setting has no
effect.
train_batch_size: Training batch size, if applicable.
model: Arguments passed into the policy model. See models/catalog.py for a
full list of the available model options.
Expand Down Expand Up @@ -1633,6 +1667,10 @@ def training(
self.gamma = gamma
if lr is not NotProvided:
self.lr = lr
if grad_clip is not NotProvided:
self.grad_clip = grad_clip
if grad_clip_by is not NotProvided:
self.grad_clip_by = grad_clip_by
if train_batch_size is not NotProvided:
self.train_batch_size = train_batch_size
if model is not NotProvided:
Expand Down Expand Up @@ -3089,6 +3127,8 @@ def get_learner_group_config(self, module_spec: ModuleSpec) -> LearnerGroupConfi
# TODO (Kourosh): optimizer config can now be more complicated.
optimizer_config={
"lr": self.lr,
"grad_clip": self.grad_clip,
"grad_clip_by": self.grad_clip_by,
},
learner_hps=self.learner_hps,
)
Expand Down
28 changes: 8 additions & 20 deletions rllib/algorithms/appo/appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from ray.rllib.algorithms.impala.impala import Impala, ImpalaConfig
from ray.rllib.algorithms.appo.tf.appo_tf_learner import AppoHPs, LEARNER_RESULTS_KL_KEY
from ray.rllib.algorithms.ppo.ppo import UpdateKL
from ray.rllib.execution.common import _get_shared_metrics, STEPS_SAMPLED_COUNTER
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
Expand Down Expand Up @@ -104,7 +103,13 @@ def __init__(self, algo_class=None):
self.learner_queue_timeout = 300
self.max_sample_requests_in_flight_per_worker = 2
self.broadcast_interval = 1

self.grad_clip = 40.0
# Note: Only when using _enable_learner_api=True can the clipping mode be
# configured by the user. On the old API stack, RLlib will always clip by
# global_norm, no matter the value of `grad_clip_by`.
self.grad_clip_by = "global_norm"

self.opt_type = "adam"
self.lr = 0.0005
self.lr_schedule = None
Expand Down Expand Up @@ -237,29 +242,12 @@ def validate(self) -> None:
self._learner_hps.clip_param = self.clip_param


# Still used by one of the old checkpoints in tests.
# Keep a shim version of this around.
class UpdateTargetAndKL:
def __init__(self, workers, config):
self.workers = workers
self.config = config
self.update_kl = UpdateKL(workers)
self.target_update_freq = (
config["num_sgd_iter"] * config["minibatch_buffer_size"]
)

def __call__(self, fetches):
metrics = _get_shared_metrics()
cur_ts = metrics.counters[STEPS_SAMPLED_COUNTER]
last_update = metrics.counters[LAST_TARGET_UPDATE_TS]
if cur_ts - last_update > self.target_update_freq:
metrics.counters[NUM_TARGET_UPDATES] += 1
metrics.counters[LAST_TARGET_UPDATE_TS] = cur_ts
# Update Target Network
self.workers.local_worker().foreach_policy_to_train(
lambda p, _: p.update_target()
)
# Also update KL Coeff
if self.config.use_kl_loss:
self.update_kl(fetches)


class APPO(Impala):
Expand Down
6 changes: 6 additions & 0 deletions rllib/algorithms/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,13 @@ def __init__(self):
self.td_model_lr = 6e-4
self.actor_lr = 8e-5
self.critic_lr = 8e-5

self.grad_clip = 100.0
# Note: Only when using _enable_learner_api=True can the clipping mode be
# configured by the user. On the old API stack, RLlib will always clip by
# global_norm, no matter the value of `grad_clip_by`.
self.grad_clip_by = "global_norm"

self.lambda_ = 0.95
self.dreamer_train_iters = 100
self.batch_size = 50
Expand Down
27 changes: 12 additions & 15 deletions rllib/algorithms/impala/impala.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
from functools import partial
import logging
import platform
import queue
Expand All @@ -15,10 +16,6 @@
_reduce_impala_results,
)
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.core.learner.learner_group_config import (
LearnerGroupConfig,
ModuleSpec,
)
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.evaluation.worker_set import handle_remote_call_result_errors
from ray.rllib.execution.buffers.mixin_replay_buffer import MixInMultiAgentReplayBuffer
Expand Down Expand Up @@ -125,7 +122,13 @@ def __init__(self, algo_class=None):
self.timeout_s_aggregator_manager = 0.0
self.broadcast_interval = 1
self.num_aggregation_workers = 0

self.grad_clip = 40.0
# Note: Only when using _enable_learner_api=True can the clipping mode be
# configured by the user. On the old API stack, RLlib will always clip by
# global_norm, no matter the value of `grad_clip_by`.
self.grad_clip_by = "global_norm"

self.opt_type = "adam"
self.lr_schedule = None
self.decay = 0.99
Expand Down Expand Up @@ -422,16 +425,6 @@ def validate(self) -> None:
self.vtrace_clip_pg_rho_threshold
)

@override(AlgorithmConfig)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed anymore with this PR. grad clipping has been universally moved into Learner.postprocess_gradients().

def get_learner_group_config(self, module_spec: ModuleSpec) -> LearnerGroupConfig:
lg_config = super().get_learner_group_config(module_spec)
optim_config = lg_config.optimizer_config
# TODO(avnishn): Make grad_clip a default parameter in algorithm_config's base
# class
optim_config.update({"grad_clip": self.grad_clip})
lg_config = lg_config.learner(optimizer_config=optim_config)
return lg_config

def get_replay_ratio(self) -> float:
"""Returns replay ratio (between 0.0 and 1.0) based off self.replay_proportion.

Expand Down Expand Up @@ -1051,6 +1044,10 @@ def process_experiences_tree_aggregation(
workers.

"""

def _process_episodes(actor, batch):
return actor.process_episodes(ray.get(batch))

for _, batch in worker_to_sample_batches_refs:
assert isinstance(batch, ObjectRef), (
"For efficiency, process_experiences_tree_aggregation should "
Expand All @@ -1061,7 +1058,7 @@ def process_experiences_tree_aggregation(
self._aggregator_actor_manager.healthy_actor_ids()
)
calls_placed = self._aggregator_actor_manager.foreach_actor_async(
lambda actor: actor.process_episodes(ray.get(batch)),
partial(_process_episodes, batch=batch),
remote_actor_ids=[aggregator_id],
)
if calls_placed <= 0:
Expand Down
7 changes: 6 additions & 1 deletion rllib/algorithms/qmix/qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,12 @@ def __init__(self):
self.double_q = True
self.optim_alpha = 0.99
self.optim_eps = 0.00001
self.grad_clip = 10

self.grad_clip = 10.0
# Note: Only when using _enable_learner_api=True can the clipping mode be
# configured by the user. On the old API stack, RLlib will always clip by
# global_norm, no matter the value of `grad_clip_by`.
self.grad_clip_by = "global_norm"

# QMix-torch overrides the TorchPolicy's learn_on_batch w/o specifying a
# alternative `learn_on_loaded_batch` alternative for the GPU.
Expand Down
8 changes: 7 additions & 1 deletion rllib/algorithms/simple_q/simple_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,13 @@ def __init__(self, algo_class=None):
self.store_buffer_in_checkpoints = False
self.lr_schedule = None
self.adam_epsilon = 1e-8
self.grad_clip = 40

self.grad_clip = 40.0
# Note: Only when using _enable_learner_api=True can the clipping mode be
# configured by the user. On the old API stack, RLlib will always clip by
# global_norm, no matter the value of `grad_clip_by`.
self.grad_clip_by = "global_norm"

self.tau = 1.0
# __sphinx_doc_end__
# fmt: on
Expand Down
9 changes: 0 additions & 9 deletions rllib/algorithms/slateq/slateq_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,15 +207,6 @@ def build_slateq_stats(policy: Policy, batch) -> Dict[str, TensorType]:
"q_loss": policy._q_loss,
"mean_actions": policy._mean_actions,
}
# if hasattr(policy, "_mean_grads_0"):
# stats.update({"mean_grads_0": policy._mean_grads_0})
# stats.update({"mean_grads_1": policy._mean_grads_1})
# stats.update({"mean_grads_2": policy._mean_grads_2})
# stats.update({"mean_grads_3": policy._mean_grads_3})
# stats.update({"mean_grads_4": policy._mean_grads_4})
# stats.update({"mean_grads_5": policy._mean_grads_5})
# stats.update({"mean_grads_6": policy._mean_grads_6})
# stats.update({"mean_grads_7": policy._mean_grads_7})
return stats


Expand Down
20 changes: 11 additions & 9 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,14 +522,14 @@ def compile_results(

# We put the stats for all modules under the ALL_MODULES key. e.g. average of
# the gradients across all modules will go here.
mean_grads = [
np.mean(grad)
mean_abs_grads = [
np.mean(np.abs(grad))
for grad in convert_to_numpy(postprocessed_gradients.values())
if grad is not None
]

module_learner_stats[ALL_MODULES] = {
"mean_gradient": np.mean(mean_grads),
"mean_abs_postprocessed_gradients": np.mean(mean_abs_grads),
self.TOTAL_LOSS_KEY: loss_numpy[self.TOTAL_LOSS_KEY],
}

Expand Down Expand Up @@ -754,19 +754,21 @@ def additional_update_per_module(

@OverrideToImplementCustomLogic
def postprocess_gradients(
self, gradients_dict: Mapping[str, Any]
self,
gradients_dict: Mapping[str, Any],
) -> Mapping[str, Any]:
"""Applies potential postprocessings to the gradients.
"""Applies potential postprocessing operations on the gradients.

In some algorithms, we may want to perform some postprocessing on the
gradients before they are applied. This method is called after gradients
have been computed, and modifies them before they are applied.
This method is called after gradients have been computed, and modifies them
before they are applied to the respective module(s).
This includes grad clipping by value, norm, or global-norm, or other
algorithm specific gradient postprocessing steps.

Args:
gradients_dict: A dictionary of gradients.

Returns:
A dictionary of updated gradients.
A dictionary with the updated gradients.
"""
return gradients_dict

Expand Down
Loading