-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inductor][CPP] Enable Quantized Linear GEMM Template with FP32 output #128825
[Inductor][CPP] Enable Quantized Linear GEMM Template with FP32 output #128825
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128825
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (3 Unrelated Failures)As of commit 7334f44 with merge base dabaebd (): FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 1876e45d603fa56589b9d83895770b05f780577a Pull Request resolved: #128825
… FP32 output" **Summary** Support int8 GEMM Template with refer MictroInt8GeMM kernel for case: - Activation: uint8; Weight: int8; Output: float32/bfloat16 - With Unary Post operator fusion **Test Plan** ``` clear && python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_quantized_linear_with_pointwise ``` **Next Step** - [ ] Unary post op fusion - [ ] Int8 output - [ ] AMX int8 MicroGEMM Kernel - [ ] Binary Fusion cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
ghstack-source-id: bb71199be19188acaec121bf8200e6cb511fbb5b Pull Request resolved: #128825
… FP32 output" **Summary** Support int8 GEMM Template with refer MicroInt8GEMM kernel for case: - Activation dtype: uint8 - Weight dtype: int8 - Output dtype: float32/bfloat16 - Post Op Fusion: without unary post operator fusion **Test Plan** ``` clear && python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_quantized_linear_with_pointwise ``` **Next Step** - [ ] Unary post op fusion - [ ] Int8 output - [ ] AMX int8 MicroGEMM Kernel - [ ] Binary Fusion cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
ghstack-source-id: 6d89dfaa7bc1636770d98b225dbd15714857adfc Pull Request resolved: #128825
… FP32 output" **Summary** Support int8 GEMM Template with refer MicroInt8GEMM kernel for case: - Activation dtype: uint8 - Weight dtype: int8 - Output dtype: float32/bfloat16 - Post Op Fusion: without unary post operator fusion **Test Plan** ``` clear && python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_quantized_linear_with_pointwise ``` **Next Step** - [ ] Unary post op fusion - [ ] Int8 output - [ ] Binary Fusion - [ ] AMX int8 MicroGEMM Kernel cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
… FP32 output" **Summary** Support int8 GEMM Template with refer MicroInt8GEMM kernel for case: - Activation dtype: uint8 - Weight dtype: int8 - Output dtype: float32/bfloat16 - Post Op Fusion: without unary post operator fusion **Test Plan** ``` clear && python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_quantized_linear_with_pointwise ``` **Next Step** - [ ] Unary post op fusion - [ ] Int8 output - [ ] Binary Fusion - [ ] AMX int8 MicroGEMM Kernel cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
ghstack-source-id: 071f66db3cae3a20844132324fcdf0afabf4f880 Pull Request resolved: pytorch#128825
… FP32 output" **Summary** Support int8 GEMM Template with refer MicroInt8GEMM kernel for case: - Activation dtype: uint8 - Weight dtype: int8 - Output dtype: float32/bfloat16 - Post Op Fusion: without unary post operator fusion **Test Plan** ``` clear && python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_quantized_linear_with_pointwise ``` **Next Step** - [ ] Unary post op fusion - [ ] Int8 output - [ ] Binary Fusion - [ ] AMX int8 MicroGEMM Kernel cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
… FP32 output" **Summary** Support int8 GEMM Template with refer MicroInt8GEMM kernel for case: - Activation dtype: uint8 - Weight dtype: int8 - Output dtype: float32/bfloat16 - Post Op Fusion: without unary post operator fusion **Test Plan** ``` clear && python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_quantized_linear_with_pointwise ``` **Next Step** - [ ] Unary post op fusion - [ ] Int8 output - [ ] Binary Fusion - [ ] AMX int8 MicroGEMM Kernel cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
… FP32 output" **Summary** Support int8 GEMM Template with refer MicroInt8GEMM kernel for case: - Activation dtype: uint8 - Weight dtype: int8 - Output dtype: float32/bfloat16 - Post Op Fusion: without unary post operator fusion **Test Plan** ``` clear && python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_quantized_linear_with_pointwise ``` **Next Step** - [ ] Unary post op fusion - [ ] Int8 output - [ ] Binary Fusion - [ ] AMX int8 MicroGEMM Kernel cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
… FP32 output" **Summary** Support int8 GEMM Template with refer MicroInt8GEMM kernel for case: - Activation dtype: uint8 - Weight dtype: int8 - Output dtype: float32/bfloat16 - Post Op Fusion: without unary post operator fusion **Test Plan** ``` clear && python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_quantized_linear_with_pointwise ``` **Next Step** - [ ] Unary post op fusion - [ ] Int8 output - [ ] Binary Fusion - [ ] AMX int8 MicroGEMM Kernel cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
… FP32 output" **Summary** Support int8 GEMM Template with refer MicroInt8GEMM kernel for case: - Activation dtype: uint8 - Weight dtype: int8 - Output dtype: float32/bfloat16 - Post Op Fusion: without unary post operator fusion **Test Plan** ``` clear && python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_quantized_linear_with_pointwise ``` **Next Step** - [ ] Unary post op fusion - [ ] Int8 output - [ ] Binary Fusion - [ ] AMX int8 MicroGEMM Kernel cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
… FP32 output" **Summary** Support int8 GEMM Template with refer MicroInt8GEMM kernel for case: - Activation dtype: uint8 - Weight dtype: int8 - Output dtype: float32/bfloat16 - Post Op Fusion: without unary post operator fusion **Test Plan** ``` clear && python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_quantized_linear_with_pointwise ``` **Next Step** - [ ] Unary post op fusion - [ ] Int8 output - [ ] Binary Fusion - [ ] AMX int8 MicroGEMM Kernel cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Hi @jansel, could you kindly help to take a look of this PR? I have modified |
… FP32 output" **Summary** Support int8 GEMM Template with refer MicroInt8GEMM kernel for case: - Activation dtype: uint8 - Weight dtype: int8 - Output dtype: float32/bfloat16 - Post Op Fusion: without unary post operator fusion **Test Plan** ``` clear && python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_quantized_linear_with_pointwise ``` **Next Step** - [ ] Unary post op fusion - [ ] Int8 output - [ ] Binary Fusion - [ ] AMX int8 MicroGEMM Kernel cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…t and Unary Post Op (#129048) **Summary** Based on previous PR, add the config to support of int8 output and unary post op fusion with `ReLU` and `GeLU` - Activation dtype: uint8 - Weight dtype: int8 - Output dtype: float32/bfloat16/uint8 - Post Op Fusion: with unary post operator fusion **Test Plan** ``` clear && python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_quantized_linear_with_pointwise ``` **Next Step** - [✓] Unary post op fusion - [✓] Int8 output - [ ] Binary Fusion - [ ] AMX int8 MicroGEMM Kernel Pull Request resolved: #129048 Approved by: https://github.com/jgong5, https://github.com/jansel ghstack dependencies: #128825
**Summary** We change the schema of QLinear Binary, so it will be easier to enable the corresponding gemm template. - Extra input of binary post-op is a tensor which needs to be an input node of autotuning, we need to move it at front of `output_scale` which is a scalar. - We also move it at front of `bias`, since `bias` is optional tensor for this fusion, but `other` is a must to have for linear binary fusion. **Test Plan** ``` python -u -m pytest -s -v test/quantization/core/test_quantized_op.py -k qlinear python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k qlinear ``` Pull Request resolved: #129049 Approved by: https://github.com/jgong5, https://github.com/jansel ghstack dependencies: #128825, #129048
…ion (#129103) **Summary** Based on previous PR, add the config to support quantized linear binary - optional(unary) post op fusion. - Activation dtype: uint8 - Weight dtype: int8 - Output dtype: float32/bfloat16/uint8 - Post Op Fusion: with binary and optional[Unary] post operator fusion **Test Plan** ``` clear && python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_quantized_linear_with_pointwise_binary ``` **Next Step** - [✓] Unary post op fusion - [✓] Int8 output - [✓] Binary Fusion - [ ] AMX int8 MicroGEMM Kernel Pull Request resolved: #129103 Approved by: https://github.com/jgong5, https://github.com/jansel ghstack dependencies: #128825, #129048, #129049
**Summary** Add the AMX micro gemm kernel with int8 data type. **Test Plan** ``` clear && python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_quantized_linear_amx ``` **Next Step** - [✓] Unary post op fusion - [✓] Int8 output - [✓] Binary Fusion - [✓] AMX int8 MicroGEMM Kernel Pull Request resolved: #129220 Approved by: https://github.com/jgong5 ghstack dependencies: #128825, #129048, #129049, #129103
…129221) **Summary** This PR mainly refactor 2 things: 1. Passing in weight's data type explicitly in `create_micro_gemm` as `input2.dtype`. When registering `CppMicroGemmConfig`, we will reuse `input.dtype` if `input2.dtype` is not explicitly registered. 2. Add an util function to get the output data type and compute data type from input data type. Pull Request resolved: #129221 Approved by: https://github.com/jgong5, https://github.com/jansel ghstack dependencies: #128825, #129048, #129049, #129103, #129220
… template (#129470) **Summary** Remove redundant INT8-specific logic in the INT8 GEMM template to unify the code structure with FP32/BF16/FP16 GEMM Template. **Test Plan** ``` numactl -C 56-111 -m 1 python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_quantized_linear ``` Pull Request resolved: #129470 Approved by: https://github.com/jgong5 ghstack dependencies: #128825, #129048, #129049, #129103, #129220, #129221
pytorch#128825) **Summary** Support int8 GEMM Template with refer MicroInt8GEMM kernel for case: - Activation dtype: uint8 - Weight dtype: int8 - Output dtype: float32/bfloat16 - Post Op Fusion: without unary post operator fusion **Test Plan** ``` clear && python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_quantized_linear_with_pointwise ``` **Next Step** - [ ] Unary post op fusion - [ ] Int8 output - [ ] Binary Fusion - [ ] AMX int8 MicroGEMM Kernel Pull Request resolved: pytorch#128825 Approved by: https://github.com/jgong5, https://github.com/jansel
…t and Unary Post Op (pytorch#129048) **Summary** Based on previous PR, add the config to support of int8 output and unary post op fusion with `ReLU` and `GeLU` - Activation dtype: uint8 - Weight dtype: int8 - Output dtype: float32/bfloat16/uint8 - Post Op Fusion: with unary post operator fusion **Test Plan** ``` clear && python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_quantized_linear_with_pointwise ``` **Next Step** - [✓] Unary post op fusion - [✓] Int8 output - [ ] Binary Fusion - [ ] AMX int8 MicroGEMM Kernel Pull Request resolved: pytorch#129048 Approved by: https://github.com/jgong5, https://github.com/jansel ghstack dependencies: pytorch#128825
Stack from ghstack (oldest at bottom):
Summary
Support int8 GEMM Template with refer MicroInt8GEMM kernel for case:
Test Plan
Next Step
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang