Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of the SQIL algorithm #744

Merged
merged 60 commits into from
Aug 10, 2023
Merged

Implementation of the SQIL algorithm #744

merged 60 commits into from
Aug 10, 2023

Conversation

RedTachyon
Copy link
Contributor

@RedTachyon RedTachyon commented Jul 4, 2023

Description

Fixes #740

Right now it's a basic implementation based on SB3.

It seems to work on a basic level (i.e. I trained it on CartPole and it converged), but still needs a bunch of cleanup and testing.

One note is that there's quite a few #type: ignore[...] annotations. I tried to minimize them, but a good chunk of the code is either modifying SB3 code, or closely interfacing with it, and SB3 seems to have more lax type checking.

Testing

WiP

@codecov
Copy link

codecov bot commented Jul 4, 2023

Codecov Report

Merging #744 (d2124a2) into master (2743c28) will increase coverage by 0.04%.
Report is 1 commits behind head on master.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master     #744      +/-   ##
==========================================
+ Coverage   96.33%   96.38%   +0.04%     
==========================================
  Files          93       95       +2     
  Lines        8789     8901     +112     
==========================================
+ Hits         8467     8579     +112     
  Misses        322      322              
Files Changed Coverage Δ
src/imitation/data/rollout.py 100.00% <ø> (ø)
src/imitation/algorithms/adversarial/common.py 96.83% <100.00%> (ø)
src/imitation/algorithms/base.py 98.73% <100.00%> (-0.05%) ⬇️
src/imitation/algorithms/bc.py 98.33% <100.00%> (ø)
src/imitation/algorithms/density.py 94.48% <100.00%> (ø)
src/imitation/algorithms/sqil.py 100.00% <100.00%> (ø)
src/imitation/data/types.py 98.21% <100.00%> (+0.01%) ⬆️
src/imitation/testing/expert_trajectories.py 100.00% <100.00%> (ø)
src/imitation/util/util.py 99.19% <100.00%> (+0.01%) ⬆️
tests/algorithms/test_sqil.py 100.00% <100.00%> (ø)

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

Copy link
Member

@AdamGleave AdamGleave left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High-level review, can take a closer look once PR is further along!

src/imitation/algorithms/sqil.py Outdated Show resolved Hide resolved
src/imitation/algorithms/sqil.py Outdated Show resolved Hide resolved
src/imitation/algorithms/sqil.py Show resolved Hide resolved
src/imitation/algorithms/sqil.py Outdated Show resolved Hide resolved
src/imitation/algorithms/sqil.py Outdated Show resolved Hide resolved
RedTachyon and others added 17 commits July 5, 2023 16:51
Remove redundant parameter
* Pin SB3 version to 1.7.0 (#738)

* Update conftest.py (#742)

* Custom environment tutorial (#746)

* Custom environment tutorial draft

* Update the docs website

* Clean notebook

* Text clarification and new environment

* Decrease training duration to hopefully make CI happy

* Clarify that BC itself does not learn rewards

---------

Co-authored-by: Ariel Kwiatkowski <[email protected]>

* Tutorial on comparing algorithm performance (#747)

* Add a new tutorial

* Update index.rst

* Improvements to the tutorial

* Some more caution words

* Fix typos

---------

Co-authored-by: Ariel Kwiatkowski <[email protected]>

---------

Co-authored-by: Adam Gleave <[email protected]>
Copy link
Member

@AdamGleave AdamGleave left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the implementation! Algorithm looks correct. SQIL is a strange beast but at least it's quite simple conceptually.

I think the implementation could be simplified / code duplication could be reduced by moving most of the logic into a ReplayBuffer wrapper, then setting DQN to use that replay buffer (probably with replay_buffer_class). I think you can still do this in composition (which I agree seems right approach -- SQIL is not an RL algorithm, so probably shouldn't inherit from DQN, and multiple inheritance gets messy). I may be missing some subtlety though, but if right this would let us eliminate many lines of code making it much easier to read. Could also pave the way to having this support any OffPolicyAlgorithm rather than just DQN which would be neat.

Other main area is it'd be nice to have slightly more comprehensive tests. Although to be fair the algorithm is so trivial there's not that much to actually test (beyond correct functioning of DQN). Checking it makes some progress on a simple environment might add something. If returns improving is too flaky (at least without too expensive a # of timesteps), could also check the Q-network is moving in right direction (e.g. assigns higher Q-value to expert demos than randomly chosen observation/action pairs)?

docs/index.rst Show resolved Hide resolved
src/imitation/algorithms/sqil.py Outdated Show resolved Hide resolved
src/imitation/algorithms/sqil.py Show resolved Hide resolved
src/imitation/algorithms/sqil.py Show resolved Hide resolved
src/imitation/algorithms/sqil.py Show resolved Hide resolved
src/imitation/algorithms/sqil.py Outdated Show resolved Hide resolved
tests/algorithms/test_sqil.py Outdated Show resolved Hide resolved
tests/algorithms/test_sqil.py Show resolved Hide resolved
"\n",
"Soft Q Imitation Learning ([SQIL](https://arxiv.org/abs/1905.11108)) is a simple algorithm that can be used to clone expert behavior.\n",
"It's fundamentally a modification of the DQN algorithm. At each training step, whenever we sample a batch of data from the replay buffer,\n",
"we also sample a batch of expert data. Expert demonstrations are assigned a reward of 1, while the agent's own transitions are assigned a reward of 0.\n",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an accurate description but does highlight the algorithm is a bit bizarre. If the agent perfectly mimicked expert demos, they'd still get assigned a reward of 0. Whereas at least with stuff like AIRL/GAIL, they'd get the same reward (as discriminator could no longer distinguish them).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is bizarre, but that's probably unavoidable, at least in continuous observation spaces (i.e. everything that's not tabular) -- if we tried to relabel generated data if it matches the demonstrations, that would basically never happen because we're comparing floats for equality.

This makes me wonder if this method would be particularly vulnerable to adversarial attacks, since there's a big difference between being in a state [1.1234] vs state [1.1235]. A larger network could probably overfit to that, which would be less likely with a more dense reward

docs/tutorials/10_train_sqil.ipynb Outdated Show resolved Hide resolved
@AdamGleave
Copy link
Member

Asking @jas-ho to look at fixing the failing/flaky tests

demonstrations=rollouts,
policy="MlpPolicy",
)
# Hint: set to 1_000_000 to match the expert performance.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100_000 was already sufficient to reach expert performance (tried only a couple of times though)

src/imitation/algorithms/sqil.py Show resolved Hide resolved
src/imitation/algorithms/sqil.py Show resolved Hide resolved
src/imitation/algorithms/sqil.py Outdated Show resolved Hide resolved
dqn_kwargs=dict(
learning_starts=500,
learning_rate=0.002,
batch_size=220,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I searched for good hyperparams using optuna. With these hyperparams I ended up with 1 failure out of 64 on my machine whereas before it was as many as 1 out of 5 (and on the CI pipeline it seems to have failed almost every time for the last ten-ish runs).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've also tried pushing the number of episodes in evaluate_policy up to 100 but it did not improve further so I reverted it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AdamGleave based on the hyperparam search and the manual testing I did I do not think the flakiness here points to a bug. Therefore, I think our best option is to actually fix the seeds for this specific test. Given how slow the CI pipeline is I think it's not good to have even a 2% remaining rate of flakiness.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see 2bf467d which passes on CI (except for codecov)
I also ran pytest --flake-finder tests/algorithms/test_sqil.py --flake-runs=16 -n 8 locally and found no failures for test_sqil_performance and test_sqil_demonstration_buffer which were the ones failing in CI previously

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing the hyperparam sweep! Yeah, let's fix the seed. Given it tests for a significant improvement in reward it should be a non-trivial test even with only a single seed (if we were just checking for any improvement then it'd be 50/50 whether it passed even if the algorithm was no better than random). We could also @pytest.mark.parametrize the seed for extra robustness if the test runs quick enough.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_sqil_performance already takes ~20 sec on my machine. Given that I did not cherry-pick the seed I don't think additional parametrization improves robustness enough to trade off favorably against decreased dev velocity

Copy link
Member

@AdamGleave AdamGleave left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for tuning hyperparams, agree seems unlikely to be a bug given high success rate after tuning.

Happy for you to make the other changes you suggested; please request a re-review from me & ernestum once done.

tests/algorithms/test_sqil.py Outdated Show resolved Hide resolved
tests/algorithms/test_sqil.py Outdated Show resolved Hide resolved
@jas-ho
Copy link
Contributor

jas-ho commented Aug 9, 2023

Thanks for tuning hyperparams, agree seems unlikely to be a bug given high success rate after tuning.

Happy for you to make the other changes you suggested; please request a re-review from me & ernestum once done.

I addressed your comments @AdamGleave . From my side it's ready for final review.

@ernestum
Copy link
Collaborator

ernestum commented Aug 9, 2023

@AdamGleave will you review this or should I have a look?

Copy link
Member

@AdamGleave AdamGleave left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Member

@AdamGleave AdamGleave left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please decide if you want to make this change or not prior to merging @jas-ho , but I don't need to re-review just for that.

src/imitation/algorithms/sqil.py Show resolved Hide resolved
@jas-ho
Copy link
Contributor

jas-ho commented Aug 10, 2023

Please decide if you want to make this change or not prior to merging @jas-ho , but I don't need to re-review just for that.

That change was implemented already so it might have just been a gh display issue. -> LGTM :)

Copy link
Collaborator

@ernestum ernestum left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ernestum ernestum merged commit fd4d8f0 into master Aug 10, 2023
15 checks passed
@ernestum ernestum deleted the redtachyon/740-sqil branch August 10, 2023 08:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement Soft Q imitation learning (SQIL)
4 participants