Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cond] inlining into one of the branches when pred is a python constant #128709

Closed
wants to merge 15 commits into from

Conversation

ydwu4
Copy link
Contributor

@ydwu4 ydwu4 commented Jun 14, 2024

Stack from ghstack (oldest at bottom):

When the input predicate is a python constant, we specialize into one of the branches and warn users that torch.cond is not preserving the dynamism. The previous behavior is that we baked in True/False in the cond operator. This can be confusing. In this PR, we change it to be specializing into one of the branches when the inputs are constants.

We additionally change the naming of cond operator to default one without overriding its name. This allows better testing on de-serialized graph.

Test Plan:
The predicate in some existing tests is the result of a shape comparison. When no dynamic shape is involved, the predicate is a python bool. To fix them, we either change the predicate to be some data-dependent tensor or change the test to check cond is specialized as one of the branches,

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames

Differential Revision: D59589709

Copy link

pytorch-bot bot commented Jun 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128709

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 131 Cancelled Jobs, 3 Unrelated Failures

As of commit e06f55e with merge base 0beeac3 (image):

CANCELLED JOBS - The following jobs were cancelled. Please retry:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…thon constant"


When the input predicate is a python constant, we specialize into one of the branches and warn users that torch.cond is not preserving the dynamism.

The predicate in some existing tests is the result of a shape comparison. When no dynamic shape is involved, the predicate is a python bool. To fix them, we either change the predicate to be some data-dependent tensor or change the test to check cond is specialized as one of the branches, 

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Jun 14, 2024
ghstack-source-id: 6204b597c43c266c4c4bd42dfe232e25c924f8a1
Pull Request resolved: #128709
…thon constant"


When the input predicate is a python constant, we specialize into one of the branches and warn users that torch.cond is not preserving the dynamism.

The predicate in some existing tests is the result of a shape comparison. When no dynamic shape is involved, the predicate is a python bool. To fix them, we either change the predicate to be some data-dependent tensor or change the test to check cond is specialized as one of the branches, 

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Jun 14, 2024
ghstack-source-id: 22661aa10d3efb61ab5b5fe840e3dac4d31fd683
Pull Request resolved: #128709
…thon constant"


When the input predicate is a python constant, we specialize into one of the branches and warn users that torch.cond is not preserving the dynamism.

The predicate in some existing tests is the result of a shape comparison. When no dynamic shape is involved, the predicate is a python bool. To fix them, we either change the predicate to be some data-dependent tensor or change the test to check cond is specialized as one of the branches, 

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
…thon constant"


When the input predicate is a python constant, we specialize into one of the branches and warn users that torch.cond is not preserving the dynamism.

The predicate in some existing tests is the result of a shape comparison. When no dynamic shape is involved, the predicate is a python bool. To fix them, we either change the predicate to be some data-dependent tensor or change the test to check cond is specialized as one of the branches, 

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Jun 17, 2024
ghstack-source-id: 289481af89844ad078f5c2540e1eac53792d1960
Pull Request resolved: #128709
if isinstance(pred, (bool, int, float)):
log.warning(
"Pred is a Python constant. When used with torch.cond, it executes only one of the branches."
" If you want torch.cond to perserve two branches, please make the predicate a boolean tensor."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can cond take symbool?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, add a test_export_cond_symbool_pred to reflect this. cond won't specialize on symbool.

…thon constant"


When the input predicate is a python constant, we specialize into one of the branches and warn users that torch.cond is not preserving the dynamism.

The predicate in some existing tests is the result of a shape comparison. When no dynamic shape is involved, the predicate is a python bool. To fix them, we either change the predicate to be some data-dependent tensor or change the test to check cond is specialized as one of the branches, 

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Jun 20, 2024
ghstack-source-id: ee91520c5d4cb73cdd9cb98404313d2616f59cff
Pull Request resolved: #128709
…thon constant"


When the input predicate is a python constant, we specialize into one of the branches and warn users that torch.cond is not preserving the dynamism.

The predicate in some existing tests is the result of a shape comparison. When no dynamic shape is involved, the predicate is a python bool. To fix them, we either change the predicate to be some data-dependent tensor or change the test to check cond is specialized as one of the branches, 

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Jun 25, 2024
ghstack-source-id: a4d25d25d63eec7c7a0f1bdfdb6a3a025af510ed
Pull Request resolved: #128709
@ydwu4 ydwu4 requested a review from zou3519 June 25, 2024 18:34
…thon constant"


When the input predicate is a python constant, we specialize into one of the branches and warn users that torch.cond is not preserving the dynamism. The previous behavior is that we baked in True/False in the cond operator. This can be confusing. In this PR, we change it to be specializing into one of the branches when the inputs are constants.

Test Plan:
The predicate in some existing tests is the result of a shape comparison. When no dynamic shape is involved, the predicate is a python bool. To fix them, we either change the predicate to be some data-dependent tensor or change the test to check cond is specialized as one of the branches, 

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
@ydwu4 ydwu4 added topic: not user facing topic category and removed topic: not user facing topic category labels Jun 26, 2024
@ydwu4
Copy link
Contributor Author

ydwu4 commented Jun 26, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR has internal changes and must be landed via Phabricator

Details for Dev Infra team Raised by workflow job

…thon constant"


When the input predicate is a python constant, we specialize into one of the branches and warn users that torch.cond is not preserving the dynamism. The previous behavior is that we baked in True/False in the cond operator. This can be confusing. In this PR, we change it to be specializing into one of the branches when the inputs are constants.

We additionally change the naming of cond operator to default one without overriding its name. This allows better testing on de-serialized graph.

Test Plan:
The predicate in some existing tests is the result of a shape comparison. When no dynamic shape is involved, the predicate is a python bool. To fix them, we either change the predicate to be some data-dependent tensor or change the test to check cond is specialized as one of the branches, 

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

Differential Revision: [D59589709](https://our.internmc.facebook.com/intern/diff/D59589709)

[ghstack-poisoned]
@ydwu4
Copy link
Contributor Author

ydwu4 commented Jul 10, 2024

@ydwu4 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

…thon constant"


When the input predicate is a python constant, we specialize into one of the branches and warn users that torch.cond is not preserving the dynamism. The previous behavior is that we baked in True/False in the cond operator. This can be confusing. In this PR, we change it to be specializing into one of the branches when the inputs are constants.

We additionally change the naming of cond operator to default one without overriding its name. This allows better testing on de-serialized graph.

Test Plan:
The predicate in some existing tests is the result of a shape comparison. When no dynamic shape is involved, the predicate is a python bool. To fix them, we either change the predicate to be some data-dependent tensor or change the test to check cond is specialized as one of the branches, 

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

Differential Revision: [D59589709](https://our.internmc.facebook.com/intern/diff/D59589709)

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Jul 10, 2024
ghstack-source-id: ad68a8daacb741710843b72bec69fd3ce0dc4247
Pull Request resolved: #128709
ydwu4 added a commit that referenced this pull request Jul 10, 2024
…thon constant"


Reland #128709.

When the input predicate is a python constant, we specialize into one of the branches and warn users that torch.cond is not preserving the dynamism. The previous behavior is that we baked in True/False in the cond operator. This can be confusing. In this PR, we change it to be specializing into one of the branches when the inputs are constants.

We additionally change the naming of cond operator to default one without overriding its name. This allows better testing on de-serialized graph.

Test Plan:
The predicate in some existing tests is the result of a shape comparison. When no dynamic shape is involved, the predicate is a python bool. To fix them, we either change the predicate to be some data-dependent tensor or change the test to check cond is specialized as one of the branches,





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
@ydwu4
Copy link
Contributor Author

ydwu4 commented Jul 10, 2024

replaced by #130493

@ydwu4 ydwu4 closed this Jul 10, 2024
ydwu4 added a commit that referenced this pull request Jul 11, 2024
…hen pred is a python constant"


Reland #128709.

When the input predicate is a python constant, we specialize into one of the branches and warn users that torch.cond is not preserving the dynamism. The previous behavior is that we baked in True/False in the cond operator. This can be confusing. In this PR, we change it to be specializing into one of the branches when the inputs are constants.

We additionally change the naming of cond operator to default one without overriding its name. This allows better testing on de-serialized graph.

Test Plan:
The predicate in some existing tests is the result of a shape comparison. When no dynamic shape is involved, the predicate is a python bool. To fix them, we either change the predicate to be some data-dependent tensor or change the test to check cond is specialized as one of the branches,





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Jul 11, 2024
…thon constant"


Reland #128709.

When the input predicate is a python constant, we specialize into one of the branches and warn users that torch.cond is not preserving the dynamism. The previous behavior is that we baked in True/False in the cond operator. This can be confusing. In this PR, we change it to be specializing into one of the branches when the inputs are constants.

We additionally change the naming of cond operator to default one without overriding its name. This allows better testing on de-serialized graph.

Test Plan:
The predicate in some existing tests is the result of a shape comparison. When no dynamic shape is involved, the predicate is a python bool. To fix them, we either change the predicate to be some data-dependent tensor or change the test to check cond is specialized as one of the branches,





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Jul 11, 2024
…hen pred is a python constant"


Reland #128709.

When the input predicate is a python constant, we specialize into one of the branches and warn users that torch.cond is not preserving the dynamism. The previous behavior is that we baked in True/False in the cond operator. This can be confusing. In this PR, we change it to be specializing into one of the branches when the inputs are constants.

We additionally change the naming of cond operator to default one without overriding its name. This allows better testing on de-serialized graph.

Test Plan:
The predicate in some existing tests is the result of a shape comparison. When no dynamic shape is involved, the predicate is a python bool. To fix them, we either change the predicate to be some data-dependent tensor or change the test to check cond is specialized as one of the branches,





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Jul 11, 2024
…thon constant"


Reland #128709.

When the input predicate is a python constant, we specialize into one of the branches and warn users that torch.cond is not preserving the dynamism. The previous behavior is that we baked in True/False in the cond operator. This can be confusing. In this PR, we change it to be specializing into one of the branches when the inputs are constants.

We additionally change the naming of cond operator to default one without overriding its name. This allows better testing on de-serialized graph.

Test Plan:
The predicate in some existing tests is the result of a shape comparison. When no dynamic shape is involved, the predicate is a python bool. To fix them, we either change the predicate to be some data-dependent tensor or change the test to check cond is specialized as one of the branches,





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Jul 12, 2024
…nt (#130493)

Reland #128709.

When the input predicate is a python constant, we specialize into one of the branches and warn users that torch.cond is not preserving the dynamism. The previous behavior is that we baked in True/False in the cond operator. This can be confusing. In this PR, we change it to be specializing into one of the branches when the inputs are constants.

We additionally change the naming of cond operator to default one without overriding its name. This allows better testing on de-serialized graph.

Test Plan:
The predicate in some existing tests is the result of a shape comparison. When no dynamic shape is involved, the predicate is a python bool. To fix them, we either change the predicate to be some data-dependent tensor or change the test to check cond is specialized as one of the branches,

Pull Request resolved: #130493
Approved by: https://github.com/BoyuanFeng
pytorchmergebot pushed a commit that referenced this pull request Jul 17, 2024
…, and prim::Exit (#129416)

- Support raise exception. It's behavior matches non-strict export now, thanks to @ydwu4's [PR](#128709).
- Support prim::Unitialized, prim::Enter, and prim::Exit
Pull Request resolved: #129416
Approved by: https://github.com/angelayi
pytorchmergebot pushed a commit that referenced this pull request Jul 17, 2024
…, and prim::Exit (#129416)

- Support raise exception. It's behavior matches non-strict export now, thanks to @ydwu4's [PR](#128709).
- Support prim::Unitialized, prim::Enter, and prim::Exit
Pull Request resolved: #129416
Approved by: https://github.com/angelayi
pytorchmergebot pushed a commit that referenced this pull request Jul 17, 2024
…, and prim::Exit (#129416)

- Support raise exception. It's behavior matches non-strict export now, thanks to @ydwu4's [PR](#128709).
- Support prim::Unitialized, prim::Enter, and prim::Exit
Pull Request resolved: #129416
Approved by: https://github.com/angelayi
mlazos pushed a commit that referenced this pull request Jul 18, 2024
…, and prim::Exit (#129416)

- Support raise exception. It's behavior matches non-strict export now, thanks to @ydwu4's [PR](#128709).
- Support prim::Unitialized, prim::Enter, and prim::Exit
Pull Request resolved: #129416
Approved by: https://github.com/angelayi
DiweiSun pushed a commit to DiweiSun/pytorch that referenced this pull request Jul 22, 2024
…, and prim::Exit (pytorch#129416)

- Support raise exception. It's behavior matches non-strict export now, thanks to @ydwu4's [PR](pytorch#128709).
- Support prim::Unitialized, prim::Enter, and prim::Exit
Pull Request resolved: pytorch#129416
Approved by: https://github.com/angelayi
DiweiSun pushed a commit to DiweiSun/pytorch that referenced this pull request Jul 22, 2024
…, and prim::Exit (pytorch#129416)

- Support raise exception. It's behavior matches non-strict export now, thanks to @ydwu4's [PR](pytorch#128709).
- Support prim::Unitialized, prim::Enter, and prim::Exit
Pull Request resolved: pytorch#129416
Approved by: https://github.com/angelayi
DiweiSun pushed a commit to DiweiSun/pytorch that referenced this pull request Jul 22, 2024
…, and prim::Exit (pytorch#129416)

- Support raise exception. It's behavior matches non-strict export now, thanks to @ydwu4's [PR](pytorch#128709).
- Support prim::Unitialized, prim::Enter, and prim::Exit
Pull Request resolved: pytorch#129416
Approved by: https://github.com/angelayi
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
…nt (pytorch#128709)

When the input predicate is a python constant, we specialize into one of the branches and warn users that torch.cond is not preserving the dynamism. The previous behavior is that we baked in True/False in the cond operator. This can be confusing. In this PR, we change it to be specializing into one of the branches when the inputs are constants.

We additionally change the naming of cond operator to default one without overriding its name. This allows better testing on de-serialized graph.

Test Plan:
The predicate in some existing tests is the result of a shape comparison. When no dynamic shape is involved, the predicate is a python bool. To fix them, we either change the predicate to be some data-dependent tensor or change the test to check cond is specialized as one of the branches,

Differential Revision: [D59589709](https://our.internmc.facebook.com/intern/diff/D59589709)
Pull Request resolved: pytorch#128709
Approved by: https://github.com/zou3519
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
…nt (pytorch#130493)

Reland pytorch#128709.

When the input predicate is a python constant, we specialize into one of the branches and warn users that torch.cond is not preserving the dynamism. The previous behavior is that we baked in True/False in the cond operator. This can be confusing. In this PR, we change it to be specializing into one of the branches when the inputs are constants.

We additionally change the naming of cond operator to default one without overriding its name. This allows better testing on de-serialized graph.

Test Plan:
The predicate in some existing tests is the result of a shape comparison. When no dynamic shape is involved, the predicate is a python bool. To fix them, we either change the predicate to be some data-dependent tensor or change the test to check cond is specialized as one of the branches,

Pull Request resolved: pytorch#130493
Approved by: https://github.com/BoyuanFeng
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
…, and prim::Exit (pytorch#129416)

- Support raise exception. It's behavior matches non-strict export now, thanks to @ydwu4's [PR](pytorch#128709).
- Support prim::Unitialized, prim::Enter, and prim::Exit
Pull Request resolved: pytorch#129416
Approved by: https://github.com/angelayi
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
…, and prim::Exit (pytorch#129416)

- Support raise exception. It's behavior matches non-strict export now, thanks to @ydwu4's [PR](pytorch#128709).
- Support prim::Unitialized, prim::Enter, and prim::Exit
Pull Request resolved: pytorch#129416
Approved by: https://github.com/angelayi
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
…, and prim::Exit (pytorch#129416)

- Support raise exception. It's behavior matches non-strict export now, thanks to @ydwu4's [PR](pytorch#128709).
- Support prim::Unitialized, prim::Enter, and prim::Exit
Pull Request resolved: pytorch#129416
Approved by: https://github.com/angelayi
@github-actions github-actions bot deleted the gh/ydwu4/125/head branch August 10, 2024 01:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants