Skip to content

Commit

Permalink
[RLlib] Turn doc tests into '.. doctest::' (ray-project#37492)
Browse files Browse the repository at this point in the history
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
  • Loading branch information
ArturNiederfahrenhorst committed Jul 18, 2023
1 parent f0c9513 commit 1ff3b1d
Show file tree
Hide file tree
Showing 10 changed files with 332 additions and 288 deletions.
8 changes: 0 additions & 8 deletions doc/source/rllib/rllib-connector.rst
Original file line number Diff line number Diff line change
Expand Up @@ -260,14 +260,6 @@ for a new mock Cartpole environment that returns additional features and require
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__


End-to-end Example
------------------

TODO: End-to-end case study: adapting an old policy to bootstrap the training of new LSTM policies,
then serve the newly trained policy in a server/client setup.


Notable TODOs
-------------

Expand Down
35 changes: 1 addition & 34 deletions rllib/algorithms/apex_ddpg/apex_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,7 @@


class ApexDDPGConfig(DDPGConfig):
"""Defines a configuration class from which an ApexDDPG can be built.
Example:
.. code-block:: python
from ray.rllib.algorithms.apex_ddpg.apex_ddpg import ApexDDPGConfig
config = ApexDDPGConfig().training(lr=0.01).resources(num_gpus=1)
print(config.to_dict())
# Build an Algorithm object from the config and run one training iteration.
algo = config.build(env="Pendulum-v1")
algo.train()
Example:
.. code-block:: python
from ray.rllib.algorithms.apex_ddpg.apex_ddpg import ApexDDPGConfig
from ray import tune
import ray.air as air
config = ApexDDPGConfig()
# Print out some default values.
print(config.lr)
config.training(lr=tune.grid_search([0.001, 0.0001]))
# Set the config object's env.
config.environment(env="Pendulum-v1")
# Use to_dict() to get the old-style python config dict
# when running with tune.
tune.Tuner(
"APEX_DDPG",
run_config=air.RunConfig(stop={"episode_reward_mean": 200}),
param_space=config.to_dict(),
).fit()
"""
"""Defines a configuration class from which an ApexDDPG can be built."""

def __init__(self, algo_class=None):
"""Initializes an ApexDDPGConfig instance."""
Expand Down
19 changes: 8 additions & 11 deletions rllib/algorithms/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,24 +446,21 @@ def on_episode_end(
episode.custom_metrics["tracemalloc/worker/vms"] = worker_vms


def make_multi_callbacks(callback_class_list: List[Type[DefaultCallbacks]]):
def make_multi_callbacks(
callback_class_list: List[Type[DefaultCallbacks]],
) -> DefaultCallbacks:
"""Allows combining multiple sub-callbacks into one new callbacks class.
Example:
.. code-block:: python
config.callbacks(make_multi_callbacks([
MyCustomStatsCallbacks,
MyCustomVideoCallbacks,
MyCustomTraceCallbacks,
....
]))
The resulting DefaultCallbacks will call all the sub-callbacks' callbacks
when called.
Args:
callback_class_list: The list of sub-classes of DefaultCallbacks to
be baked into the to-be-returned class. All of these sub-classes'
implemented methods will be called in the given order.
Returns:
A DefaultCallbacks subclass that combines all the given sub-classes.
"""

class _MultiCallbacks(DefaultCallbacks):
Expand Down
99 changes: 53 additions & 46 deletions rllib/connectors/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,59 +149,63 @@ class AgentConnector(Connector):
AgentConnectorDataTypes can be used to specify arbitrary type of env data,
Example:
.. code-block:: python
# A dict of multi-agent data from one env step() call.
ac = AgentConnectorDataType(
env_id="env_1",
agent_id=None,
data={
"agent_1": np.array(...),
"agent_2": np.array(...),
}
)
.. testcode::
import numpy as np
# Represent a list of agent data from one env step() call.
ac = AgentConnectorDataType(
env_id="env_1",
agent_id=None,
data={
"agent_1": np.array([1, 2, 3]),
"agent_2": np.array([4, 5, 6]),
}
)
Example:
.. code-block:: python
# Single agent data ready to be preprocessed.
ac = AgentConnectorDataType(
env_id="env_1",
agent_id="agent_1",
data=np.array(...)
)
# ... or a single agent data ready to be preprocessed.
ac = AgentConnectorDataType(
env_id="env_1",
agent_id="agent_1",
data=np.array([1, 2, 3]),
)
We can adapt a simple stateless function into an agent connector by using
register_lambda_agent_connector:
.. code-block:: python
TimesTwoAgentConnector = register_lambda_agent_connector(
"TimesTwoAgentConnector", lambda data: data * 2
)
# We can also adapt a simple stateless function into an agent connector by
# using register_lambda_agent_connector:
More complicated agent connectors can be implemented by extending this
AgentConnector class:
import numpy as np
from ray.rllib.connectors.agent.lambdas import (
register_lambda_agent_connector
)
TimesTwoAgentConnector = register_lambda_agent_connector(
"TimesTwoAgentConnector", lambda data: data * 2
)
Example:
.. code-block:: python
class FrameSkippingAgentConnector(AgentConnector):
def __init__(self, n):
self._n = n
self._frame_count = default_dict(str, default_dict(str, int))
# More complicated agent connectors can be implemented by extending this
# AgentConnector class:
def reset(self, env_id: str):
del self._frame_count[env_id]
class FrameSkippingAgentConnector(AgentConnector):
def __init__(self, n):
self._n = n
self._frame_count = default_dict(str, default_dict(str, int))
def __call__(
self, ac_data: List[AgentConnectorDataType]
) -> List[AgentConnectorDataType]:
ret = []
for d in ac_data:
assert d.env_id and d.agent_id, "Frame skipping works per agent"
def reset(self, env_id: str):
del self._frame_count[env_id]
count = self._frame_count[ac_data.env_id][ac_data.agent_id]
self._frame_count[ac_data.env_id][ac_data.agent_id] = count + 1
def __call__(
self, ac_data: List[AgentConnectorDataType]
) -> List[AgentConnectorDataType]:
ret = []
for d in ac_data:
assert d.env_id and d.agent_id, "Skipping works per agent!"
if count % self._n == 0:
ret.append(d)
return ret
count = self._frame_count[ac_data.env_id][ac_data.agent_id]
self._frame_count[ac_data.env_id][ac_data.agent_id] = (
count + 1
)
if count % self._n == 0:
ret.append(d)
return ret
As shown, an agent connector may choose to emit an empty list to stop input
observations from being further prosessed.
Expand Down Expand Up @@ -279,7 +283,10 @@ class ActionConnector(Connector):
into an ActionConnector by using register_lambda_action_connector.
Example:
.. code-block:: python
.. testcode::
from ray.rllib.connectors.action.lambdas import (
register_lambda_action_connector
)
ZeroActionConnector = register_lambda_action_connector(
"ZeroActionsConnector",
lambda actions, states, fetches: (
Expand Down
Loading

0 comments on commit 1ff3b1d

Please sign in to comment.