-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Reopen #114036] Allow "must recompute" in torch.compile + selective checkpointing (SAC) #129295
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/129295
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit 5ca28a5 with merge base aa4ee2c (): BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…checkpointing (SAC) ghstack-source-id: dfc5d8f7d50844b2fa49726ed114a55c014ab89e Pull Request resolved: #129295
must_recompute(user) | ||
and user.meta["recompute"] > node.meta["recompute"] | ||
prefer_recompute(user) | ||
and user.meta["ac_graph_id"] > node.meta["ac_graph_id"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this an independent bugfix? Or is this part of the "partitioner should respect prefer_recompute" change.
Mostly a curiosity question: my understanding from the comment above is that if you run code like:
checkpoint_f = checkpoint(f)
checkpoint(g) = checkpoint(g)
out = f(g(inp))
Then AC requires us to save the inputs to f (outputs of g), but in our tag-based system every node would have the recompute tag: so you need some notion of "which AC subgraph does a node belong to" (ac_graph_id
) to tell the partitioner that it should save the output of the first subgraph. Is that right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is part of the "recompute" tag cleanup to disentangle its two meanings: 1) recompute policy, 2) AC subgraph ID, by splitting them into two tags.
yes your understanding is exactly right :)
torch/_functorch/partitioners.py
Outdated
@@ -808,8 +813,7 @@ def should_ban_recomputation(node): | |||
return False | |||
if node.target in [aten.lift_fresh_copy.default, aten.lift_fresh.default]: | |||
return False | |||
# NB: "recompute" == 0 means that must save this node. | |||
if node.meta.get("recompute", None) == 0: | |||
if node.meta.get("recompute", None) == CheckpointPolicy.MUST_SAVE: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like right above this condition, if the op is a view or a lift_fresh_copy
then we will not ban recomputation, even if the user marked as MUST_SAVE.
Do you think we should either error here, or prefer the user annotation instead?
Maybe a more general question: for ops that the partitioner is already strongly opinionated about whether they should be saved (e.g. randomness or view ops), should we error when a user tries to change the partitioners behavior for them? Or promise to always respect the user intent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good point, I think we should error when user MUST_SAVE a view or lift_fresh, etc. rather than ignore the annotation, or silently do something less optimal.
However, generally it is tricky to absolutely respect the "MUST_SAVE" condition, e.g. in the case where some code get's DCE'd.
Or another case is if the user choose to MUST_SAVE some tensor, but let's say the backward formula reduces this value before using it, so maybe we'd rather save the value post reduction.
But these two are really just the consequence of the fact that MUST_NOT_RECOMPUTE became MUST_SAVE... maybe its worth renaming it back lol.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I feel that we should bias more on respecting user intent - if user wants "MUST_SAVE", we should try to give them MUST_SAVE as much as possible, since the user explicitly expresses it.
…checkpointing (SAC) ghstack-source-id: 7a9f0491de59fc37fb98c8b9e6848b6f7524a624 Pull Request resolved: #129295
@pytorchbot merge -f "unrelated failures" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This replaces #114036.
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang