-
Notifications
You must be signed in to change notification settings - Fork 235
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add sqil cli * Lints * More lints * Add shine requirement, used for DQN progress bar. * Undo removal of src.policy * Remove old comment * Add trailing commas * change dependencies * Update src/imitation/scripts/config/train_imitation.py Co-authored-by: Adam Gleave <[email protected]> * Move save_policy and reconstruct_policy" * Respond to fix save_policy issue * Remove some boilerplate * fix use of save_policy * Fix bug in sqil * Update src/imitation/scripts/ingredients/sqil.py Co-authored-by: Adam Gleave <[email protected]> * address PR * fix typing error * fix typing error * change shine to rich * remove line * Update src/imitation/scripts/ingredients/sqil.py Co-authored-by: Adam Gleave <[email protected]> * respond to adam comments * make line shorter * Simplify RL hook --------- Co-authored-by: Adam Gleave <[email protected]>
- Loading branch information
1 parent
cb93fb0
commit 885beff
Showing
11 changed files
with
177 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
"""This ingredient provides a SQIL algorithm instance.""" | ||
import sacred | ||
from stable_baselines3 import dqn as dqn_algorithm | ||
|
||
from imitation.policies import base | ||
from imitation.scripts.ingredients import policy, rl | ||
|
||
sqil_ingredient = sacred.Ingredient( | ||
"sqil", | ||
ingredients=[rl.rl_ingredient, policy.policy_ingredient], | ||
) | ||
|
||
|
||
@sqil_ingredient.config | ||
def config(): | ||
total_timesteps = 3e5 | ||
train_kwargs = dict( | ||
log_interval=4, # Number of updates between Tensorboard/stdout logs | ||
progress_bar=True, | ||
) | ||
|
||
locals() # quieten flake8 unused variable warning | ||
|
||
|
||
@rl.rl_ingredient.config_hook | ||
def override_rl_cls(config, command_name, logger): | ||
# want to remove arguments added by the rl ingredient but keep | ||
# the ones that are added by others | ||
del logger | ||
|
||
res = {} | ||
if command_name == "sqil" and config["rl"]["rl_cls"] is None: | ||
res["rl_cls"] = dqn_algorithm.DQN | ||
|
||
return res | ||
|
||
|
||
@policy.policy_ingredient.config_hook | ||
def override_policy_cls(config, command_name, logger): # noqa | ||
del logger | ||
|
||
res = {} | ||
if ( | ||
command_name == "sqil" | ||
and config["policy"]["policy_cls"] == base.FeedForward32Policy | ||
): | ||
res["policy_cls"] = "MlpPolicy" | ||
|
||
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters