Skip to content

Commit

Permalink
Tune hyperparameters in tutorials for GAIL and AIRL (#772)
Browse files Browse the repository at this point in the history
* Pin huggingface_sb3 version.

* Properly specify the compatible seals version so it does not auto-upgrade to 0.2.

* Make random_mdp test deterministic by seeding the environment.

* Tune hyperparameters in tutorials for GAIL and AIRL

* Modify .rst docs for GAIL and AIRL to match tutorials

* GAIL and AIRL tutorials: report also std in results

---------

Co-authored-by: Maximilian Ernestus <[email protected]>
Co-authored-by: Adam Gleave <[email protected]>
  • Loading branch information
3 people authored Sep 7, 2023
1 parent f09aeea commit 74b63ff
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 50 deletions.
18 changes: 11 additions & 7 deletions docs/algorithms/airl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,13 @@ Detailed example notebook: :doc:`../tutorials/4_train_airl`
learner = PPO(
env=env,
policy=MlpPolicy,
batch_size=16,
learning_rate=0.0001,
n_epochs=2,
batch_size=64,
ent_coef=0.0,
learning_rate=0.0005,
gamma=0.95,
clip_range=0.1,
vf_coef=0.1,
n_epochs=5,
seed=SEED,
)
reward_net = BasicShapedRewardNet(
Expand All @@ -72,9 +76,9 @@ Detailed example notebook: :doc:`../tutorials/4_train_airl`
)
airl_trainer = AIRL(
demonstrations=rollouts,
demo_batch_size=1024,
gen_replay_buffer_capacity=2048,
n_disc_updates_per_round=4,
demo_batch_size=2048,
gen_replay_buffer_capacity=512,
n_disc_updates_per_round=16,
venv=env,
gen_algo=learner,
reward_net=reward_net,
Expand All @@ -84,7 +88,7 @@ Detailed example notebook: :doc:`../tutorials/4_train_airl`
learner_rewards_before_training, _ = evaluate_policy(
learner, env, 100, return_episode_rewards=True,
)
airl_trainer.train(20000)
airl_trainer.train(20000) # Train for 2_000_000 steps to match expert.
env.seed(SEED)
learner_rewards_after_training, _ = evaluate_policy(
learner, env, 100, return_episode_rewards=True,
Expand Down
15 changes: 8 additions & 7 deletions docs/algorithms/gail.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Detailed example notebook: :doc:`../tutorials/3_train_gail`
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.policies.serialize import load_policy
from imitation.rewards.reward_nets import BasicShapedRewardNet
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm
from imitation.util.util import make_vec_env

Expand Down Expand Up @@ -60,20 +60,21 @@ Detailed example notebook: :doc:`../tutorials/3_train_gail`
policy=MlpPolicy,
batch_size=64,
ent_coef=0.0,
learning_rate=0.00001,
n_epochs=1,
learning_rate=0.0004,
gamma=0.95,
n_epochs=5,
seed=SEED,
)
reward_net = BasicShapedRewardNet(
reward_net = BasicRewardNet(
observation_space=env.observation_space,
action_space=env.action_space,
normalize_input_layer=RunningNorm,
)
gail_trainer = GAIL(
demonstrations=rollouts,
demo_batch_size=1024,
gen_replay_buffer_capacity=2048,
n_disc_updates_per_round=4,
gen_replay_buffer_capacity=512,
n_disc_updates_per_round=8,
venv=env,
gen_algo=learner,
reward_net=reward_net,
Expand All @@ -86,7 +87,7 @@ Detailed example notebook: :doc:`../tutorials/3_train_gail`
)

# train the learner and evaluate again
gail_trainer.train(20000)
gail_trainer.train(20000) # Train for 800_000 steps to match expert.
env.seed(SEED)
learner_rewards_after_training, _ = evaluate_policy(
learner, env, 100, return_episode_rewards=True,
Expand Down
38 changes: 20 additions & 18 deletions docs/tutorials/3_train_gail.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
"outputs": [],
"source": [
"from imitation.algorithms.adversarial.gail import GAIL\n",
"from imitation.rewards.reward_nets import BasicShapedRewardNet\n",
"from imitation.rewards.reward_nets import BasicRewardNet\n",
"from imitation.util.networks import RunningNorm\n",
"from stable_baselines3 import PPO\n",
"from stable_baselines3.ppo import MlpPolicy\n",
Expand All @@ -100,20 +100,21 @@
" policy=MlpPolicy,\n",
" batch_size=64,\n",
" ent_coef=0.0,\n",
" learning_rate=0.00001,\n",
" n_epochs=1,\n",
" learning_rate=0.0004,\n",
" gamma=0.95,\n",
" n_epochs=5,\n",
" seed=SEED,\n",
")\n",
"reward_net = BasicShapedRewardNet(\n",
"reward_net = BasicRewardNet(\n",
" observation_space=env.observation_space,\n",
" action_space=env.action_space,\n",
" normalize_input_layer=RunningNorm,\n",
")\n",
"gail_trainer = GAIL(\n",
" demonstrations=rollouts,\n",
" demo_batch_size=1024,\n",
" gen_replay_buffer_capacity=2048,\n",
" n_disc_updates_per_round=4,\n",
" gen_replay_buffer_capacity=512,\n",
" n_disc_updates_per_round=8,\n",
" venv=env,\n",
" gen_algo=learner,\n",
" reward_net=reward_net,\n",
Expand All @@ -126,7 +127,7 @@
")\n",
"\n",
"# train the learner and evaluate again\n",
"gail_trainer.train(20000)\n",
"gail_trainer.train(800_000)\n",
"env.seed(SEED)\n",
"learner_rewards_after_training, _ = evaluate_policy(\n",
" learner, env, 100, return_episode_rewards=True\n",
Expand All @@ -137,7 +138,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"When we look at the histograms of rewards before and after learning, we can see that the learner is not perfect yet, but it made some progress at least."
"We can see that an untrained policy performs poorly, while GAIL matches expert returns (500):"
]
},
{
Expand All @@ -146,17 +147,18 @@
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"print(\"mean reward after training:\", np.mean(learner_rewards_after_training))\n",
"print(\"mean reward before training:\", np.mean(learner_rewards_before_training))\n",
"\n",
"plt.hist(\n",
" [learner_rewards_before_training, learner_rewards_after_training],\n",
" label=[\"untrained\", \"trained\"],\n",
"print(\n",
" \"Rewards before training:\",\n",
" np.mean(learner_rewards_before_training),\n",
" \"+/-\",\n",
" np.std(learner_rewards_before_training),\n",
")\n",
"plt.legend()\n",
"plt.show()"
"print(\n",
" \"Rewards after training:\",\n",
" np.mean(learner_rewards_after_training),\n",
" \"+/-\",\n",
" np.std(learner_rewards_after_training),\n",
")"
]
}
],
Expand Down
47 changes: 29 additions & 18 deletions docs/tutorials/4_train_airl.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@
"\n",
"SEED = 42\n",
"\n",
"FAST = True\n",
"\n",
"if FAST:\n",
" N_RL_TRAIN_STEPS = 800_000\n",
"else:\n",
" N_RL_TRAIN_STEPS = 2_000_000\n",
"\n",
"env = make_vec_env(\n",
" \"seals/CartPole-v0\",\n",
" rng=np.random.default_rng(SEED),\n",
Expand Down Expand Up @@ -96,10 +103,13 @@
"learner = PPO(\n",
" env=env,\n",
" policy=MlpPolicy,\n",
" batch_size=16,\n",
" batch_size=64,\n",
" ent_coef=0.0,\n",
" learning_rate=0.0001,\n",
" n_epochs=2,\n",
" learning_rate=0.0005,\n",
" gamma=0.95,\n",
" clip_range=0.1,\n",
" vf_coef=0.1,\n",
" n_epochs=5,\n",
" seed=SEED,\n",
")\n",
"reward_net = BasicShapedRewardNet(\n",
Expand All @@ -109,9 +119,9 @@
")\n",
"airl_trainer = AIRL(\n",
" demonstrations=rollouts,\n",
" demo_batch_size=1024,\n",
" gen_replay_buffer_capacity=2048,\n",
" n_disc_updates_per_round=4,\n",
" demo_batch_size=2048,\n",
" gen_replay_buffer_capacity=512,\n",
" n_disc_updates_per_round=16,\n",
" venv=env,\n",
" gen_algo=learner,\n",
" reward_net=reward_net,\n",
Expand All @@ -121,7 +131,7 @@
"learner_rewards_before_training, _ = evaluate_policy(\n",
" learner, env, 100, return_episode_rewards=True\n",
")\n",
"airl_trainer.train(20000)\n",
"airl_trainer.train(N_RL_TRAIN_STEPS)\n",
"env.seed(SEED)\n",
"learner_rewards_after_training, _ = evaluate_policy(\n",
" learner, env, 100, return_episode_rewards=True\n",
Expand All @@ -132,7 +142,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"When we look at the histograms of rewards before and after learning, we can see that the learner is not perfect yet, but it made some progress at least."
"We can see that an untrained policy performs poorly, while AIRL brings an improvement. To make it match the expert performance (500), set the flag `FAST` to `False` in the first cell."
]
},
{
Expand All @@ -141,17 +151,18 @@
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"print(\"mean reward after training:\", np.mean(learner_rewards_after_training))\n",
"print(\"mean reward before training:\", np.mean(learner_rewards_before_training))\n",
"\n",
"plt.hist(\n",
" [learner_rewards_before_training, learner_rewards_after_training],\n",
" label=[\"untrained\", \"trained\"],\n",
"print(\n",
" \"Rewards before training:\",\n",
" np.mean(learner_rewards_before_training),\n",
" \"+/-\",\n",
" np.std(learner_rewards_before_training),\n",
")\n",
"plt.legend()\n",
"plt.show()"
"print(\n",
" \"Rewards after training:\",\n",
" np.mean(learner_rewards_after_training),\n",
" \"+/-\",\n",
" np.std(learner_rewards_after_training),\n",
")"
]
}
],
Expand Down

0 comments on commit 74b63ff

Please sign in to comment.