Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise exception if torch.func.* calls torch.compile functions #128736

Draft
wants to merge 3 commits into
base: gh/guilhermeleobas/54/base
Choose a base branch
from

Conversation

guilhermeleobas
Copy link
Collaborator

@guilhermeleobas guilhermeleobas commented Jun 14, 2024

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Jun 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128736

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (3 Unrelated Failures)

As of commit 2edd494 with merge base f389bca (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

guilhermeleobas added a commit that referenced this pull request Jun 14, 2024
ghstack-source-id: 790460103b03d70b99855662d2c3f58e57ca2908
Pull Request resolved: #128736
[ghstack-poisoned]
guilhermeleobas added a commit that referenced this pull request Jun 18, 2024
ghstack-source-id: ed45aa82f7c8477867ebcd0d8e7768049209f3c1
Pull Request resolved: #128736
@guilhermeleobas guilhermeleobas changed the title Let functorch call torch.compile functions Raise exception if torch.func.* calls torch.compile functions Jun 18, 2024
Comment on lines 2449 to 2450
# Is there a better way to detect this?
from_eager = counters.get("graph_break", None) is None
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is 100% wrong.

Copy link
Collaborator Author

@guilhermeleobas guilhermeleobas Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If any of the inputs is a functorch tensor and the backend is not eager, raise a graph break.

@guilhermeleobas guilhermeleobas linked an issue Jul 29, 2024 that may be closed by this pull request
Comment on lines +5288 to +5300
def test_vmap_call_torch_compile_fn(self):
def wrapped_fn(x):
return x.sin()

x = torch.randn(3, 4)
fn = torch.compile(backend="aot_eager", fullgraph=True)(wrapped_fn)

with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported,
"Calling torch.func.vmap\\(compiled_fn\\) function from eager mode is not supported",
):
torch.func.vmap(fn)(x)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This graph breaks today:

import torch

def wrapped_fn(x):
    return x.sin()

x = torch.randn(3, 4)
fn = torch.compile(backend="aot_eager")(wrapped_fn)

torch.vmap(fn)(x)

[ghstack-poisoned]
guilhermeleobas added a commit that referenced this pull request Jul 29, 2024
ghstack-source-id: 409c03af178c6e4b3ac121210f3a1d2e59d156d9
Pull Request resolved: #128736
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

vmap fails to call torch.compiled function
3 participants