Skip to content

Commit

Permalink
SAC for Mujoco Environments (#6642)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelzhiluo authored and ericl committed Dec 31, 2019
1 parent cdc1ce4 commit 1cb3354
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 7 deletions.
12 changes: 10 additions & 2 deletions doc/source/rllib-algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,17 @@ Soft Actor Critic (SAC)

SAC architecture (same as DQN)

RLlib's soft-actor critic implementation is ported from the `official SAC repo <https://github.com/rail-berkeley/softlearning>`__ to better integrate with RLlib APIs. Note that SAC has two fields to configure for custom models: ``policy_model`` and ``Q_model``, and currently has no support for non-continuous action distributions. It is also currently *experimental*.
RLlib's soft-actor critic implementation is ported from the `official SAC repo <https://github.com/rail-berkeley/softlearning>`__ to better integrate with RLlib APIs. Note that SAC has two fields to configure for custom models: ``policy_model`` and ``Q_model``, and currently has no support for non-continuous action distributions.

Tuned examples: `Pendulum-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/regression_tests/pendulum-sac.yaml>`__
Tuned examples: `Pendulum-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/regression_tests/pendulum-sac.yaml>`__, `HalfCheetah-v3 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/halfcheetah-sac.yaml>`__

**MuJoCo results @500k steps:** `more details <https://github.com/ray-project/rl-experiments>`__

============= ========== ===================
MuJoCo env RLlib SAC Haarnoja et al SAC
============= ========== ===================
HalfCheetah 8752 ~9000
============= ========== ===================

**SAC-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):

Expand Down
6 changes: 1 addition & 5 deletions rllib/agents/sac/sac_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def __init__(self,
shape=(num_outputs, ), name="model_out")
self.actions = tf.keras.layers.Input(
shape=(self.action_dim, ), name="actions")

shift_and_log_scale_diag = tf.keras.Sequential([
tf.keras.layers.Dense(
units=hidden,
Expand All @@ -90,10 +89,7 @@ def __init__(self,
for i, hidden in enumerate(actor_hiddens)
] + [
tf.keras.layers.Dense(
units=tfp.layers.MultivariateNormalTriL.params_size(
self.action_dim),
activation=None,
name="action_out")
units=2 * self.action_dim, activation=None, name="action_out")
])(self.model_out)

shift, log_scale_diag = tf.keras.layers.Lambda(
Expand Down
37 changes: 37 additions & 0 deletions rllib/tuned_examples/halfcheetah-sac.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Our implementation of SAC can reach 9k reward in 400k timesteps
halfcheetah_sac:
env: HalfCheetah-v3
run: SAC
stop:
episode_reward_mean: 9000
config:
horizon: 1000
soft_horizon: False
Q_model:
hidden_activation: relu
hidden_layer_sizes: [256, 256]
policy_model:
hidden_activation: relu
hidden_layer_sizes: [256, 256]
tau: 0.005
target_entropy: auto
no_done_at_end: True
n_step: 1
sample_batch_size: 1
prioritized_replay: False
train_batch_size: 256
target_network_update_freq: 1
timesteps_per_iteration: 1000
learning_starts: 10000
exploration_enabled: True
optimization:
actor_learning_rate: 0.0003
critic_learning_rate: 0.0003
entropy_learning_rate: 0.0003
num_workers: 0
num_gpus: 0
clip_actions: False
normalize_actions: True
evaluation_interval: 1
metrics_smoothing_episodes: 5

0 comments on commit 1cb3354

Please sign in to comment.