-
Notifications
You must be signed in to change notification settings - Fork 230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CLI for SQIL #784
Add CLI for SQIL #784
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests are currently failing (looks like a missing dependency for tqdm), please address.
Some small comments on the code itself.
sqil_trainer = SQIL( | ||
venv=venv, | ||
demonstrations=expert_trajs, | ||
policy=sqil["policy_model"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This duplicates the policy ingredient: https://github.com/HumanCompatibleAI/imitation/blob/master/src/imitation/scripts/ingredients/policy.py
I think can replace with policy.make_policy(venv)
and remove the policy_model
config parameter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think can replace with policy.make_policy(venv) and remove the policy_model config parameter.
This wouldn't work because policy.make_policy(venv)
returns policies.BasePolicy
, whereas SQIL requires type[policies.BasePolicy]
(i.e. a constructor, not an instance).
But you are right, I could use policy["policy_cls"]
. One issue is that SQIL uses a DQN by, default which requires a DQNPolicy and the default for policy_cls
is base.FeedForward32Policy
, which is incompatible. This is bit unfortunate, but I can't think of a way around it currently. To make it easy I've made a named_config called sqil.dqn
which lets users set a dqn policy.
I currently can't think of a way to make sqil
work by default. We would somehow want to override a config conditional on the sqil command being used. Not sure if that's possible/desirable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This wouldn't work because policy.make_policy(venv) returns policies.BasePolicy, whereas SQIL requires type[policies.BasePolicy] (i.e. a constructor, not an instance).
Ah, good point, yes the difference between SQIL expecting classes v.s. the rest of our code expecting objects bites again.
But you are right, I could use policy["policy_cls"]. One issue is that SQIL uses a DQN by, default which requires a DQNPolicy and the default for policy_cls is base.FeedForward32Policy, which is incompatible. This is bit unfortunate, but I can't think of a way around it currently. To make it easy I've made a named_config called sqil.dqn which lets users set a dqn policy.
I currently can't think of a way to make sqil work by default. We would somehow want to override a config conditional on the sqil command being used. Not sure if that's possible/desirable.
Mm indeed this is messy and highlights a design flaw in Sacred. It is possible to have a different default depending on the context using a config hook, https://github.com/HumanCompatibleAI/imitation/blob/master/src/imitation/scripts/ingredients/rl.py#L41 is an example of this. You can add a hook in SQIL that checks if policy_cls == base.FeedForward32Policy
and if so changes it to a DQNPolicy
. This is a bit nasty since if for some reason the user manually sets the policy to a FeedForward32Policy we'll just override that silently, but it seems OK.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems good. I've made the change.
Co-authored-by: Adam Gleave <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making these changes! I think this is almost ready, a few minor comments
locals() # quieten flake8 unused variable warning | ||
|
||
|
||
@rl.rl_ingredient.config_hook |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this a config hook on the RL ingredient? I think all ingredients can modify any part of the config, so this could be a config hook on the SQIL ingredient directly. This would avoid mutating other ingredients (remember this is what caused issues with the tests previously), and would let you combine this with override_policy_cls
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried doing this originally, but it didn't work. For some reason it would set the variables inside of sqil. E.g. it would set config["sqil"]["rl"]["rl_cls"]
instead of config["rl"]["rl_cls"]
as intended.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can confirm having it in SQIL ingredient will set variables inside of SQIL. More problematically moving hook to train_imitation experiment seems to then have no effect on sub-ingredients. So, ugly though it is, I think we probably do need to keep it here. Good news is these hooks are no-op when command name is not sqil.
Co-authored-by: Adam Gleave <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
locals() # quieten flake8 unused variable warning | ||
|
||
|
||
@rl.rl_ingredient.config_hook |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can confirm having it in SQIL ingredient will set variables inside of SQIL. More problematically moving hook to train_imitation experiment seems to then have no effect on sub-ingredients. So, ugly though it is, I think we probably do need to keep it here. Good news is these hooks are no-op when command name is not sqil.
Description
This PR adds an option to
src/imitation/scripts
to train using SQIL as addressed in #780. This is still a work in progress, but I would welcome feedback. Some uncertainties I have include:log_interval
,total_timesteps
, andprogress_bar
, similarly to how it's done forBC
andDAgger
. The user set can set all the other arguments too, but including them would make them more salient.Testing
I added three tests that mirror those for DAgger and BC. Right now I have more tests for
SQIL
than forDAgger
, which might be overkill. Let me know if I should remove some.