Skip to content

Commit

Permalink
Change SQIL SAC to use Pendulum (#800)
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamGleave authored Oct 5, 2023
1 parent cd76326 commit 573b086
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions docs/tutorials/8a_train_sqil_sac.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
"[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/8a_train_sqil_sac.ipynb)\n",
"# Train an Agent using Soft Q Imitation Learning with SAC\n",
"\n",
"In the previous tutorial, we used Soft Q Imitation Learning ([SQIL](https://arxiv.org/abs/1905.11108)) on top of the DQN base algorithm. In fact, SQIL can be combined with any off-policy algorithm from `stable_baselines3`. Here, we train a HalfCheetah agent using SQIL + SAC."
"In the previous tutorial, we used Soft Q Imitation Learning ([SQIL](https://arxiv.org/abs/1905.11108)) on top of the DQN base algorithm. In fact, SQIL can be combined with any off-policy algorithm from `stable_baselines3`. Here, we train a Pendulum agent using SQIL + SAC."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, we need some expert trajectories in our environment (`seals/HalfCheetah-v0`).\n",
"First, we need some expert trajectories in our environment (`Pendulum-v1`).\n",
"Note that you can use other environments, but the action space must be continuous."
]
},
Expand All @@ -28,7 +28,7 @@
"from imitation.data import huggingface_utils\n",
"\n",
"# Download some expert trajectories from the HuggingFace Datasets Hub.\n",
"dataset = datasets.load_dataset(\"HumanCompatibleAI/ppo-seals-HalfCheetah-v0\")\n",
"dataset = datasets.load_dataset(\"HumanCompatibleAI/ppo-Pendulum-v1\")\n",
"\n",
"# Convert the dataset to a format usable by the imitation library.\n",
"expert_trajectories = huggingface_utils.TrajectoryDatasetSequence(dataset[\"train\"])"
Expand Down Expand Up @@ -75,12 +75,11 @@
"from imitation.util.util import make_vec_env\n",
"import numpy as np\n",
"from stable_baselines3 import sac\n",
"import seals # noqa: F401 # needed to load \"seals/\" environments\n",
"\n",
"SEED = 42\n",
"\n",
"venv = make_vec_env(\n",
" \"seals/HalfCheetah-v1\",\n",
" \"Pendulum-v1\",\n",
" rng=np.random.default_rng(seed=SEED),\n",
")\n",
"\n",
Expand Down

0 comments on commit 573b086

Please sign in to comment.