Skip to content

Commit

Permalink
Add CLI for SQIL (#784)
Browse files Browse the repository at this point in the history
* Add sqil cli

* Lints

* More lints

* Add shine requirement, used for DQN progress bar.

* Undo removal of src.policy

* Remove old comment

* Add trailing commas

* change dependencies

* Update src/imitation/scripts/config/train_imitation.py

Co-authored-by: Adam Gleave <[email protected]>

* Move save_policy and reconstruct_policy"

* Respond to fix save_policy issue

* Remove some boilerplate

* fix use of save_policy

* Fix bug in sqil

* Update src/imitation/scripts/ingredients/sqil.py

Co-authored-by: Adam Gleave <[email protected]>

* address PR

* fix typing error

* fix typing error

* change shine to rich

* remove line

* Update src/imitation/scripts/ingredients/sqil.py

Co-authored-by: Adam Gleave <[email protected]>

* respond to adam comments

* make line shorter

* Simplify RL hook

---------

Co-authored-by: Adam Gleave <[email protected]>
  • Loading branch information
lukasberglund and AdamGleave committed Sep 16, 2023
1 parent cb93fb0 commit 885beff
Show file tree
Hide file tree
Showing 11 changed files with 177 additions and 47 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str:
"numpy>=1.15",
"torch>=1.4.0",
"tqdm",
"rich",
"scikit-learn>=0.21.2",
"seals~=0.1.5",
STABLE_BASELINES3,
Expand Down
8 changes: 0 additions & 8 deletions src/imitation/algorithms/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,11 +483,3 @@ def process_batch():
# if there remains an incomplete batch
batch_num += 1
process_batch()

def save_policy(self, policy_path: types.AnyPath) -> None:
"""Save policy to a path. Can be reloaded by `.reconstruct_policy()`.
Args:
policy_path: path to save policy to.
"""
th.save(self.policy, util.parse_path(policy_path))
10 changes: 1 addition & 9 deletions src/imitation/algorithms/dagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,18 +544,10 @@ def save_trainer(self) -> Tuple[pathlib.Path, pathlib.Path]:
self.scratch_dir / "policy-latest.pt",
]
for policy_path in policy_paths:
self.save_policy(policy_path)
util.save_policy(self.policy, policy_path)

return checkpoint_paths[0], policy_paths[0]

def save_policy(self, policy_path: types.AnyPath) -> None:
"""Save the current policy only (and not the rest of the trainer).
Args:
policy_path: path to save policy to.
"""
self.bc_trainer.save_policy(policy_path)


class SimpleDAggerTrainer(DAggerTrainer):
"""Simpler subclass of DAggerTrainer for training with synthetic feedback."""
Expand Down
5 changes: 3 additions & 2 deletions src/imitation/scripts/config/train_imitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@
from imitation.scripts.ingredients import demonstrations as demos_common
from imitation.scripts.ingredients import environment, expert
from imitation.scripts.ingredients import logging as logging_ingredient
from imitation.scripts.ingredients import policy, policy_evaluation
from imitation.scripts.ingredients import policy_evaluation, sqil

train_imitation_ex = sacred.Experiment(
"train_imitation",
ingredients=[
logging_ingredient.logging_ingredient,
demos_common.demonstrations_ingredient,
policy.policy_ingredient,
expert.expert_ingredient,
environment.environment_ingredient,
policy_evaluation.policy_evaluation_ingredient,
bc.bc_ingredient,
sqil.sqil_ingredient,
],
)

Expand Down Expand Up @@ -100,3 +100,4 @@ def seals_humanoid():
def fast():
dagger = dict(total_timesteps=50)
bc = dict(train_kwargs=dict(n_batches=50))
sqil = dict(total_timesteps=50)
10 changes: 7 additions & 3 deletions src/imitation/scripts/ingredients/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,14 @@ def config():

@rl_ingredient.config_hook
def config_hook(config, command_name, logger):
"""Sets defaults equivalent to sb3.PPO default hyperparameters."""
del command_name, logger
"""Sets defaults equivalent to sb3.PPO default hyperparameters.
This hook is a no-op if command_name is "sqil" (used only in train_imitation),
which has its own config hook.
"""
del logger
res = {}
if config["rl"]["rl_cls"] is None or config["rl"]["rl_cls"] == sb3.PPO:
if config["rl"]["rl_cls"] == None and command_name != "sqil":
res["rl_cls"] = sb3.PPO
res["batch_size"] = 2048 # rl_kwargs["n_steps"] = batch_size // venv.num_envs
res["rl_kwargs"] = dict(
Expand Down
49 changes: 49 additions & 0 deletions src/imitation/scripts/ingredients/sqil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""This ingredient provides a SQIL algorithm instance."""
import sacred
from stable_baselines3 import dqn as dqn_algorithm

from imitation.policies import base
from imitation.scripts.ingredients import policy, rl

sqil_ingredient = sacred.Ingredient(
"sqil",
ingredients=[rl.rl_ingredient, policy.policy_ingredient],
)


@sqil_ingredient.config
def config():
total_timesteps = 3e5
train_kwargs = dict(
log_interval=4, # Number of updates between Tensorboard/stdout logs
progress_bar=True,
)

locals() # quieten flake8 unused variable warning


@rl.rl_ingredient.config_hook
def override_rl_cls(config, command_name, logger):
# want to remove arguments added by the rl ingredient but keep
# the ones that are added by others
del logger

res = {}
if command_name == "sqil" and config["rl"]["rl_cls"] is None:
res["rl_cls"] = dqn_algorithm.DQN

return res


@policy.policy_ingredient.config_hook
def override_policy_cls(config, command_name, logger): # noqa
del logger

res = {}
if (
command_name == "sqil"
and config["policy"]["policy_cls"] == base.FeedForward32Policy
):
res["policy_cls"] = "MlpPolicy"

return res
68 changes: 56 additions & 12 deletions src/imitation/scripts/train_imitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
import numpy as np
from sacred.observers import FileStorageObserver

from imitation.algorithms.dagger import SimpleDAggerTrainer
from imitation.algorithms import dagger as dagger_algorithm
from imitation.algorithms import sqil as sqil_algorithm
from imitation.data import rollout, types
from imitation.scripts.config.train_imitation import train_imitation_ex
from imitation.scripts.ingredients import bc as bc_ingredient
from imitation.scripts.ingredients import demonstrations, environment, expert
from imitation.scripts.ingredients import logging as logging_ingredient
from imitation.scripts.ingredients import policy_evaluation
from imitation.util import util

logger = logging.getLogger(__name__)

Expand All @@ -40,6 +42,18 @@ def _try_computing_expert_stats(
return None


def _collect_stats(
imit_stats: Mapping[str, float],
expert_trajs: Sequence[types.Trajectory],
) -> Mapping[str, Mapping[str, Any]]:
stats = {"imit_stats": imit_stats}
expert_stats = _try_computing_expert_stats(expert_trajs)
if expert_stats is not None:
stats["expert_stats"] = expert_stats

return stats


@train_imitation_ex.command
def bc(
bc: Dict[str, Any],
Expand Down Expand Up @@ -68,14 +82,12 @@ def bc(

bc_trainer.train(**bc_train_kwargs)
# TODO(adam): add checkpointing to BC?
bc_trainer.save_policy(policy_path=osp.join(log_dir, "final.th"))
util.save_policy(bc_trainer.policy, policy_path=osp.join(log_dir, "final.th"))

imit_stats = policy_evaluation.eval_policy(bc_trainer.policy, venv)

stats = {"imit_stats": imit_stats}
expert_stats = _try_computing_expert_stats(expert_trajs)
if expert_stats is not None:
stats["expert_stats"] = expert_stats
stats = _collect_stats(imit_stats, expert_trajs)

return stats


Expand Down Expand Up @@ -112,7 +124,7 @@ def dagger(

expert_policy = expert.get_expert_policy(venv)

dagger_trainer = SimpleDAggerTrainer(
dagger_trainer = dagger_algorithm.SimpleDAggerTrainer(
venv=venv,
scratch_dir=osp.join(log_dir, "scratch"),
expert_trajs=expert_trajs,
Expand All @@ -132,16 +144,48 @@ def dagger(

imit_stats = policy_evaluation.eval_policy(bc_trainer.policy, venv)

stats = {"imit_stats": imit_stats}
assert dagger_trainer._all_demos is not None
expert_stats = _try_computing_expert_stats(dagger_trainer._all_demos)
if expert_stats is not None:
stats["expert_stats"] = expert_stats
stats = _collect_stats(imit_stats, dagger_trainer._all_demos)

return stats


@train_imitation_ex.command
def sqil(
sqil: Mapping[str, Any],
policy: Mapping[str, Any],
rl: Mapping[str, Any],
_run,
_rnd: np.random.Generator,
) -> Mapping[str, Mapping[str, float]]:
custom_logger, log_dir = logging_ingredient.setup_logging()
expert_trajs = demonstrations.get_expert_trajectories()

with environment.make_venv() as venv:
sqil_trainer = sqil_algorithm.SQIL(
venv=venv,
demonstrations=expert_trajs,
policy=policy["policy_cls"],
custom_logger=custom_logger,
rl_algo_class=rl["rl_cls"],
rl_kwargs=rl["rl_kwargs"],
)

sqil_trainer.train(
total_timesteps=int(sqil["total_timesteps"]),
**sqil["train_kwargs"],
)
util.save_policy(sqil_trainer.policy, policy_path=osp.join(log_dir, "final.th"))

imit_stats = policy_evaluation.eval_policy(sqil_trainer.policy, venv)

stats = _collect_stats(imit_stats, expert_trajs)

return stats


def main_console():
observer_path = pathlib.Path.cwd() / "output" / "sacred" / "train_dagger"
observer_path = pathlib.Path.cwd() / "output" / "sacred" / "train_imitation"
observer = FileStorageObserver(observer_path)
train_imitation_ex.observers.append(observer)
train_imitation_ex.run_commandline()
Expand Down
12 changes: 11 additions & 1 deletion src/imitation/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,22 @@
import numpy as np
import torch as th
from gym.wrappers import TimeLimit
from stable_baselines3.common import monitor
from stable_baselines3.common import monitor, policies
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv

from imitation.data.types import AnyPath


def save_policy(policy: policies.BasePolicy, policy_path: AnyPath) -> None:
"""Save policy to a path.
Args:
policy: policy to save.
policy_path: path to save policy to.
"""
th.save(policy, parse_path(policy_path))


def oric(x: np.ndarray) -> np.ndarray:
"""Optimal rounding under integer constraints.
Expand Down
2 changes: 1 addition & 1 deletion tests/algorithms/test_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def test_that_policy_reconstruction_preserves_parameters(
original_parameters = list(cartpole_bc_trainer.policy.parameters())

# WHEN
cartpole_bc_trainer.save_policy(pol_path)
util.save_policy(cartpole_bc_trainer.policy, pol_path)
reconstructed_policy = bc.reconstruct_policy(pol_path)

# THEN
Expand Down
2 changes: 1 addition & 1 deletion tests/algorithms/test_dagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def test_simple_dagger_trainer_train(
def test_policy_save_reload(tmpdir, trainer):
# just make sure the methods run; we already test them in test_bc.py
policy_path = os.path.join(tmpdir, "policy.pt")
trainer.save_policy(policy_path)
util.save_policy(trainer.policy, policy_path)
pol = bc.reconstruct_policy(policy_path)
assert isinstance(pol, policies.BasePolicy)

Expand Down
57 changes: 47 additions & 10 deletions tests/scripts/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,15 +338,7 @@ def test_train_bc_main_with_demonstrations_from_huggingface(tmpdir):
)


@pytest.fixture(
params=[
"expert_from_path",
"expert_from_huggingface",
"random_expert",
"zero_expert",
],
)
def bc_config(tmpdir, request):
def generate_imitation_config(tmpdir, request, command_name: str) -> Dict:
environment_named_config = "seals_cartpole"

if request.param == "expert_from_path":
Expand All @@ -365,7 +357,7 @@ def bc_config(tmpdir, request):
expert_config = dict(policy_type="zero")

return dict(
command_name="bc",
command_name=command_name,
named_configs=[environment_named_config] + ALGO_FAST_CONFIGS["imitation"],
config_updates=dict(
logging=dict(log_root=tmpdir),
Expand All @@ -375,6 +367,30 @@ def bc_config(tmpdir, request):
)


@pytest.fixture(
params=[
"expert_from_path",
"expert_from_huggingface",
"random_expert",
"zero_expert",
],
)
def bc_config(tmpdir, request):
return generate_imitation_config(tmpdir, request, "bc")


@pytest.fixture(
params=[
"expert_from_path",
"expert_from_huggingface",
"random_expert",
"zero_expert",
],
)
def sqil_config(tmpdir, request):
return generate_imitation_config(tmpdir, request, "sqil")


def test_train_bc_main(bc_config):
run = train_imitation.train_imitation_ex.run(**bc_config)
assert run.status == "COMPLETED"
Expand Down Expand Up @@ -409,6 +425,27 @@ def test_train_bc_warmstart(tmpdir):
assert isinstance(run_warmstart.result, dict)


def test_train_sqil_main(sqil_config):
# NOTE: Having four different expert types as in bc might be overkill for sqil
run = train_imitation.train_imitation_ex.run(**sqil_config)
assert run.status == "COMPLETED"
assert isinstance(run.result, dict)


def test_train_sqil_main_with_demonstrations_from_huggingface(tmpdir):
train_imitation.train_imitation_ex.run(
command_name="sqil",
named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["imitation"],
config_updates=dict(
logging=dict(log_root=tmpdir),
demonstrations=dict(
source="huggingface",
algo_name="ppo",
),
),
)


@pytest.fixture(params=["cold_start", "warm_start"])
def rl_train_ppo_config(request, tmpdir):
config = dict(logging=dict(log_root=tmpdir))
Expand Down

0 comments on commit 885beff

Please sign in to comment.