Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add partial support for dictionary observation spaces (bc, density) #785

Merged
merged 89 commits into from
Oct 5, 2023

Conversation

NixGD
Copy link
Contributor

@NixGD NixGD commented Sep 13, 2023

Description

Partially addresses #681, by adding support for dictionary observation spaces in:

  • Core trajectory gathering and processing code (types.py, rollout.py, etc.)
  • Behavioral Cloning
  • Density based algorithms

Does not add support for any other algorithms, or trajectory saving / loading.

Testing

  • Add Dict space to observation space parameterization over trajectories in test_types.py
  • Add explicit tests of ObsDict in test_types.py
  • Add tests of dict observation spaces to test_rollout.py
  • Add tests of dict observation spaces to test_bc.py
  • Add tests of dict observation spaces to density.py

src/imitation/data/types.py Outdated Show resolved Hide resolved
Copy link
Member

@AdamGleave AdamGleave left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this design looks good to me. As you mentioned lots of things that needed to be cleaned up but I think it's going in the right direction. I did a fairly detailed review of data/types.py, data/rollout.py and algorithms/bc.py but just skimmed the rest so do highlight if there are any other important areas.

src/imitation/data/types.py Outdated Show resolved Hide resolved
src/imitation/data/types.py Outdated Show resolved Hide resolved
src/imitation/data/types.py Outdated Show resolved Hide resolved
src/imitation/data/types.py Outdated Show resolved Hide resolved
src/imitation/data/types.py Outdated Show resolved Hide resolved
src/imitation/data/rollout.py Outdated Show resolved Hide resolved
src/imitation/data/rollout.py Outdated Show resolved Hide resolved
src/imitation/data/rollout.py Outdated Show resolved Hide resolved
src/imitation/algorithms/bc.py Outdated Show resolved Hide resolved
src/imitation/algorithms/bc.py Outdated Show resolved Hide resolved
Copy link
Member

@AdamGleave AdamGleave left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to think about is how much we want to push dict observations through the code. In many cases for the imitation algorithms we could basically just flatten the dict to a NumPy array and "add" dict support with minimal additional code changes.

But that kind of defeats the purpose of adding dict support -- may as well just flatten the observations at the environment level. On the other hand, we have to flatten them at some point: the neural network will take a tensor as input not a dict. So, a lot of the design decision is deciding where to do the flattening.

It seems nice to be able to preserve the dict up until calling the policy. This gives flexibility to the user. In our case, InteractivePolicy can ignore the non-rendering component. In general, a user might want to preprocess different components of the observation differently.

src/imitation/algorithms/density.py Outdated Show resolved Hide resolved
src/imitation/algorithms/mce_irl.py Show resolved Hide resolved
src/imitation/data/huggingface_utils.py Show resolved Hide resolved
@@ -371,3 +375,59 @@ def inc_batch_cnt():

# THEN
assert batch_cnt == no_yield_after_iter


class FloatReward(gym.RewardWrapper):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this? shouldn't reward be a float already?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see SimpleMultiObsEnvs sometimes returns 1 rather than 1.0. This is a bug we should probably fix upstream. I'm a maintainer of SB3 so if you make a PR I can review it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR here. Though the environment is more sketchy the more I look at it, e.g. they hardcode the possible transitions from each state in a way that doesn't depend on gridworld size. So maybe best to not depend on the environment much (or submit more fixes upstream).

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SimpleMultiObsEnvs is just a test env, it is not meant to be used except in tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, we'd only be using it in tests as well.

tests/algorithms/test_bc.py Outdated Show resolved Hide resolved
@NixGD
Copy link
Contributor Author

NixGD commented Sep 14, 2023

One sad thing to note is we don't support arbitrarily nested dictionaries.

@NixGD
Copy link
Contributor Author

NixGD commented Sep 15, 2023

One thing to think about is how much we want to push dict observations through the code. In many cases for the imitation algorithms we could basically just flatten the dict to a NumPy array and "add" dict support with minimal additional code changes.

But that kind of defeats the purpose of adding dict support -- may as well just flatten the observations at the environment level. On the other hand, we have to flatten them at some point: the neural network will take a tensor as input not a dict. So, a lot of the design decision is deciding where to do the flattening.

It seems nice to be able to preserve the dict up until calling the policy. This gives flexibility to the user. In our case, InteractivePolicy can ignore the non-rendering component. In general, a user might want to preprocess different components of the observation differently.

I agree, I think the Policy is the place to handle the dict -> network input transition. This is consistent with SB3 (see here) although the type signatures in SB3 obfuscate this fact.

@ZiyueWang25 ZiyueWang25 mentioned this pull request Oct 4, 2023
@ZiyueWang25
Copy link
Contributor

Finished addressing your comments. Please take another look.

I moved pytype upgrade related issue to #801 because it turns out to be more complicated than simply fixing some types.

Copy link
Member

@AdamGleave AdamGleave left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM -- one minor typo to fix but no need for a re-review. Can merge once we get CI green.

src/imitation/data/rollout.py Show resolved Hide resolved
src/imitation/algorithms/dagger.py Outdated Show resolved Hide resolved
@codecov
Copy link

codecov bot commented Oct 5, 2023

Codecov Report

Merging #785 (0af3037) into master (573b086) will increase coverage by 0.07%.
The diff coverage is 97.13%.

@@            Coverage Diff             @@
##           master     #785      +/-   ##
==========================================
+ Coverage   96.33%   96.40%   +0.07%     
==========================================
  Files          98       98              
  Lines        9177     9441     +264     
==========================================
+ Hits         8841     9102     +261     
- Misses        336      339       +3     
Files Coverage Δ
src/imitation/algorithms/bc.py 98.32% <100.00%> (+<0.01%) ⬆️
src/imitation/algorithms/preference_comparisons.py 99.13% <100.00%> (ø)
src/imitation/data/buffer.py 95.38% <100.00%> (ø)
src/imitation/data/rollout.py 100.00% <100.00%> (+1.38%) ⬆️
src/imitation/data/wrappers.py 100.00% <100.00%> (ø)
src/imitation/policies/base.py 98.07% <100.00%> (+0.16%) ⬆️
src/imitation/policies/exploration_wrapper.py 100.00% <100.00%> (ø)
src/imitation/rewards/reward_wrapper.py 98.41% <100.00%> (+0.07%) ⬆️
tests/algorithms/conftest.py 100.00% <100.00%> (ø)
tests/algorithms/test_adversarial.py 100.00% <ø> (ø)
... and 15 more

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@AdamGleave AdamGleave merged commit e6d8886 into master Oct 5, 2023
12 of 15 checks passed
@AdamGleave AdamGleave deleted the support-dict-obs-space branch October 5, 2023 19:32
@saeed349
Copy link

Appreciate adding this support.
Would be really good to see dict support extended to other algorithms as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants