Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Remove reset on last step of a rollout #1936

Merged
merged 24 commits into from
Feb 21, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
docs
  • Loading branch information
matteobettini committed Feb 21, 2024
commit 8289d047bce86a25c8845f0fa5ab3a5e7659c9f4
51 changes: 46 additions & 5 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2273,7 +2273,9 @@ def rollout(
called on the sub-envs that are done. Default is True.
return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is True.
tensordict (TensorDict, optional): if auto_reset is False, an initial
tensordict must be provided.
tensordict must be provided. Rollout will check if this tensordict has done flags and reset the
environment in those dimensions (if needed). This normally should not occur if ``tensordict`` is the
output of a reset, but can occur if ``tensordict`` is the last step of a previous rollout.

Returns:
TensorDict object containing the resulting trajectory.
Expand Down Expand Up @@ -2369,6 +2371,33 @@ def rollout(
>>> print(rollout.names)
[None, 'time']

Rollouts can be used in a loop to emulate data collection.
To do so, you need to pass as input the last tensordict coming from the previous rollout after calling
:meth:`step_mdp` on it.
vmoens marked this conversation as resolved.
Show resolved Hide resolved

Examples:
>>> from torchrl.envs import GymEnv, step_mdp
>>> env = GymEnv("CartPole-v1")
>>> epochs = 10
>>> reset_td = env.reset()
>>> for i in range(epochs):
>>> rollout_td = env.rollout(
... max_steps=100,
... policy=None,
... break_when_any_done=False,
... auto_reset=False,
... tensordict=reset_td,
... )
>>> reset_td = step_mdp(
... rollout_td[..., -1],
... keep_other=True,
... exclude_action=False,
... exclude_reward=True,
... reward_keys=env.reward_keys,
... action_keys=env.action_keys,
... done_keys=env.done_keys,
... )

"""
if auto_cast_to_device:
try:
Expand Down Expand Up @@ -2566,15 +2595,27 @@ def step_and_maybe_reset(
tensordict_ = self.maybe_reset(tensordict_)
return tensordict, tensordict_

def maybe_reset(self, tensordict_: TensorDictBase) -> TensorDictBase:
def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Checks the done keys of the in put tensordict and, if needed, resets the environment where it is done.

Args:
tensordict (TensorDictBase): a tensordict coming from the output of :meth:`step_mdp`
vmoens marked this conversation as resolved.
Show resolved Hide resolved
Returns:
TensorDictBase: a tensordict that is identical to the input one where the environment was
vmoens marked this conversation as resolved.
Show resolved Hide resolved
not reset and contains the new reset data where the environment was reset
vmoens marked this conversation as resolved.
Show resolved Hide resolved

This method is part of :meth:`~.step_and_maybe_reset` and should be called on the output of a :meth:`~.step`
on which :meth:`step_mdp` has been called.

vmoens marked this conversation as resolved.
Show resolved Hide resolved
"""
any_done = _terminated_or_truncated(
tensordict_,
tensordict,
full_done_spec=self.output_spec["full_done_spec"],
key="_reset",
)
if any_done:
tensordict_ = self.reset(tensordict_)
return tensordict_
tensordict = self.reset(tensordict)
return tensordict

def empty_cache(self):
"""Erases all the cached values.
Expand Down