Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] step_and_maybe_reset in env #1611

Merged
merged 122 commits into from
Oct 24, 2023
Merged

[Feature] step_and_maybe_reset in env #1611

merged 122 commits into from
Oct 24, 2023

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Oct 6, 2023

Contribution

This PR proposes the step_and_maybe_reset in EnvBase.
This method executes a step followed by a reset, if necessary.

We also make reset more robust by ensuring that partial resets are handled uniformly. This is necessary since step_and_maybe_reset must take care of this functionality, and from our perspective handling partial resets is the responsibility of reset (the user should not have to worry about data not updated properly).

This has repercussions on the logic behind TransformedEnv._reset and BatchedEnv._reset.

I'm now considering having batched envs calling reset and not _reset to make sure that the data is well presented, since now the update of the input tensordict with the tensordict_reset occurs after _reset (hence, the output of _reset in SerialEnv is incomplete).
This could introduce some overhead but that's of limited impact since now step_and_maybe_reset is there to handle things faster.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 6, 2023
@vmoens vmoens added enhancement New feature or request performance Performance issue or suggestion for improvement Refactoring Refactoring of an existing feature labels Oct 6, 2023
@matteobettini matteobettini added the Environments Adds or modifies an environment wrapper label Oct 6, 2023
@@ -464,12 +480,18 @@ def _reset(
self, tensordict: Optional[TensorDictBase] = None, **kwargs
) -> TensorDictBase:

_reset = tensordict.get("_reset", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This crashes when tensordict is None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah sorry about that, I can't test petting zoo locally so I'm always moving in the dark...

@vmoens
Copy link
Contributor Author

vmoens commented Oct 22, 2023

@matteobettini @albertbou92 @BY571 this should be (almost) mergeable.
Tests are passing (VMAS sporadically failing after tests complete and old deps broken because of #1622)

@matteobettini
Copy link
Contributor

Why would vmas sporadically fail?

@vmoens
Copy link
Contributor Author

vmoens commented Oct 22, 2023

Why would vmas sporadically fail?

Not sure you can have a look.
All tests pass but it fails when closing.
Seems like an issue with tensors not released on CUDA, it happens often with VMAS. It isn't even a flaky test, since they all pass, it's just a weird exit status

@matteobettini
Copy link
Contributor

Why would vmas sporadically fail?

Not sure you can have a look. All tests pass but it fails when closing. Seems like an issue with tensors not released on CUDA, it happens often with VMAS. It isn't even a flaky test, since they all pass, it's just a weird exit status

Was this happening before this PR?

@vmoens
Copy link
Contributor Author

vmoens commented Oct 22, 2023

it happens often with vmas

I was referring to VMAS CI beyond this PR

Copy link
Contributor

@matteobettini matteobettini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some final questions/comments

# goes through the tensordict and brings the _reset information to
# a boolean tensor of the shape of the tensordict.
batch_size = data.batch_size
n = len(batch_size)

if done_keys is not None and reset_keys is None:
reset_keys = {_replace_last(key, "done") for key in done_keys}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not following this.
We are taking the done keys (with all the terminated and truncated entries) and replacing a "done" ending for all and making it a set which we call reset keys.

This is counter intuitive as reset_keys have a _reset ending and not a done ending.
This change seems to come from the fact that you aim to use this function in 2 contexts:

  • normally on the root td with the reset_keys as input
  • on the "next" td in collectors with the done_keys as input

I think we should try to write this better, here are some suggestions:

  • always call the function on the root td and pass the keys with preappended "next" if you want to use that
  • do this key filtering and conversion outside of the function and let the function just operate on a set of keys to be cconsidered as reset keys

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the problem, can you elaborate why we should try to write this better?
Is it a naming problem? We can rename the function _aggregate_stop or smth similar.

always call the function on the root td and pass the keys with preappended "next" if you want to use that
That introduces some unwanted overhead when we can directly access "next" and read the done_keys. Recall that td.get(("next", "key")) is considerably slwoer than next_td.get("key") as we do here.

do this key filtering and conversion outside of the function and let the function just operate on a set of keys to be cconsidered as reset keys

What's your suggestion for _update_traj_ids in collectors.py for instance? We don't have a "_reset" in the "next" tensordict, but I think this function does its job of aggregating the done signals to read what the trajectory ids are.
What I understand is that the confusion comes from the "reset" in the function name, but what this function really does is just aggregating end-of-trajectory signals (either reset or done) to the root.
Given this, I don't see why it should be changed. It's a private function, properly tested and I think it serves its purpose.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

connecting to the comment below, we could have

  • _aggeregate_keys(keys=) which can be called on anything

alternatively we could have both _aggeregate_dones and _aggregate_resets where one calls the other

Comment on lines 789 to 790
traj_sop = _aggregate_resets(
tensordict.get("next"), done_keys=self.env.done_keys
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume we have to do this becuse there are no reset keys anymore visible by the collector.
Since this is quite counterintitive, what about a _aggregate_dones()?
I also have other suggestions in the other comment relating to this function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep we can rename the function. I would rather say that it's more direct to aggregate the done rather than the "_reset" which come from the dones

action_keys=self.action_keys,
done_keys=self.done_keys,
)
any_done = _terminated_or_truncated(
Copy link
Contributor

@matteobettini matteobettini Oct 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Help me understand this a little bit better.

In an axample case (e.g., pettingzoo) where i have

{
"done": [False],
"agents":{"done:[True, False]}
}

Is the any_done triggered?

If so, this is a problem for envs like PettingZoo where _reset() will be called with {"_reset": [False]}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The best thing it trying it out :)

from tensordict import TensorDict
from torchrl.envs.utils import _terminated_or_truncated

data = TensorDict({"done": [False], ("agent", "done"): [True, False]}, [])
print(_terminated_or_truncated(data))

which returns True
So what you're saying is that it should be False since there's a False at the root?
I can correct that

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should to follow the dominance rule we imposed right?
Or at least in this context definitely becuase we do not want to call reset.
I don't know in what other contexts this function is used though, but if its primary use is to decide when to call reset then yes

@vmoens
Copy link
Contributor Author

vmoens commented Oct 24, 2023

@matteobettini I addressed all your comments

Copy link
Contributor

@matteobettini matteobettini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@vmoens vmoens merged commit 3b355dd into main Oct 24, 2023
51 of 59 checks passed
@vmoens vmoens deleted the step_maybe_reset branch October 24, 2023 10:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request Environments Adds or modifies an environment wrapper major major refactoring in the code base performance Performance issue or suggestion for improvement Refactoring Refactoring of an existing feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants