Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FakeTensor cache SymInt support #127596

Closed
wants to merge 36 commits into from
Closed

Conversation

aorenste
Copy link
Contributor

@aorenste aorenste commented May 31, 2024

Adds support for SymInts in the FakeTensor cache.

A couple notes:

  1. When a SymInt is present in the input key for a FakeTensor operation we cache on the ShapeEnv instead of using the FakeTensorMode cache. This is necessary so we don't have to remember and check the guards. It reduces the cache hits but there's diminishing return on how much work we can do before the cache becomes more of a burden than a gain.
  2. We need to be careful that when we cache an output SymInt that is a direct copy from the input that when we have a cache-hit we copy the SymNode from the input to the output. This is important because the fx-graph building code actually uses SymNode ids in the process of building the graph so constructing a same-content-but-different-id SymNode will fail.
  3. In the cache key we store SymInts as a _PySymInputStub. These represent SymInt (and friends) but support __hash__ and __eq__ (which SymInt do not).
  4. In the cache entry we store SymInts as a _SymIntOutputStub.

Perf example:

python benchmarks/dynamo/timm_models.py --ci --accuracy --timing
--explain --inductor --dynamic-shapes --dynamic-batch-only --device cuda
--training --amp --total-partitions 2 --partition-id 0 --output
/tmp/training_timm_models.csv --filter crossvit_9_240

fake tensor cache before:

INFO: FakeTensor cache stats:
INFO:   cache_hits: 68137
INFO:   cache_misses: 837
INFO:   cache_bypasses:
INFO:     symbolic shape:            48224
INFO:     CompositeImplicitAutograd: 917
INFO:     non-fake tensor:           70
INFO:     non-FakeTensor output:     62
INFO:     non-builtin:               8
INFO:     dynamic output shape:      1

and after:

INFO: FakeTensor cache stats:
INFO:   cache_hits: 88187
INFO:   cache_misses: 14233
INFO:   cache_bypasses:
INFO:     CompositeImplicitAutograd: 1037
INFO:     non-FakeTensor output:     602
INFO:     non-fake tensor:           70
INFO:     unsafe view:               36
INFO:     non-builtin:               8
INFO:     dynamic output shape:      1

Stack from ghstack (oldest at bottom):

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @tianyu-l

Copy link

pytorch-bot bot commented May 31, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/127596

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1b8be29 with merge base 73d0f48 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

aorenste added a commit that referenced this pull request May 31, 2024
ghstack-source-id: 9c96e1f1f765ce72fc1c145006317ae683db5b52
Pull Request resolved: #127596
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
aorenste added a commit that referenced this pull request Jun 1, 2024
ghstack-source-id: cfbd3620309d030873243aa1ba7e5c2c99264956
Pull Request resolved: #127596
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
@aorenste aorenste mentioned this pull request Jun 7, 2024
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
aorenste added a commit that referenced this pull request Jun 7, 2024
ghstack-source-id: 491552e4ad6c761d9e8c2fe2f9481d34ef49df7d
Pull Request resolved: #127596
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
aorenste added a commit that referenced this pull request Jun 11, 2024
ghstack-source-id: 3a371c61bab631fb62cbc3a1ff13fc973ee27947
Pull Request resolved: #127596
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
aorenste added a commit that referenced this pull request Jun 13, 2024
ghstack-source-id: b790910ee1acf561ea7238d58ae315716f71b8ae
Pull Request resolved: #127596
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
aorenste added a commit that referenced this pull request Jun 13, 2024
ghstack-source-id: 355002077db7bc67ac7f0386210b2a154aac1c77
Pull Request resolved: #127596
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
aorenste added a commit that referenced this pull request Jun 13, 2024
ghstack-source-id: 0e3707cb9c8ea9a462266b1695824cc612fc6ee1
Pull Request resolved: #127596
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
aorenste added a commit that referenced this pull request Jul 18, 2024
…ache key"

This is part of #127596, pulled out to make reviewing a little easier.

Flatten the FakeTensor cache key - so it's a list of singular elements and pointing at one requires a single index rather than a PyTree path.  This is used in the next PR to allow us to have the cache entry refer to an input SymInt that it needs to copy directly into the output.




[ghstack-poisoned]
aorenste added a commit that referenced this pull request Jul 18, 2024
This is part of #127596, pulled out to make reviewing a little easier.

Flatten the FakeTensor cache key - so it's a list of singular elements and pointing at one requires a single index rather than a PyTree path.  This is used in the next PR to allow us to have the cache entry refer to an input SymInt that it needs to copy directly into the output.




[ghstack-poisoned]
@aorenste aorenste requested a review from eellison July 18, 2024 18:21
Adds support for SymInts in the FakeTensor cache.

A couple notes:
1. When a SymInt is present in the input key for a FakeTensor operation we cache on the ShapeEnv instead of using the FakeTensorMode cache. This is necessary so we don't have to remember and check the guards. It reduces the cache hits but there's diminishing return on how much work we can do before the cache becomes more of a burden than a gain.
2. We need to be careful that when we cache an output SymInt that is a direct copy from the input that when we have a cache-hit we copy the SymNode from the input to the output. This is important because the fx-graph building code actually uses SymNode ids in the process of building the graph so constructing a same-content-but-different-id SymNode will fail.
3. In the cache key we store SymInts as a _PySymInputStub. These represent SymInt (and friends) but support `__hash__` and `__eq__` (which SymInt do not).
4. In the cache entry we store SymInts as a _SymIntOutputStub.

Perf example:
```
python benchmarks/dynamo/timm_models.py --ci --accuracy --timing
--explain --inductor --dynamic-shapes --dynamic-batch-only --device cuda
--training --amp --total-partitions 2 --partition-id 0 --output
/tmp/training_timm_models.csv --filter crossvit_9_240
```
fake tensor cache before:
```
INFO: FakeTensor cache stats:
INFO:   cache_hits: 68137
INFO:   cache_misses: 837
INFO:   cache_bypasses:
INFO:     symbolic shape:            48224
INFO:     CompositeImplicitAutograd: 917
INFO:     non-fake tensor:           70
INFO:     non-FakeTensor output:     62
INFO:     non-builtin:               8
INFO:     dynamic output shape:      1
```
and after:
```
INFO: FakeTensor cache stats:
INFO:   cache_hits: 88187
INFO:   cache_misses: 14233
INFO:   cache_bypasses:
INFO:     CompositeImplicitAutograd: 1037
INFO:     non-FakeTensor output:     602
INFO:     non-fake tensor:           70
INFO:     unsafe view:               36
INFO:     non-builtin:               8
INFO:     dynamic output shape:      1
```




cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse tianyu-l

[ghstack-poisoned]
aorenste added a commit that referenced this pull request Jul 18, 2024
…ache key"

This is part of #127596, pulled out to make reviewing a little easier.

Flatten the FakeTensor cache key - so it's a list of singular elements and pointing at one requires a single index rather than a PyTree path.  This is used in the next PR to allow us to have the cache entry refer to an input SymInt that it needs to copy directly into the output.




[ghstack-poisoned]
aorenste added a commit that referenced this pull request Jul 18, 2024
This is part of #127596, pulled out to make reviewing a little easier.

Flatten the FakeTensor cache key - so it's a list of singular elements and pointing at one requires a single index rather than a PyTree path.  This is used in the next PR to allow us to have the cache entry refer to an input SymInt that it needs to copy directly into the output.




[ghstack-poisoned]
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚢

torch/_subclasses/fake_tensor.py Show resolved Hide resolved
torch/_subclasses/fake_tensor.py Show resolved Hide resolved
@aorenste
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 19, 2024
@aorenste aorenste added the topic: not user facing topic category label Jul 19, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@@ -123,7 +123,7 @@ def compute_hint():
"Cannot create SymNode of type "
f"{pytype} with incompatible hint of type {type(hint)}"
)
if self.shape_env._translation_validation_enabled:
if self.shape_env and self.shape_env._translation_validation_enabled:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can shape_env be None here now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that was debug cruft - I'll remove in a follow-up PR.

and self.pytype == other.pytype
and self._hint == other._hint
and self.constant == other.constant
and self.fx_node == other.fx_node
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually think you probably only need expr and pytype, the rest are derived quantities that don't matter for the purpose of cache matching. Actually, pytype also probably not needed either since you're probably only ever looking at integral quantities.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we were going to do that I'd want to move this check over to fake_tensor.py instead of having it on SymNode. It seems kind of weird to have a function on SymNode that was specific to caching fake tensors (and not have it named something like value_eq_for_fake_tensor_cache)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think i had to use pytype for symbool previously

zero_bytes = guard_size_oblivious(storage_bytes == 0)
else:
zero_bytes = storage_bytes == 0
if zero_bytes:
empty.untyped_storage().resize_(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is going on here? I'm more asking about the preexisting code. Why do we need a special case for storage bytes zero?

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@aorenste
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Jul 21, 2024
This is part of #127596, pulled out to make reviewing a little easier.

Flatten the FakeTensor cache key - so it's a list of singular elements and pointing at one requires a single index rather than a PyTree path.  This is used in the next PR to allow us to have the cache entry refer to an input SymInt that it needs to copy directly into the output.

Pull Request resolved: #129780
Approved by: https://github.com/oulgen, https://github.com/eellison
ghstack dependencies: #131014
DiweiSun pushed a commit to DiweiSun/pytorch that referenced this pull request Jul 22, 2024
This is part of pytorch#127596, pulled out to make reviewing a little easier.

Flatten the FakeTensor cache key - so it's a list of singular elements and pointing at one requires a single index rather than a PyTree path.  This is used in the next PR to allow us to have the cache entry refer to an input SymInt that it needs to copy directly into the output.

Pull Request resolved: pytorch#129780
Approved by: https://github.com/oulgen, https://github.com/eellison
ghstack dependencies: pytorch#131014
DiweiSun pushed a commit to DiweiSun/pytorch that referenced this pull request Jul 22, 2024
Adds support for SymInts in the FakeTensor cache.

A couple notes:
1. When a SymInt is present in the input key for a FakeTensor operation we cache on the ShapeEnv instead of using the FakeTensorMode cache. This is necessary so we don't have to remember and check the guards. It reduces the cache hits but there's diminishing return on how much work we can do before the cache becomes more of a burden than a gain.
2. We need to be careful that when we cache an output SymInt that is a direct copy from the input that when we have a cache-hit we copy the SymNode from the input to the output. This is important because the fx-graph building code actually uses SymNode ids in the process of building the graph so constructing a same-content-but-different-id SymNode will fail.
3. In the cache key we store SymInts as a _PySymInputStub. These represent SymInt (and friends) but support `__hash__` and `__eq__` (which SymInt do not).
4. In the cache entry we store SymInts as a _SymIntOutputStub.

Perf example:
```
python benchmarks/dynamo/timm_models.py --ci --accuracy --timing
--explain --inductor --dynamic-shapes --dynamic-batch-only --device cuda
--training --amp --total-partitions 2 --partition-id 0 --output
/tmp/training_timm_models.csv --filter crossvit_9_240
```
fake tensor cache before:
```
INFO: FakeTensor cache stats:
INFO:   cache_hits: 68137
INFO:   cache_misses: 837
INFO:   cache_bypasses:
INFO:     symbolic shape:            48224
INFO:     CompositeImplicitAutograd: 917
INFO:     non-fake tensor:           70
INFO:     non-FakeTensor output:     62
INFO:     non-builtin:               8
INFO:     dynamic output shape:      1
```
and after:
```
INFO: FakeTensor cache stats:
INFO:   cache_hits: 88187
INFO:   cache_misses: 14233
INFO:   cache_bypasses:
INFO:     CompositeImplicitAutograd: 1037
INFO:     non-FakeTensor output:     602
INFO:     non-fake tensor:           70
INFO:     unsafe view:               36
INFO:     non-builtin:               8
INFO:     dynamic output shape:      1
```

Pull Request resolved: pytorch#127596
Approved by: https://github.com/eellison
ghstack dependencies: pytorch#131014, pytorch#129780
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
This is part of pytorch#127596, pulled out to make reviewing a little easier.

Flatten the FakeTensor cache key - so it's a list of singular elements and pointing at one requires a single index rather than a PyTree path.  This is used in the next PR to allow us to have the cache entry refer to an input SymInt that it needs to copy directly into the output.

Pull Request resolved: pytorch#129780
Approved by: https://github.com/oulgen, https://github.com/eellison
ghstack dependencies: pytorch#131014
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
Adds support for SymInts in the FakeTensor cache.

A couple notes:
1. When a SymInt is present in the input key for a FakeTensor operation we cache on the ShapeEnv instead of using the FakeTensorMode cache. This is necessary so we don't have to remember and check the guards. It reduces the cache hits but there's diminishing return on how much work we can do before the cache becomes more of a burden than a gain.
2. We need to be careful that when we cache an output SymInt that is a direct copy from the input that when we have a cache-hit we copy the SymNode from the input to the output. This is important because the fx-graph building code actually uses SymNode ids in the process of building the graph so constructing a same-content-but-different-id SymNode will fail.
3. In the cache key we store SymInts as a _PySymInputStub. These represent SymInt (and friends) but support `__hash__` and `__eq__` (which SymInt do not).
4. In the cache entry we store SymInts as a _SymIntOutputStub.

Perf example:
```
python benchmarks/dynamo/timm_models.py --ci --accuracy --timing
--explain --inductor --dynamic-shapes --dynamic-batch-only --device cuda
--training --amp --total-partitions 2 --partition-id 0 --output
/tmp/training_timm_models.csv --filter crossvit_9_240
```
fake tensor cache before:
```
INFO: FakeTensor cache stats:
INFO:   cache_hits: 68137
INFO:   cache_misses: 837
INFO:   cache_bypasses:
INFO:     symbolic shape:            48224
INFO:     CompositeImplicitAutograd: 917
INFO:     non-fake tensor:           70
INFO:     non-FakeTensor output:     62
INFO:     non-builtin:               8
INFO:     dynamic output shape:      1
```
and after:
```
INFO: FakeTensor cache stats:
INFO:   cache_hits: 88187
INFO:   cache_misses: 14233
INFO:   cache_bypasses:
INFO:     CompositeImplicitAutograd: 1037
INFO:     non-FakeTensor output:     602
INFO:     non-fake tensor:           70
INFO:     unsafe view:               36
INFO:     non-builtin:               8
INFO:     dynamic output shape:      1
```

Pull Request resolved: pytorch#127596
Approved by: https://github.com/eellison
ghstack dependencies: pytorch#131014, pytorch#129780
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: fx release notes category topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants