Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Extend TensorDictPrimer default_value options #2071

Merged
merged 21 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 69 additions & 9 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6406,17 +6406,11 @@ def test_trans_parallel_env_check(self):
finally:
env.close()

def test_trans_serial_env_check(self):
with pytest.raises(RuntimeError, match="The leading shape of the primer specs"):
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([4])),
)
_ = env.observation_spec

@pytest.mark.parametrize("spec_shape", [[4], [2, 4]])
def test_trans_serial_env_check(self, spec_shape):
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([2, 4])),
TensorDictPrimer(mykey=UnboundedContinuousTensorSpec(spec_shape)),
)
check_env_specs(env)
assert "mykey" in env.reset().keys()
Expand Down Expand Up @@ -6516,6 +6510,72 @@ def test_tensordictprimer_batching(self, batched_class, break_when_any_done):
r1 = env.rollout(100, break_when_any_done=break_when_any_done)
tensordict.tensordict.assert_allclose_td(r0, r1)

def test_callable_default_value(self):
def create_tensor():
return torch.ones(3)

env = TransformedEnv(
ContinuousActionVecMockEnv(),
TensorDictPrimer(
mykey=UnboundedContinuousTensorSpec([3]), default_value=create_tensor
),
)
check_env_specs(env)
assert "mykey" in env.reset().keys()
assert ("next", "mykey") in env.rollout(3).keys(True)

def test_dict_default_value(self):

# Test with a dict of float default values
key1_spec = UnboundedContinuousTensorSpec([3])
key2_spec = UnboundedContinuousTensorSpec([3])
env = TransformedEnv(
ContinuousActionVecMockEnv(),
TensorDictPrimer(
mykey1=key1_spec,
mykey2=key2_spec,
default_value={
"mykey1": 1.0,
"mykey2": 2.0,
},
),
)
check_env_specs(env)
reset_td = env.reset()
assert "mykey1" in reset_td.keys()
assert "mykey2" in reset_td.keys()
rollout_td = env.rollout(3)
assert ("next", "mykey1") in rollout_td.keys(True)
assert ("next", "mykey2") in rollout_td.keys(True)
assert (rollout_td.get(("next", "mykey1")) == 1.0).all()
assert (rollout_td.get(("next", "mykey2")) == 2.0).all()

# Test with a dict of callable default values
key1_spec = UnboundedContinuousTensorSpec([3])
key2_spec = DiscreteTensorSpec(3, dtype=torch.int64)
env = TransformedEnv(
ContinuousActionVecMockEnv(),
TensorDictPrimer(
mykey1=key1_spec,
mykey2=key2_spec,
default_value={
"mykey1": lambda: torch.ones(3),
"mykey2": lambda: torch.tensor(1, dtype=torch.int64),
},
),
)
check_env_specs(env)
reset_td = env.reset()
assert "mykey1" in reset_td.keys()
assert "mykey2" in reset_td.keys()
rollout_td = env.rollout(3)
assert ("next", "mykey1") in rollout_td.keys(True)
assert ("next", "mykey2") in rollout_td.keys(True)
assert (rollout_td.get(("next", "mykey1")) == torch.ones(3)).all
assert (
rollout_td.get(("next", "mykey2")) == torch.tensor(1, dtype=torch.int64)
).all


class TestTimeMaxPool(TransformBase):
@pytest.mark.parametrize("T", [2, 4])
Expand Down
110 changes: 88 additions & 22 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4431,8 +4431,12 @@ class TensorDictPrimer(Transform):
random (bool, optional): if ``True``, the values will be drawn randomly from
the TensorSpec domain (or a unit Gaussian if unbounded). Otherwise a fixed value will be assumed.
Defaults to `False`.
default_value (float, optional): if non-random filling is chosen, this
value will be used to populate the tensors. Defaults to `0.0`.
default_value (float, Callable, Dict[NestedKey, float], Dict[NestedKey, Callable], optional): If non-random
filling is chosen, `default_value` will be used to populate the tensors. If `default_value` is a float,
all elements of the tensors will be set to that value. If it is a callable, this callable is expected to
return a tensor fitting the specs, and it will be used to generate the tensors. Finally, if `default_value`
is a dictionary of tensors or a dictionary of callables with keys matching those of the specs, these will
be used to generate the corresponding tensors. Defaults to `0.0`.
reset_key (NestedKey, optional): the reset key to be used as partial
reset indicator. Must be unique. If not provided, defaults to the
only reset key of the parent environment (if it has only one)
Expand Down Expand Up @@ -4489,8 +4493,11 @@ class TensorDictPrimer(Transform):
def __init__(
self,
primers: dict | CompositeSpec = None,
random: bool = False,
default_value: float = 0.0,
random: bool | None = None,
default_value: float
| Callable
| Dict[NestedKey, float]
| Dict[NestedKey, Callable] = None,
reset_key: NestedKey | None = None,
**kwargs,
):
Expand All @@ -4505,8 +4512,30 @@ def __init__(
if not isinstance(kwargs, CompositeSpec):
kwargs = CompositeSpec(kwargs)
self.primers = kwargs
if random and default_value:
raise ValueError(
"Setting random to True and providing a default_value are incompatible."
)
default_value = (
default_value or 0.0
) # if not random and no default value, use 0.0
self.random = random
if isinstance(default_value, dict):
primer_keys = {unravel_key(key) for key in self.primers.keys(True, True)}
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
default_value_keys = {unravel_key(key) for key in default_value.keys()}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about passing through a tensordict to represent this?
For instance, this format will be messy with nested keys

default_values = {("a", "b"): 1, ("c", "d"): lambda: torch.randn(()), "e": {"f": lambda: torch.zeros(())}}

but if you use tensordict nightly you get a nice representation:

default_values = TensorDict(default_values, []).to_dict()
default_values

which prints

{'a': {'b': tensor(1)},
 'c': {'d': <function __main__.<lambda>()>},
 'e': {'f': <function __main__.<lambda>()>}}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since tensordict accepts whatever value now, we could even not transform it back to a dict

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That way the default value structure will be 100% identical with the CompositeSpec that we use to represent the specs

Copy link
Contributor Author

@albertbou92 albertbou92 Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah very cool option @vmoens !

atm I need to transform it back to dict, because I get the following behaviour for non-tensor data:

import torch
from tensordict import TensorDict
default_value = {
    "mykey1": lambda: torch.ones(3),
    "mykey2": lambda: torch.tensor(1, dtype=torch.int64),
}
default_value = TensorDict(default_value, [])
keys = default_value.keys(True, True)
print(keys)

output:

_TensorDictKeysView([],
    include_nested=True,
    leaves_only=True)

So non-tensor data ara not considered leafs

Copy link
Contributor

@vmoens vmoens Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you using the latest nightly?
You can always define your own is_leaf for keys:

import tensordict
import torch
from tensordict import TensorDict
default_value = {
    "mykey1": lambda: torch.ones(3),
    "mykey2": lambda: torch.tensor(1, dtype=torch.int64),
}
default_value = TensorDict(default_value, [])
print(default_value)
keys = list(default_value.keys(True, True, is_leaf=lambda x: issubclass(x, (tensordict.NonTensorData, torch.Tensor))))
print(keys)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was using the latest commit, yes. But this solution works fine :)

if primer_keys != default_value_keys:
raise ValueError(
"If a default_value dictionary is provided, it must match the primers keys."
)
default_value = {
key: default_value[key] for key in self.primers.keys(True, True)
}
else:
default_value = {
key: default_value for key in self.primers.keys(True, True)
}
self.default_value = default_value
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
self._validated = False
self.reset_key = reset_key

# sanity check
Expand Down Expand Up @@ -4559,6 +4588,9 @@ def to(self, *args, **kwargs):
self.primers = self.primers.to(device)
return super().to(*args, **kwargs)

def _try_expand_shape(self, spec):
return spec.expand((*self.parent.batch_size, *spec.shape))

def transform_observation_spec(
self, observation_spec: CompositeSpec
) -> CompositeSpec:
Expand All @@ -4568,15 +4600,20 @@ def transform_observation_spec(
)
for key, spec in self.primers.items():
if spec.shape[: len(observation_spec.shape)] != observation_spec.shape:
raise RuntimeError(
f"The leading shape of the primer specs ({self.__class__}) should match the one of the parent env. "
f"Got observation_spec.shape={observation_spec.shape} but the '{key}' entry's shape is {spec.shape}."
)
try:
expanded_spec = self._try_expand_shape(spec)
except AttributeError:
raise RuntimeError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When will this be reached?

Copy link
Contributor Author

@albertbou92 albertbou92 Apr 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if for any reason self.parent is None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when would transform_observation_spec be called when parent is None?

f"The leading shape of the primer specs ({self.__class__}) should match the one of the "
f"parent env. Got observation_spec.shape={observation_spec.shape} but the '{key}' entry's "
f"shape is {expanded_spec.shape}."
)
spec = expanded_spec
try:
device = observation_spec.device
except RuntimeError:
device = self.device
observation_spec[key] = spec.to(device)
observation_spec[key] = self.primers[key] = spec.to(device)
return observation_spec

def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
Expand All @@ -4589,8 +4626,13 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
def _batch_size(self):
return self.parent.batch_size

def _validate_value_tensor(self, value, spec):
if not spec.is_in(value):
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError(f"Value ({value}) is not in the spec domain ({spec}).")
return True

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
for key, spec in self.primers.items():
for key, spec in self.primers.items(True, True):
if spec.shape[: len(tensordict.shape)] != tensordict.shape:
raise RuntimeError(
"The leading shape of the spec must match the tensordict's, "
Expand All @@ -4601,11 +4643,21 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.random:
value = spec.rand()
else:
value = torch.full_like(
spec.zero(),
self.default_value,
)
value = self.default_value[key]
if callable(value):
value = value()
if not self._validated:
self._validate_value_tensor(value, spec)
else:
value = torch.full(
spec.shape,
value,
device=spec.device,
)

tensordict.set(key, value)
if not self._validated:
self._validated = True
return tensordict

def _step(
Expand Down Expand Up @@ -4634,22 +4686,36 @@ def _reset(
)
_reset = _get_reset(self.reset_key, tensordict)
if _reset.any():
for key, spec in self.primers.items():
for key, spec in self.primers.items(True, True):
if self.random:
value = spec.rand(shape)
else:
value = torch.full_like(
spec.zero(shape),
self.default_value,
)
prev_val = tensordict.get(key, 0.0)
value = torch.where(expand_as_right(_reset, value), value, prev_val)
value = self.default_value[key]
if callable(value):
value = value()
if not self._validated:
self._validate_value_tensor(value, spec)
else:
value = torch.full(
spec.shape,
value,
device=spec.device,
)
prev_val = tensordict.get(key, 0.0)
value = torch.where(
expand_as_right(_reset, value), value, prev_val
)
tensordict_reset.set(key, value)
self._validated = True
return tensordict_reset

def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}(primers={self.primers}, default_value={self.default_value}, random={self.random})"
default_value = {
key: value if isinstance(value, float) else "Callable"
for key, value in self.default_value.items()
}
return f"{class_name}(primers={self.primers}, default_value={default_value}, random={self.random})"


class PinMemoryTransform(Transform):
Expand Down
Loading