-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FakeTensor cache SymInt support #127596
FakeTensor cache SymInt support #127596
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/127596
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 1b8be29 with merge base 73d0f48 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 9c96e1f1f765ce72fc1c145006317ae683db5b52 Pull Request resolved: #127596
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
ghstack-source-id: cfbd3620309d030873243aa1ba7e5c2c99264956 Pull Request resolved: #127596
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
ghstack-source-id: 491552e4ad6c761d9e8c2fe2f9481d34ef49df7d Pull Request resolved: #127596
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
ghstack-source-id: 3a371c61bab631fb62cbc3a1ff13fc973ee27947 Pull Request resolved: #127596
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
ghstack-source-id: b790910ee1acf561ea7238d58ae315716f71b8ae Pull Request resolved: #127596
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
ghstack-source-id: 355002077db7bc67ac7f0386210b2a154aac1c77 Pull Request resolved: #127596
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
ghstack-source-id: 0e3707cb9c8ea9a462266b1695824cc612fc6ee1 Pull Request resolved: #127596
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…ache key" This is part of #127596, pulled out to make reviewing a little easier. Flatten the FakeTensor cache key - so it's a list of singular elements and pointing at one requires a single index rather than a PyTree path. This is used in the next PR to allow us to have the cache entry refer to an input SymInt that it needs to copy directly into the output. [ghstack-poisoned]
This is part of #127596, pulled out to make reviewing a little easier. Flatten the FakeTensor cache key - so it's a list of singular elements and pointing at one requires a single index rather than a PyTree path. This is used in the next PR to allow us to have the cache entry refer to an input SymInt that it needs to copy directly into the output. [ghstack-poisoned]
Adds support for SymInts in the FakeTensor cache. A couple notes: 1. When a SymInt is present in the input key for a FakeTensor operation we cache on the ShapeEnv instead of using the FakeTensorMode cache. This is necessary so we don't have to remember and check the guards. It reduces the cache hits but there's diminishing return on how much work we can do before the cache becomes more of a burden than a gain. 2. We need to be careful that when we cache an output SymInt that is a direct copy from the input that when we have a cache-hit we copy the SymNode from the input to the output. This is important because the fx-graph building code actually uses SymNode ids in the process of building the graph so constructing a same-content-but-different-id SymNode will fail. 3. In the cache key we store SymInts as a _PySymInputStub. These represent SymInt (and friends) but support `__hash__` and `__eq__` (which SymInt do not). 4. In the cache entry we store SymInts as a _SymIntOutputStub. Perf example: ``` python benchmarks/dynamo/timm_models.py --ci --accuracy --timing --explain --inductor --dynamic-shapes --dynamic-batch-only --device cuda --training --amp --total-partitions 2 --partition-id 0 --output /tmp/training_timm_models.csv --filter crossvit_9_240 ``` fake tensor cache before: ``` INFO: FakeTensor cache stats: INFO: cache_hits: 68137 INFO: cache_misses: 837 INFO: cache_bypasses: INFO: symbolic shape: 48224 INFO: CompositeImplicitAutograd: 917 INFO: non-fake tensor: 70 INFO: non-FakeTensor output: 62 INFO: non-builtin: 8 INFO: dynamic output shape: 1 ``` and after: ``` INFO: FakeTensor cache stats: INFO: cache_hits: 88187 INFO: cache_misses: 14233 INFO: cache_bypasses: INFO: CompositeImplicitAutograd: 1037 INFO: non-FakeTensor output: 602 INFO: non-fake tensor: 70 INFO: unsafe view: 36 INFO: non-builtin: 8 INFO: dynamic output shape: 1 ``` cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse tianyu-l [ghstack-poisoned]
…ache key" This is part of #127596, pulled out to make reviewing a little easier. Flatten the FakeTensor cache key - so it's a list of singular elements and pointing at one requires a single index rather than a PyTree path. This is used in the next PR to allow us to have the cache entry refer to an input SymInt that it needs to copy directly into the output. [ghstack-poisoned]
This is part of #127596, pulled out to make reviewing a little easier. Flatten the FakeTensor cache key - so it's a list of singular elements and pointing at one requires a single index rather than a PyTree path. This is used in the next PR to allow us to have the cache entry refer to an input SymInt that it needs to copy directly into the output. [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚢
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@@ -123,7 +123,7 @@ def compute_hint(): | |||
"Cannot create SymNode of type " | |||
f"{pytype} with incompatible hint of type {type(hint)}" | |||
) | |||
if self.shape_env._translation_validation_enabled: | |||
if self.shape_env and self.shape_env._translation_validation_enabled: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can shape_env be None here now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that was debug cruft - I'll remove in a follow-up PR.
and self.pytype == other.pytype | ||
and self._hint == other._hint | ||
and self.constant == other.constant | ||
and self.fx_node == other.fx_node |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually think you probably only need expr and pytype, the rest are derived quantities that don't matter for the purpose of cache matching. Actually, pytype also probably not needed either since you're probably only ever looking at integral quantities.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we were going to do that I'd want to move this check over to fake_tensor.py instead of having it on SymNode. It seems kind of weird to have a function on SymNode that was specific to caching fake tensors (and not have it named something like value_eq_for_fake_tensor_cache
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think i had to use pytype for symbool previously
zero_bytes = guard_size_oblivious(storage_bytes == 0) | ||
else: | ||
zero_bytes = storage_bytes == 0 | ||
if zero_bytes: | ||
empty.untyped_storage().resize_(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is going on here? I'm more asking about the preexisting code. Why do we need a special case for storage bytes zero?
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This is part of #127596, pulled out to make reviewing a little easier. Flatten the FakeTensor cache key - so it's a list of singular elements and pointing at one requires a single index rather than a PyTree path. This is used in the next PR to allow us to have the cache entry refer to an input SymInt that it needs to copy directly into the output. Pull Request resolved: #129780 Approved by: https://github.com/oulgen, https://github.com/eellison ghstack dependencies: #131014
This is part of pytorch#127596, pulled out to make reviewing a little easier. Flatten the FakeTensor cache key - so it's a list of singular elements and pointing at one requires a single index rather than a PyTree path. This is used in the next PR to allow us to have the cache entry refer to an input SymInt that it needs to copy directly into the output. Pull Request resolved: pytorch#129780 Approved by: https://github.com/oulgen, https://github.com/eellison ghstack dependencies: pytorch#131014
Adds support for SymInts in the FakeTensor cache. A couple notes: 1. When a SymInt is present in the input key for a FakeTensor operation we cache on the ShapeEnv instead of using the FakeTensorMode cache. This is necessary so we don't have to remember and check the guards. It reduces the cache hits but there's diminishing return on how much work we can do before the cache becomes more of a burden than a gain. 2. We need to be careful that when we cache an output SymInt that is a direct copy from the input that when we have a cache-hit we copy the SymNode from the input to the output. This is important because the fx-graph building code actually uses SymNode ids in the process of building the graph so constructing a same-content-but-different-id SymNode will fail. 3. In the cache key we store SymInts as a _PySymInputStub. These represent SymInt (and friends) but support `__hash__` and `__eq__` (which SymInt do not). 4. In the cache entry we store SymInts as a _SymIntOutputStub. Perf example: ``` python benchmarks/dynamo/timm_models.py --ci --accuracy --timing --explain --inductor --dynamic-shapes --dynamic-batch-only --device cuda --training --amp --total-partitions 2 --partition-id 0 --output /tmp/training_timm_models.csv --filter crossvit_9_240 ``` fake tensor cache before: ``` INFO: FakeTensor cache stats: INFO: cache_hits: 68137 INFO: cache_misses: 837 INFO: cache_bypasses: INFO: symbolic shape: 48224 INFO: CompositeImplicitAutograd: 917 INFO: non-fake tensor: 70 INFO: non-FakeTensor output: 62 INFO: non-builtin: 8 INFO: dynamic output shape: 1 ``` and after: ``` INFO: FakeTensor cache stats: INFO: cache_hits: 88187 INFO: cache_misses: 14233 INFO: cache_bypasses: INFO: CompositeImplicitAutograd: 1037 INFO: non-FakeTensor output: 602 INFO: non-fake tensor: 70 INFO: unsafe view: 36 INFO: non-builtin: 8 INFO: dynamic output shape: 1 ``` Pull Request resolved: pytorch#127596 Approved by: https://github.com/eellison ghstack dependencies: pytorch#131014, pytorch#129780
This is part of pytorch#127596, pulled out to make reviewing a little easier. Flatten the FakeTensor cache key - so it's a list of singular elements and pointing at one requires a single index rather than a PyTree path. This is used in the next PR to allow us to have the cache entry refer to an input SymInt that it needs to copy directly into the output. Pull Request resolved: pytorch#129780 Approved by: https://github.com/oulgen, https://github.com/eellison ghstack dependencies: pytorch#131014
Adds support for SymInts in the FakeTensor cache. A couple notes: 1. When a SymInt is present in the input key for a FakeTensor operation we cache on the ShapeEnv instead of using the FakeTensorMode cache. This is necessary so we don't have to remember and check the guards. It reduces the cache hits but there's diminishing return on how much work we can do before the cache becomes more of a burden than a gain. 2. We need to be careful that when we cache an output SymInt that is a direct copy from the input that when we have a cache-hit we copy the SymNode from the input to the output. This is important because the fx-graph building code actually uses SymNode ids in the process of building the graph so constructing a same-content-but-different-id SymNode will fail. 3. In the cache key we store SymInts as a _PySymInputStub. These represent SymInt (and friends) but support `__hash__` and `__eq__` (which SymInt do not). 4. In the cache entry we store SymInts as a _SymIntOutputStub. Perf example: ``` python benchmarks/dynamo/timm_models.py --ci --accuracy --timing --explain --inductor --dynamic-shapes --dynamic-batch-only --device cuda --training --amp --total-partitions 2 --partition-id 0 --output /tmp/training_timm_models.csv --filter crossvit_9_240 ``` fake tensor cache before: ``` INFO: FakeTensor cache stats: INFO: cache_hits: 68137 INFO: cache_misses: 837 INFO: cache_bypasses: INFO: symbolic shape: 48224 INFO: CompositeImplicitAutograd: 917 INFO: non-fake tensor: 70 INFO: non-FakeTensor output: 62 INFO: non-builtin: 8 INFO: dynamic output shape: 1 ``` and after: ``` INFO: FakeTensor cache stats: INFO: cache_hits: 88187 INFO: cache_misses: 14233 INFO: cache_bypasses: INFO: CompositeImplicitAutograd: 1037 INFO: non-FakeTensor output: 602 INFO: non-fake tensor: 70 INFO: unsafe view: 36 INFO: non-builtin: 8 INFO: dynamic output shape: 1 ``` Pull Request resolved: pytorch#127596 Approved by: https://github.com/eellison ghstack dependencies: pytorch#131014, pytorch#129780
Adds support for SymInts in the FakeTensor cache.
A couple notes:
__hash__
and__eq__
(which SymInt do not).Perf example:
fake tensor cache before:
and after:
Stack from ghstack (oldest at bottom):
cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @tianyu-l