Skip to content

Commit

Permalink
[RLlib] Cleanups: Learner API and Catalog. (ray-project#35982)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 committed Jun 2, 2023
1 parent 0ba48e4 commit baf2d72
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 108 deletions.
4 changes: 0 additions & 4 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1046,21 +1046,18 @@ py_test(
size = "small",
srcs = ["algorithms/dt/tests/test_segmentation_buffer.py"]
)

py_test(
name = "test_dt_model",
tags = ["team:rllib", "algorithms_dir"],
size = "small",
srcs = ["algorithms/dt/tests/test_dt_model.py"]
)

py_test(
name = "test_dt_policy",
tags = ["team:rllib", "algorithms_dir"],
size = "small",
srcs = ["algorithms/dt/tests/test_dt_policy.py"]
)

py_test(
name = "test_dt",
tags = ["team:rllib", "algorithms_dir", "ray_data"],
Expand Down Expand Up @@ -1102,7 +1099,6 @@ py_test(
size = "large",
srcs = ["algorithms/impala/tests/test_impala_off_policyness.py"]
)

py_test(
name = "test_impala_learner",
tags = ["team:rllib", "algorithms_dir"],
Expand Down
4 changes: 2 additions & 2 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,11 +601,11 @@ def postprocess_gradients_for_module(

@OverrideToImplementCustomLogic
@abc.abstractmethod
def apply_gradients(self, gradients: ParamDict) -> None:
def apply_gradients(self, gradients_dict: ParamDict) -> None:
"""Applies the gradients to the MultiAgentRLModule parameters.
Args:
gradients: A dictionary of gradients in the same (flat) format as
gradients_dict: A dictionary of gradients in the same (flat) format as
self._params. Note that top-level structures, such as module IDs,
will not be present anymore in this dict. It will merely map gradient
tensor references to gradient tensors.
Expand Down
23 changes: 11 additions & 12 deletions rllib/core/learner/tf/tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,22 +112,21 @@ def compute_gradients(
return grads

@override(Learner)
def apply_gradients(self, gradients: ParamDict):
def apply_gradients(self, gradients_dict: ParamDict) -> None:
# TODO (Avnishn, kourosh): apply gradients doesn't work in cases where
# only some agents have a sample batch that is passed but not others.
# This is probably because of the way that we are iterating over the
# parameters in the optim_to_param_dictionary.
for optimizer, param_ref_seq in self._optimizer_parameters.items():
variable_list = [
self._params[param_ref]
for param_ref in param_ref_seq
if gradients[param_ref] is not None
]
gradient_list = [
gradients[param_ref]
for param_ref in param_ref_seq
if gradients[param_ref] is not None
]
for optimizer in self._optimizer_parameters:
optim_grad_dict = self.filter_param_dict_for_optimizer(
optimizer=optimizer, param_dict=gradients_dict
)
variable_list = []
gradient_list = []
for param_ref, grad in optim_grad_dict.items():
if grad is not None:
variable_list.append(self._params[param_ref])
gradient_list.append(grad)
optimizer.apply_gradients(zip(gradient_list, variable_list))

@override(Learner)
Expand Down
4 changes: 2 additions & 2 deletions rllib/core/learner/torch/torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ def compute_gradients(
return grads

@override(Learner)
def apply_gradients(self, gradients: ParamDict) -> None:
def apply_gradients(self, gradients_dict: ParamDict) -> None:
# Make sure the parameters do not carry gradients on their own.
for optim in self._optimizer_parameters:
optim.zero_grad(set_to_none=True)

# Set the gradient of the parameters.
for pid, grad in gradients.items():
for pid, grad in gradients_dict.items():
self._params[pid].grad = grad

# For each optimizer call its step function.
Expand Down
174 changes: 87 additions & 87 deletions rllib/core/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,93 +25,6 @@
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space


def _multi_action_dist_partial_helper(
catalog_cls: "Catalog", action_space: gym.Space, framework: str
) -> Distribution:
"""Helper method to get a partial of a MultiActionDistribution.
This is useful for when we want to create MultiActionDistributions from
logits only (!) later, but know the action space now already.
Args:
catalog_cls: The ModelCatalog class to use.
action_space: The action space to get the child distribution classes for.
framework: The framework to use.
Returns:
A partial of the TorchMultiActionDistribution class.
"""
action_space_struct = get_base_struct_from_space(action_space)
flat_action_space = flatten_space(action_space)
child_distribution_cls_struct = tree.map_structure(
lambda s: catalog_cls.get_dist_cls_from_action_space(
action_space=s,
framework=framework,
),
action_space_struct,
)
flat_distribution_clses = tree.flatten(child_distribution_cls_struct)

logit_lens = [
int(dist_cls.required_input_dim(space))
for dist_cls, space in zip(flat_distribution_clses, flat_action_space)
]

if framework == "torch":
from ray.rllib.models.torch.torch_distributions import (
TorchMultiDistribution,
)

multi_action_dist_cls = TorchMultiDistribution
elif framework == "tf2":
from ray.rllib.models.tf.tf_distributions import TfMultiDistribution

multi_action_dist_cls = TfMultiDistribution
else:
raise ValueError(f"Unsupported framework: {framework}")

partial_dist_cls = multi_action_dist_cls.get_partial_dist_cls(
space=action_space,
child_distribution_cls_struct=child_distribution_cls_struct,
input_lens=logit_lens,
)
return partial_dist_cls


def _multi_categorical_dist_partial_helper(
action_space: gym.Space, framework: str
) -> Distribution:
"""Helper method to get a partial of a MultiCategorical Distribution.
This is useful for when we want to create MultiCategorical Distribution from
logits only (!) later, but know the action space now already.
Args:
action_space: The action space to get the child distribution classes for.
framework: The framework to use.
Returns:
A partial of the MultiCategorical class.
"""

if framework == "torch":
from ray.rllib.models.torch.torch_distributions import TorchMultiCategorical

multi_categorical_dist_cls = TorchMultiCategorical
elif framework == "tf2":
from ray.rllib.models.tf.tf_distributions import TfMultiCategorical

multi_categorical_dist_cls = TfMultiCategorical
else:
raise ValueError(f"Unsupported framework: {framework}")

partial_dist_cls = multi_categorical_dist_cls.get_partial_dist_cls(
space=action_space, input_lens=list(action_space.nvec)
)

return partial_dist_cls


class Catalog:
"""Describes the sub-modules architectures to be used in RLModules.
Expand Down Expand Up @@ -593,3 +506,90 @@ def get_preprocessor(observation_space: gym.Space, **kwargs) -> Preprocessor:
cls = get_preprocessor(observation_space)
prep = cls(observation_space, options)
return prep


def _multi_action_dist_partial_helper(
catalog_cls: "Catalog", action_space: gym.Space, framework: str
) -> Distribution:
"""Helper method to get a partial of a MultiActionDistribution.
This is useful for when we want to create MultiActionDistributions from
logits only (!) later, but know the action space now already.
Args:
catalog_cls: The ModelCatalog class to use.
action_space: The action space to get the child distribution classes for.
framework: The framework to use.
Returns:
A partial of the TorchMultiActionDistribution class.
"""
action_space_struct = get_base_struct_from_space(action_space)
flat_action_space = flatten_space(action_space)
child_distribution_cls_struct = tree.map_structure(
lambda s: catalog_cls.get_dist_cls_from_action_space(
action_space=s,
framework=framework,
),
action_space_struct,
)
flat_distribution_clses = tree.flatten(child_distribution_cls_struct)

logit_lens = [
int(dist_cls.required_input_dim(space))
for dist_cls, space in zip(flat_distribution_clses, flat_action_space)
]

if framework == "torch":
from ray.rllib.models.torch.torch_distributions import (
TorchMultiDistribution,
)

multi_action_dist_cls = TorchMultiDistribution
elif framework == "tf2":
from ray.rllib.models.tf.tf_distributions import TfMultiDistribution

multi_action_dist_cls = TfMultiDistribution
else:
raise ValueError(f"Unsupported framework: {framework}")

partial_dist_cls = multi_action_dist_cls.get_partial_dist_cls(
space=action_space,
child_distribution_cls_struct=child_distribution_cls_struct,
input_lens=logit_lens,
)
return partial_dist_cls


def _multi_categorical_dist_partial_helper(
action_space: gym.Space, framework: str
) -> Distribution:
"""Helper method to get a partial of a MultiCategorical Distribution.
This is useful for when we want to create MultiCategorical Distribution from
logits only (!) later, but know the action space now already.
Args:
action_space: The action space to get the child distribution classes for.
framework: The framework to use.
Returns:
A partial of the MultiCategorical class.
"""

if framework == "torch":
from ray.rllib.models.torch.torch_distributions import TorchMultiCategorical

multi_categorical_dist_cls = TorchMultiCategorical
elif framework == "tf2":
from ray.rllib.models.tf.tf_distributions import TfMultiCategorical

multi_categorical_dist_cls = TfMultiCategorical
else:
raise ValueError(f"Unsupported framework: {framework}")

partial_dist_cls = multi_categorical_dist_cls.get_partial_dist_cls(
space=action_space, input_lens=list(action_space.nvec)
)

return partial_dist_cls
2 changes: 1 addition & 1 deletion rllib/utils/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def clip_gradients(
gradients_dict: "ParamDict",
*,
grad_clip: Optional[float] = None,
grad_clip_by: str = "value",
grad_clip_by: str,
) -> Optional[float]:
"""Performs gradient clipping on a grad-dict based on a clip value and clip mode.
Expand Down

0 comments on commit baf2d72

Please sign in to comment.