forked from HumanCompatibleAI/imitation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
quickstart.py
75 lines (61 loc) · 1.87 KB
/
quickstart.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""This is a simple example demonstrating how to clone the behavior of an expert.
Refer to the jupyter notebooks for more detailed examples of how to use the algorithms.
"""
import gym
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.ppo import MlpPolicy
from imitation.algorithms import bc
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
env = gym.make("CartPole-v1")
rng = np.random.default_rng(0)
def train_expert():
print("Training a expert.")
expert = PPO(
policy=MlpPolicy,
env=env,
seed=0,
batch_size=64,
ent_coef=0.0,
learning_rate=0.0003,
n_epochs=10,
n_steps=64,
)
expert.learn(100) # Note: change this to 100000 to train a decent expert.
return expert
def sample_expert_transitions():
expert = train_expert()
print("Sampling expert transitions.")
rollouts = rollout.rollout(
expert,
DummyVecEnv([lambda: RolloutInfoWrapper(env)]),
rollout.make_sample_until(min_timesteps=None, min_episodes=50),
rng=rng,
)
return rollout.flatten_trajectories(rollouts)
transitions = sample_expert_transitions()
bc_trainer = bc.BC(
observation_space=env.observation_space,
action_space=env.action_space,
demonstrations=transitions,
rng=rng,
)
reward, _ = evaluate_policy(
bc_trainer.policy, # type: ignore[arg-type]
env,
n_eval_episodes=3,
render=True,
)
print(f"Reward before training: {reward}")
print("Training a policy using Behavior Cloning")
bc_trainer.train(n_epochs=1)
reward, _ = evaluate_policy(
bc_trainer.policy, # type: ignore[arg-type]
env,
n_eval_episodes=3,
render=True,
)
print(f"Reward after training: {reward}")