Skip to content

Commit

Permalink
[rllib] Basic IMPALA implementation (using deepmind's reference vtrac…
Browse files Browse the repository at this point in the history
…e.py) (ray-project#2504)

 Rename AsyncSamplesOptimizer -> AsyncReplayOptimizer
  Add AsyncSamplesOptimizer that implements the IMPALA architecture
  integrate V-trace with a3c policy graph
  audit V-trace integration
  benchmark compare vs A3C and with V-trace on/off
PongNoFrameskip-v4 on IMPALA scaling from 16 to 128 workers, solving Pong in <10 min. For reference, solving this env takes ~40 minutes for Ape-X and several hours for A3C.
  • Loading branch information
ericl authored Aug 2, 2018
1 parent e4f68ff commit 9ea57c2
Show file tree
Hide file tree
Showing 22 changed files with 1,131 additions and 200 deletions.
19 changes: 19 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,22 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

--------------------------------------------------------------------------------

Code in python/ray/rllib/impala/vtrace.py from
https://github.com/deepmind/scalable_agent

Copyright 2018 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Binary file added doc/source/impala.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
18 changes: 16 additions & 2 deletions doc/source/rllib-algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Tuned examples: `PongNoFrameskip-v4 <https://github.com/ray-project/ray/blob/mas
Asynchronous Advantage Actor-Critic
-----------------------------------
`[paper] <https://arxiv.org/abs/1602.01783>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/a3c/a3c.py>`__
RLlib's A3C uses the AsyncGradientsOptimizer to apply gradients computed remotely on policy evaluation actors. It scales to up to 16-32 worker processes, depending on the environment. Both a TensorFlow (LSTM), and PyTorch version are available.
RLlib's A3C uses the AsyncGradientsOptimizer to apply gradients computed remotely on policy evaluation actors. It scales to up to 16-32 worker processes, depending on the environment. Both a TensorFlow (LSTM), and PyTorch version are available. Note that if you have a GPU, `IMPALA <#importance-weighted-actor-learner-architecture>`__ probably will perform better than A3C.

Tuned examples: `PongDeterministic-v4 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-a3c.yaml>`__, `PyTorch version <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-a3c-pytorch.yaml>`__

Expand Down Expand Up @@ -47,6 +47,20 @@ Tuned examples: `Humanoid-v1 <https://github.com/ray-project/ray/blob/master/pyt

RLlib's ES implementation scales further and is faster than a reference Redis implementation.

Importance Weighted Actor-Learner Architecture
----------------------------------------------

`[paper] <https://arxiv.org/abs/1802.01561>`__
`[implementation] <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/impala/impala.py>`__
In IMPALA, a central learner runs SGD in a tight loop while asynchronously pulling sample batches from many actor processes. RLlib's IMPALA implementation uses DeepMind's reference `V-trace code <https://github.com/deepmind/scalable_agent/blob/master/vtrace.py>`__. Note that we do not provide a deep residual network out of the box, but one can be plugged in as a `custom model <rllib-models.html#custom-models>`__.

Tuned examples: `PongNoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-impala.yaml>`__, `vectorized configuration <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/pong-impala-vectorized.yaml>`__

.. figure:: impala.png
:align: center

RLlib's IMPALA implementation scales from 16 to 128 workers on PongNoFrameskip-v4. With vectorization, similar learning performance to 128 workers can be achieved with only 32 workers. This about an order of magnitude faster than A3C, with similar sample efficiency.

Policy Gradients
----------------
`[paper] <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/pg/pg.py>`__ We include a vanilla policy gradients implementation as an example algorithm. This is usually outperformed by PPO.
Expand All @@ -64,4 +78,4 @@ Tuned examples: `Humanoid-v1 <https://github.com/ray-project/ray/blob/master/pyt
:width: 500px
:align: center

RLlib's PPO is more cost effective and faster than a reference PPO implementation.
RLlib's multi-GPU PPO can scale to hundreds of cores and be more cost effective than MPI-based implementations by leveraging mixed GPU and high-CPU machines.
8 changes: 7 additions & 1 deletion doc/source/rllib-training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,19 @@ located at ``~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint-1``
and renders its behavior in the environment specified by ``--env``.

Tuned Examples
--------------
~~~~~~~~~~~~~~

Some good hyperparameters and settings are available in
`the repository <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples>`__
(some of them are tuned to run on GPUs). If you find better settings or tune
an algorithm on a different domain, consider submitting a Pull Request!

You can run these with the ``train.py`` script as follows:

.. code-block:: bash
python ray/python/ray/rllib/train.py -f /path/to/tuned/example.yaml
Python API
----------

Expand Down
1 change: 1 addition & 0 deletions doc/source/rllib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ Algorithms
* `Deep Deterministic Policy Gradients <rllib-algorithms.html#deep-deterministic-policy-gradients>`__
* `Deep Q Networks <rllib-algorithms.html#deep-q-networks>`__
* `Evolution Strategies <rllib-algorithms.html#evolution-strategies>`__
* `Importance Weighted Actor-Learner Architecture <rllib-algorithms.html#importance-weighted-actor-learner-architecture>`__
* `Policy Gradients <rllib-algorithms.html#policy-gradients>`__
* `Proximal Policy Optimization <rllib-algorithms.html#proximal-policy-optimization>`__

Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
def _register_all():
for key in [
"PPO", "ES", "DQN", "APEX", "A3C", "BC", "PG", "DDPG", "APEX_DDPG",
"__fake", "__sigmoid_fake_data", "__parameter_tuning"
"IMPALA", "__fake", "__sigmoid_fake_data", "__parameter_tuning"
]:
from ray.rllib.agents.agent import get_agent_class
register_trainable(key, get_agent_class(key))
Expand Down
36 changes: 18 additions & 18 deletions python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Note: Keep in sync with changes to VTracePolicyGraph."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
Expand Down Expand Up @@ -77,8 +79,6 @@ def __init__(self, observation_space, action_space, config):
("advantages", advantages),
("value_targets", v_target),
]
self.state_in = self.model.state_in
self.state_out = self.model.state_out
TFPolicyGraph.__init__(
self,
observation_space,
Expand All @@ -88,29 +88,21 @@ def __init__(self, observation_space, action_space, config):
action_sampler=action_dist.sample(),
loss=self.loss.total_loss,
loss_inputs=loss_in,
state_inputs=self.state_in,
state_outputs=self.state_out,
state_inputs=self.model.state_in,
state_outputs=self.model.state_out,
seq_lens=self.model.seq_lens,
max_seq_len=self.config["model"]["max_seq_len"])

if self.config.get("summarize"):
bs = tf.to_float(tf.shape(self.observations)[0])
tf.summary.scalar("model/policy_graph", self.loss.pi_loss / bs)
tf.summary.scalar("model/value_loss", self.loss.vf_loss / bs)
tf.summary.scalar("model/entropy", self.loss.entropy / bs)
tf.summary.scalar("model/grad_gnorm", tf.global_norm(self._grads))
tf.summary.scalar("model/var_gnorm", tf.global_norm(self.var_list))
self.summary_op = tf.summary.merge_all()

self.sess.run(tf.global_variables_initializer())

def extra_compute_action_fetches(self):
return {"vf_preds": self.vf}

def value(self, ob, *args):
feed_dict = {self.observations: [ob]}
assert len(args) == len(self.state_in), (args, self.state_in)
for k, v in zip(self.state_in, args):
feed_dict = {self.observations: [ob], self.model.seq_lens: [1]}
assert len(args) == len(self.model.state_in), \
(args, self.model.state_in)
for k, v in zip(self.model.state_in, args):
feed_dict[k] = v
vf = self.sess.run(self.vf, feed_dict)
return vf[0]
Expand All @@ -126,7 +118,15 @@ def gradients(self, optimizer):

def extra_compute_grad_fetches(self):
if self.config.get("summarize"):
return {"summary": self.summary_op}
return {
"stats": {
"policy_loss": self.loss.pi_loss,
"value_loss": self.loss.vf_loss,
"entropy": self.loss.entropy,
"grad_gnorm": tf.global_norm(self._grads),
"var_gnorm": tf.global_norm(self.var_list),
},
}
else:
return {}

Expand All @@ -139,7 +139,7 @@ def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
last_r = 0.0
else:
next_state = []
for i in range(len(self.state_in)):
for i in range(len(self.model.state_in)):
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
last_r = self.value(sample_batch["new_obs"][-1], *next_state)
return compute_advantages(sample_batch, last_r, self.config["gamma"],
Expand Down
3 changes: 3 additions & 0 deletions python/ray/rllib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,9 @@ def get_agent_class(alg):
elif alg == "PG":
from ray.rllib.agents import pg
return pg.PGAgent
elif alg == "IMPALA":
from ray.rllib.agents import impala
return impala.ImpalaAgent
elif alg == "script":
from ray.tune import script_runner
return script_runner.ScriptRunner
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/agents/ddpg/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
APEX_DDPG_DEFAULT_CONFIG = merge_dicts(
DDPG_CONFIG,
{
"optimizer_class": "AsyncSamplesOptimizer",
"optimizer_class": "AsyncReplayOptimizer",
"optimizer": merge_dicts(
DDPG_CONFIG["optimizer"], {
"max_weight_sync_delay": 400,
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/agents/dqn/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
APEX_DEFAULT_CONFIG = merge_dicts(
DQN_CONFIG,
{
"optimizer_class": "AsyncSamplesOptimizer",
"optimizer_class": "AsyncReplayOptimizer",
"optimizer": merge_dicts(
DQN_CONFIG["optimizer"], {
"max_weight_sync_delay": 400,
Expand Down
3 changes: 3 additions & 0 deletions python/ray/rllib/agents/impala/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ray.rllib.agents.impala.impala import ImpalaAgent, DEFAULT_CONFIG

__all__ = ["ImpalaAgent", "DEFAULT_CONFIG"]
123 changes: 123 additions & 0 deletions python/ray/rllib/agents/impala/impala.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import pickle
import os
import time

import ray
from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph
from ray.rllib.agents.impala.vtrace_policy_graph import VTracePolicyGraph
from ray.rllib.agents.agent import Agent, with_common_config
from ray.rllib.optimizers import AsyncSamplesOptimizer
from ray.rllib.utils import FilterManager
from ray.tune.trial import Resources

OPTIMIZER_SHARED_CONFIGS = [
"sample_batch_size",
"train_batch_size",
]

DEFAULT_CONFIG = with_common_config({
# V-trace params (see vtrace.py).
"vtrace": True,
"vtrace_clip_rho_threshold": 1.0,
"vtrace_clip_pg_rho_threshold": 1.0,

# System params.
"sample_batch_size": 50,
"train_batch_size": 500,
"min_iter_time_s": 10,
"summarize": False,
"gpu": True,
"num_workers": 2,
"num_cpus_per_worker": 1,
"num_gpus_per_worker": 0,

# Learning params.
"grad_clip": 40.0,
"lr": 0.0001,
"vf_loss_coeff": 0.5,
"entropy_coeff": -0.01,

# Model and preprocessor options.
"clip_rewards": True,
"preprocessor_pref": "deepmind",
"model": {
"use_lstm": False,
"max_seq_len": 20,
"dim": 80,
},
})


class ImpalaAgent(Agent):
"""IMPALA implementation using DeepMind's V-trace."""

_agent_name = "IMPALA"
_default_config = DEFAULT_CONFIG

@classmethod
def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
return Resources(
cpu=1,
gpu=cf["gpu"] and 1 or 0,
extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"],
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"])

def _init(self):
for k in OPTIMIZER_SHARED_CONFIGS:
if k not in self.config["optimizer"]:
self.config["optimizer"][k] = self.config[k]
if self.config["vtrace"]:
policy_cls = VTracePolicyGraph
else:
policy_cls = A3CPolicyGraph
self.local_evaluator = self.make_local_evaluator(
self.env_creator, policy_cls)
self.remote_evaluators = self.make_remote_evaluators(
self.env_creator, policy_cls, self.config["num_workers"],
{"num_cpus": 1})
self.optimizer = AsyncSamplesOptimizer(self.local_evaluator,
self.remote_evaluators,
self.config["optimizer"])

def _train(self):
prev_steps = self.optimizer.num_steps_sampled
start = time.time()
self.optimizer.step()
while time.time() - start < self.config["min_iter_time_s"]:
self.optimizer.step()
FilterManager.synchronize(self.local_evaluator.filters,
self.remote_evaluators)
result = self.optimizer.collect_metrics()
result = result._replace(
timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps)
return result

def _stop(self):
# workaround for https://github.com/ray-project/ray/issues/1516
for ev in self.remote_evaluators:
ev.__ray_terminate__.remote()

def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir,
"checkpoint-{}".format(self.iteration))
agent_state = ray.get(
[a.save.remote() for a in self.remote_evaluators])
extra_data = {
"remote_state": agent_state,
"local_state": self.local_evaluator.save()
}
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
return checkpoint_path

def _restore(self, checkpoint_path):
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
ray.get([
a.restore.remote(o)
for a, o in zip(self.remote_evaluators, extra_data["remote_state"])
])
self.local_evaluator.restore(extra_data["local_state"])
Loading

0 comments on commit 9ea57c2

Please sign in to comment.