-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Raise exception if torch.func.* calls torch.compile functions #128736
base: gh/guilhermeleobas/54/base
Are you sure you want to change the base?
Raise exception if torch.func.* calls torch.compile functions #128736
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128736
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (3 Unrelated Failures)As of commit 2edd494 with merge base f389bca (): BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 790460103b03d70b99855662d2c3f58e57ca2908 Pull Request resolved: #128736
ghstack-source-id: ed45aa82f7c8477867ebcd0d8e7768049209f3c1 Pull Request resolved: #128736
torch/_dynamo/symbolic_convert.py
Outdated
# Is there a better way to detect this? | ||
from_eager = counters.get("graph_break", None) is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is 100% wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If any of the inputs is a functorch tensor and the backend is not eager
, raise a graph break.
def test_vmap_call_torch_compile_fn(self): | ||
def wrapped_fn(x): | ||
return x.sin() | ||
|
||
x = torch.randn(3, 4) | ||
fn = torch.compile(backend="aot_eager", fullgraph=True)(wrapped_fn) | ||
|
||
with self.assertRaisesRegex( | ||
torch._dynamo.exc.Unsupported, | ||
"Calling torch.func.vmap\\(compiled_fn\\) function from eager mode is not supported", | ||
): | ||
torch.func.vmap(fn)(x) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This graph breaks today:
import torch
def wrapped_fn(x):
return x.sin()
x = torch.randn(3, 4)
fn = torch.compile(backend="aot_eager")(wrapped_fn)
torch.vmap(fn)(x)
ghstack-source-id: 409c03af178c6e4b3ac121210f3a1d2e59d156d9 Pull Request resolved: #128736
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang