Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Remove reset on last step of a rollout #1936

Merged
merged 24 commits into from
Feb 21, 2024

Conversation

matteobettini
Copy link
Contributor

Discussion in #1929

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 20, 2024
Copy link

pytorch-bot bot commented Feb 20, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/1936

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures

As of commit 5222496 with merge base 23bf315 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@matteobettini matteobettini changed the title [BugFix] Remove rollout reset on last step [BugFix] Remove reset on last step of a rollout Feb 20, 2024
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this! Ideally we'd like a non regression test.
Also now that there is a reset missing, if the user takes the tensordict passed as output and feeds it back to rollout, rollout will need to call reset on it. I don't think this was part of rollout contract before: either a tensordict is passed or reset is called but not both (as this PR will dictate)

@matteobettini
Copy link
Contributor Author

matteobettini commented Feb 20, 2024

Ideally we'd like a non regression test.

I'll work on the test.

Also now that there is a reset missing, if the user takes the tensordict passed as output and feeds it back to rollout, rollout will need to call reset on it. I don't think this was part of rollout contract before: either a tensordict is passed or reset is called but not both (as this PR will dictate)

So this is a design choice, let me explain my view.

Before the PR

Before the PR, a user could not take the last tensordict from a rollout and safely pass it back in.

This is because if the env was reset on the last step, the data from that reset would not be part of the rollout results.

Thus, if the user passed the output of a rollout back into a new rollout, It would be providing old data from a previous trajectory to a freshly reset env which is in a different state.

After this PR

In the current state of this PR, the user still cannot just take the last tensordict from a rollout and safely pass it back in.

Doing so would have the same effect as prior to this PR, with the difference that now at least the env is not in a new reset state.

In my opinion this should remain like this (and like it was).

If the users chooses auto_reset=False it will be the duty of the user to pass non-done data as the input to a rollout (just as before).

The difference that this PR allows is that now users that want to run rollouts in a row are able to by adding the reset logic outside the rollout function (as detailed in the snippet in #1929 (comment)). Before they could not do this.

@vmoens vmoens added bug Something isn't working Suitable for minor Suitable to be integrated in minor release (no new feature) labels Feb 20, 2024
@vmoens
Copy link
Contributor

vmoens commented Feb 20, 2024

If we patch things let's patch them correctly no?
Why not having something like

        if auto_reset:
            if tensordict is not None:
                raise RuntimeError(
                    "tensordict cannot be provided when auto_reset is True"
                )
            tensordict = self.reset()
        elif tensordict is None:
            raise RuntimeError("tensordict must be provided when auto_reset is False")
		else:
			aggregate_reset = _aggregate_end_of_traj(tensordict)
            if aggregate_reset.any():
				self.reset(tensordict)

I don't understand the reluctance against this, does this break anything?

@matteobettini
Copy link
Contributor Author

I don't understand the reluctance against this, does this break anything?

Oh no this works too! I'm not against it.

It will just be a new feature.

Aka rollout will check that the passed tensordict is done and it will reset the env if so.

We can do it, a few considerations:

  • this feature should be activated if also auto_reset is set. This would change the meaning of auto_reset as it could be set at the same time as an input tensordict, generating the following cases:
    - no tensordict passed -> auto_reset will either call reset or not (like before)
    - tensordict passed and auto_reset=True -> rollout will check if the passed td is done and if so reset (this is the new param combination that was not allowed before and it will trigger the new feature)
    - tensordict passed and auto_reset=False -> rollout will not check the input td for reset (like before)

    keeping the last case is good for bc compatibility and there might be cases where you do not want the rollout to call a reset

@vmoens
Copy link
Contributor

vmoens commented Feb 20, 2024

It will just be a new feature.

I think we're at the edge between new feature and bug fix. If you can't pass the tensordict you got from the last step instead of calling reset safely, this would be a bugfix.

@matteobettini
Copy link
Contributor Author

If you can't pass the tensordict you got from the last step instead of calling reset safely, this would be a bugfix.

Yeah if this was supposed to be a feature before, I guess this is a bug fix

@vmoens
Copy link
Contributor

vmoens commented Feb 20, 2024

keeping the last case is good for bc compatibility and there might be cases where you do not want the rollout to call a reset

If the env does not allow a step after done and you pass a tensordict that is done, it will result in an error. IMO we should capture that and make sure it does not happen for ease of use. I can't think of anyone using rollout reasonably right now who would be annoyed by this change and consider it bc-breaking, but I could be overlooking things.

To me auto_reset just means reset when starting. I don't think we should change the meaning. The consideration I was putting forward is orthogonal to auto_reset IMO.

@matteobettini
Copy link
Contributor Author

matteobettini commented Feb 20, 2024

If the env does not allow a step after done and you pass a tensordict that is done, it will result in an error. IMO we should capture that and make sure it does not happen for ease of use. I can't think of anyone using rollout reasonably right now who would be annoyed by this change and consider it bc-breaking, but I could be overlooking things.

To me auto_reset just means reset when starting. I don't think we should change the meaning. The consideration I was putting forward is orthogonal to auto_reset IMO.

This consideration makes an assumption about the env.
There could be users that might want to take steps in done environments

Forcing a reset on these users would narrow the flexibility of rollout with respect to prior versions.
I think if we introduce this auto resetting when a td is given, it is important to make it deactivatable.

If we want to have another param to deactivate it other than auto_reset that is fine too, I just though that auto_reset fits nicely as it was not allowed before and it is true by default.

EDIT: thinking about it, you might be right. Since rollout auto-resets anyway along the trajectory, it could makes sense that this is done on the input td as well. It will add computational complexity for checking the done tho, so maybe having it optional still makes sense?

@vmoens
Copy link
Contributor

vmoens commented Feb 20, 2024

Not sure I'm following here

Previously we had a broken version of rollout where the reset data of the last step was potentially lost. Anyone wanting to do anything with this was doomed to fail. I don't think that in this scenario working with env that do not reset when done was an option. So this scenario can be excluded entirely and considered as a separate issue.

If we want to land this, to me having it work with tensordict that are in a done state is a pre-requirement. The change I'm proposing isn't bc-breaking in any way since it build upon a bug-fix: anyone who was doing anything with rollout before can only have fewer bugs now.

The runtime won't be affected since the check I'm suggesting is already performed by step_and_maybe_reset, which we now replace by step.

@matteobettini
Copy link
Contributor Author

matteobettini commented Feb 20, 2024

There is something that we still did not discuss.

If you really want to chain calls of rollout, then the function should call sefl.maybe_reset on the input dict (like in the current version of the PR).
Which includes step_mdp.

Or are we assuming that users call step_mdp beteween rollout calls? (which i guess was the assumption before?)

@vmoens
Copy link
Contributor

vmoens commented Feb 20, 2024

It's safe to assume that users call step_mdp because the contract is that the tensordict you pass is similar to the one you'd get out of a call to reset

@matteobettini
Copy link
Contributor Author

Ok if you wanna have a look now it should do what we want

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just missing a docstring and a test
(maybe also writing a note in rollout doctring about all of this? Like "how to use rollout as a data collector" or similar?)

@matteobettini
Copy link
Contributor Author

Further question:

if break_when_any_done is true, we still do the reset and ignore it?

@matteobettini
Copy link
Contributor Author

LGTM, just missing a docstring and a test (maybe also writing a note in rollout doctring about all of this? Like "how to use rollout as a data collector" or similar?)

We should be gucci!

Also, I still haven't come around the problem in #1929, so currently rollout is still not working for me as a data collector.
Have to see if it is a bug of my brain or a a very subtle bug of the library.

Nevertheless, this can be merged.

torchrl/envs/common.py Outdated Show resolved Hide resolved
torchrl/envs/common.py Outdated Show resolved Hide resolved
torchrl/envs/common.py Outdated Show resolved Hide resolved
torchrl/envs/common.py Outdated Show resolved Hide resolved
torchrl/envs/common.py Outdated Show resolved Hide resolved
torchrl/envs/common.py Outdated Show resolved Hide resolved
torchrl/envs/common.py Outdated Show resolved Hide resolved
@vmoens vmoens merged commit 03f4aa3 into pytorch:main Feb 21, 2024
65 of 68 checks passed
@matteobettini matteobettini deleted the fix_rollout branch February 21, 2024 22:11
vmoens added a commit that referenced this pull request Feb 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Suitable for minor Suitable to be integrated in minor release (no new feature)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants