Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Extend TensorDictPrimer default_value options #2071

Merged
merged 21 commits into from
Apr 18, 2024

Conversation

albertbou92
Copy link
Contributor

@albertbou92 albertbou92 commented Apr 9, 2024

Description

This PR aims to extend the possible values taken by the tensors added by the TensorDictPrimer transform, allowing to use callable to create them.

Motivation and Context

Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax close #15213 if this solves the issue #15213

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

Copy link

pytorch-bot bot commented Apr 9, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2071

Note: Links to docs will display an error until the docs builds have been completed.

❌ 6 New Failures, 7 Unrelated Failures

As of commit 1625514 with merge base acf168e (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 9, 2024
@albertbou92 albertbou92 changed the title [WIP] Extend TensorDictPrimer default_value options Extend TensorDictPrimer default_value options Apr 10, 2024
@vmoens vmoens changed the title Extend TensorDictPrimer default_value options [Feature] Extend TensorDictPrimer default_value options Apr 10, 2024
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, left a bunch of comments.
Thanks a million!

torchrl/envs/transforms/transforms.py Outdated Show resolved Hide resolved
torchrl/envs/transforms/transforms.py Outdated Show resolved Hide resolved
torchrl/envs/transforms/transforms.py Show resolved Hide resolved
try:
expanded_spec = self._try_expand_shape(spec)
except AttributeError:
raise RuntimeError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When will this be reached?

Copy link
Contributor Author

@albertbou92 albertbou92 Apr 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if for any reason self.parent is None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when would transform_observation_spec be called when parent is None?

torchrl/envs/transforms/transforms.py Outdated Show resolved Hide resolved
torchrl/envs/transforms/transforms.py Show resolved Hide resolved
torchrl/envs/transforms/transforms.py Outdated Show resolved Hide resolved
torchrl/envs/transforms/transforms.py Outdated Show resolved Hide resolved
torchrl/envs/transforms/transforms.py Outdated Show resolved Hide resolved
@albertbou92 albertbou92 requested a review from vmoens April 14, 2024 15:48
@vmoens vmoens added the enhancement New feature or request label Apr 15, 2024
torchrl/envs/transforms/transforms.py Outdated Show resolved Hide resolved
self.random = random
if isinstance(default_value, dict):
primer_keys = {unravel_key(key) for key in self.primers.keys(True, True)}
default_value_keys = {unravel_key(key) for key in default_value.keys()}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about passing through a tensordict to represent this?
For instance, this format will be messy with nested keys

default_values = {("a", "b"): 1, ("c", "d"): lambda: torch.randn(()), "e": {"f": lambda: torch.zeros(())}}

but if you use tensordict nightly you get a nice representation:

default_values = TensorDict(default_values, []).to_dict()
default_values

which prints

{'a': {'b': tensor(1)},
 'c': {'d': <function __main__.<lambda>()>},
 'e': {'f': <function __main__.<lambda>()>}}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since tensordict accepts whatever value now, we could even not transform it back to a dict

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That way the default value structure will be 100% identical with the CompositeSpec that we use to represent the specs

Copy link
Contributor Author

@albertbou92 albertbou92 Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah very cool option @vmoens !

atm I need to transform it back to dict, because I get the following behaviour for non-tensor data:

import torch
from tensordict import TensorDict
default_value = {
    "mykey1": lambda: torch.ones(3),
    "mykey2": lambda: torch.tensor(1, dtype=torch.int64),
}
default_value = TensorDict(default_value, [])
keys = default_value.keys(True, True)
print(keys)

output:

_TensorDictKeysView([],
    include_nested=True,
    leaves_only=True)

So non-tensor data ara not considered leafs

Copy link
Contributor

@vmoens vmoens Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you using the latest nightly?
You can always define your own is_leaf for keys:

import tensordict
import torch
from tensordict import TensorDict
default_value = {
    "mykey1": lambda: torch.ones(3),
    "mykey2": lambda: torch.tensor(1, dtype=torch.int64),
}
default_value = TensorDict(default_value, [])
print(default_value)
keys = list(default_value.keys(True, True, is_leaf=lambda x: issubclass(x, (tensordict.NonTensorData, torch.Tensor))))
print(keys)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was using the latest commit, yes. But this solution works fine :)

try:
expanded_spec = self._try_expand_shape(spec)
except AttributeError:
raise RuntimeError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when would transform_observation_spec be called when parent is None?

@albertbou92 albertbou92 requested a review from vmoens April 16, 2024 09:59
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@vmoens
Copy link
Contributor

vmoens commented Apr 16, 2024

TestgSDE is failing because we patched the behaviour for wrong primers, can you fix that?

@albertbou92
Copy link
Contributor Author

done!

@albertbou92 albertbou92 requested a review from vmoens April 18, 2024 09:36
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@vmoens vmoens merged commit 6b87184 into pytorch:main Apr 18, 2024
40 of 53 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
3 participants