Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CLI for SQIL #784

Merged
merged 25 commits into from
Sep 16, 2023
Merged

Add CLI for SQIL #784

merged 25 commits into from
Sep 16, 2023

Conversation

lukasberglund
Copy link
Contributor

Description

This PR adds an option to src/imitation/scripts to train using SQIL as addressed in #780. This is still a work in progress, but I would welcome feedback. Some uncertainties I have include:

  • What arguments to explicitly add to the sacred config for SQIL -- Right now I'm only including the keyword-args for SQIL as well as log_interval, total_timesteps, and progress_bar, similarly to how it's done for BC and DAgger. The user set can set all the other arguments too, but including them would make them more salient.
  • Adding warmstart capabilities -- BC and DAgger both allow to continue training a previously trained model, but the SQIL API doesn't. Consequently I didn't include in the SQIL CLI. I could add this feature pretty easily though.
  • Logging -- Currently, I don't see any logs to stdout when training with SQIL, except at the end. I'm not sure if I'm using the wrong hyperparameters or this is a bug. I will look into this more tomorrow.

Testing

I added three tests that mirror those for DAgger and BC. Right now I have more tests for SQIL than for DAgger, which might be overkill. Let me know if I should remove some.

@lukasberglund lukasberglund added the enhancement New feature or request label Sep 12, 2023
@lukasberglund lukasberglund linked an issue Sep 12, 2023 that may be closed by this pull request
Copy link
Member

@AdamGleave AdamGleave left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests are currently failing (looks like a missing dependency for tqdm), please address.

Some small comments on the code itself.

setup.py Outdated Show resolved Hide resolved
src/imitation/algorithms/base.py Outdated Show resolved Hide resolved
src/imitation/scripts/config/train_imitation.py Outdated Show resolved Hide resolved
sqil_trainer = SQIL(
venv=venv,
demonstrations=expert_trajs,
policy=sqil["policy_model"],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This duplicates the policy ingredient: https://github.com/HumanCompatibleAI/imitation/blob/master/src/imitation/scripts/ingredients/policy.py

I think can replace with policy.make_policy(venv) and remove the policy_model config parameter.

Copy link
Contributor Author

@lukasberglund lukasberglund Sep 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think can replace with policy.make_policy(venv) and remove the policy_model config parameter.

This wouldn't work because policy.make_policy(venv) returns policies.BasePolicy, whereas SQIL requires type[policies.BasePolicy] (i.e. a constructor, not an instance).

But you are right, I could use policy["policy_cls"]. One issue is that SQIL uses a DQN by, default which requires a DQNPolicy and the default for policy_cls is base.FeedForward32Policy, which is incompatible. This is bit unfortunate, but I can't think of a way around it currently. To make it easy I've made a named_config called sqil.dqn which lets users set a dqn policy.

I currently can't think of a way to make sqil work by default. We would somehow want to override a config conditional on the sqil command being used. Not sure if that's possible/desirable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wouldn't work because policy.make_policy(venv) returns policies.BasePolicy, whereas SQIL requires type[policies.BasePolicy] (i.e. a constructor, not an instance).

Ah, good point, yes the difference between SQIL expecting classes v.s. the rest of our code expecting objects bites again.

But you are right, I could use policy["policy_cls"]. One issue is that SQIL uses a DQN by, default which requires a DQNPolicy and the default for policy_cls is base.FeedForward32Policy, which is incompatible. This is bit unfortunate, but I can't think of a way around it currently. To make it easy I've made a named_config called sqil.dqn which lets users set a dqn policy.

I currently can't think of a way to make sqil work by default. We would somehow want to override a config conditional on the sqil command being used. Not sure if that's possible/desirable.

Mm indeed this is messy and highlights a design flaw in Sacred. It is possible to have a different default depending on the context using a config hook, https://github.com/HumanCompatibleAI/imitation/blob/master/src/imitation/scripts/ingredients/rl.py#L41 is an example of this. You can add a hook in SQIL that checks if policy_cls == base.FeedForward32Policy and if so changes it to a DQNPolicy. This is a bit nasty since if for some reason the user manually sets the policy to a FeedForward32Policy we'll just override that silently, but it seems OK.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems good. I've made the change.

src/imitation/scripts/train_imitation.py Outdated Show resolved Hide resolved
src/imitation/scripts/train_imitation.py Outdated Show resolved Hide resolved
src/imitation/scripts/train_imitation.py Outdated Show resolved Hide resolved
tests/scripts/test_scripts.py Outdated Show resolved Hide resolved
tests/scripts/test_scripts.py Show resolved Hide resolved
setup.py Outdated Show resolved Hide resolved
src/imitation/scripts/ingredients/sqil.py Outdated Show resolved Hide resolved
src/imitation/scripts/config/train_imitation.py Outdated Show resolved Hide resolved
src/imitation/scripts/train_imitation.py Outdated Show resolved Hide resolved
src/imitation/scripts/train_imitation.py Outdated Show resolved Hide resolved
setup.py Show resolved Hide resolved
Copy link
Member

@AdamGleave AdamGleave left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making these changes! I think this is almost ready, a few minor comments

src/imitation/scripts/ingredients/rl.py Outdated Show resolved Hide resolved
src/imitation/scripts/ingredients/sqil.py Outdated Show resolved Hide resolved
locals() # quieten flake8 unused variable warning


@rl.rl_ingredient.config_hook
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this a config hook on the RL ingredient? I think all ingredients can modify any part of the config, so this could be a config hook on the SQIL ingredient directly. This would avoid mutating other ingredients (remember this is what caused issues with the tests previously), and would let you combine this with override_policy_cls.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried doing this originally, but it didn't work. For some reason it would set the variables inside of sqil. E.g. it would set config["sqil"]["rl"]["rl_cls"] instead of config["rl"]["rl_cls"] as intended.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can confirm having it in SQIL ingredient will set variables inside of SQIL. More problematically moving hook to train_imitation experiment seems to then have no effect on sub-ingredients. So, ugly though it is, I think we probably do need to keep it here. Good news is these hooks are no-op when command name is not sqil.

tests/test_benchmarking.py Outdated Show resolved Hide resolved
Copy link
Member

@AdamGleave AdamGleave left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

locals() # quieten flake8 unused variable warning


@rl.rl_ingredient.config_hook
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can confirm having it in SQIL ingredient will set variables inside of SQIL. More problematically moving hook to train_imitation experiment seems to then have no effect on sub-ingredients. So, ugly though it is, I think we probably do need to keep it here. Good news is these hooks are no-op when command name is not sqil.

@AdamGleave AdamGleave merged commit 885beff into master Sep 16, 2023
1 of 8 checks passed
@AdamGleave AdamGleave deleted the sqil_cli branch September 16, 2023 00:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add CLI for SQIL
2 participants