Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Optimizer] DQ + MatMul to MatMulNBits support #21180

Closed
wants to merge 36 commits into from

Conversation

fajin-corp
Copy link
Contributor

@fajin-corp fajin-corp commented Jun 26, 2024

Description

MatMulNBits is a heavily optimized matmul operation. Currently a MatMul can be converted to MatMulNBits to speed up the model inference. However, MatMulNBits is an ORT only op. To make the graph compatible with ONNX ops and utilize MatMulNBits at the same time, we introduce Q/DQ support for MatMulNBits.

To convert MatMul ops in a model to MatMulNBits:

  1. use matmul_4bits_quantizer.py to convert MatMul to DQ + MatMul using QDQ mode.
  2. In ORT session, DQ + MatMul is fused to MatMulNBits

Note

MatMulNBits assume B weight is uint4. When no zp is provided, zp defaults to 8, which is different from DQ. DQ defaults zp to 0 when no zp provided. And DQ supports int4. Therefore some conversions are introduced during DQ + MatMul --> MatMulNBits step.

Perf

Using QDQ format will increase the model initialization time and memory consumption. With current implement, model init time increased from ~4s to ~9s, and memory consumption increased from ~2.8GB to ~4.8GB.
The memory increase is due to

  1. in optimizer, after transpose the B weight, a in-memory tensor proto is created using protobuf's arena.
  2. in finalize step, when saving initializer and prepacking, ORT arena is used to create buffers for initializers.

The memory allocated by arenas cannot be fully deallocated.
If disable ORT arena memory allocation, the memory consumptions of both QDQ format and original format are ~2.2GB.
The time increase is mainly due to multiple memory copy, but can be further optimized.

Motivation and Context

Please see description for details.

@fajin-corp fajin-corp requested a review from a team as a code owner June 26, 2024 06:48
@fajin-corp
Copy link
Contributor Author

/azp run Windows_CI / Windows-CUDA-12,Windows_SCA / Onnxruntime-SCA-training-CUDA

Copy link

No pipelines are associated with this pull request.

}

OrtThreadPoolParams to;
auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if we should create a new threadpool here. if we really need this to run with a threadpool, consider adjusting the APIs to pass in an existing threadpool.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix in next iter

onnxruntime/core/optimizer/graph_transformer_utils.cc Outdated Show resolved Hide resolved
// used together with DQMatMulNodeGroupSelector, which does the sanity check
struct DQMatMulReplaceWithMatMulNBits : public Action {
explicit DQMatMulReplaceWithMatMulNBits(int64_t accuracy_level);
Status Run(Graph&, const NodesToOptimize& selected_nodes) const override;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ReplaceWithNew would definitely be preferred.

Not sure why we seem to have a few things using Action directly (SplitReplaceWithQuant and MatMulReplaceWithQLinear), but the intention of the setup was to use something like ReplaceWithNew whenever possible so we're not duplication logic (and unnecessarily increasing binary size) for the common functionality.

@fajin-corp fajin-corp force-pushed the fajin/qdqmatmulnbitstoolchain branch 2 times, most recently from b26e0ea to d2a0008 Compare July 8, 2024 23:33
@guschmue
Copy link
Contributor

guschmue commented Jul 11, 2024

the CI error in ort-web is this test:
https://github.com/microsoft/onnxruntime/blob/main/js/web/test/e2e/browser-test-webgpu-external-data.js
super simple test that is checking if we load external data correctly.
On main the test is passing.

Fails with:
ERROR: 'worker sent an error! https://localhost:9876/base/node_modules/onnxruntime-web/dist/ort-wasm-simd-threaded.jsep.mjs:112: Uncaught TypeError: Sa is not a function'

Unclear how we get there.

@fajin-corp fajin-corp force-pushed the fajin/qdqmatmulnbitstoolchain branch from 6596e54 to 1ecf5c5 Compare July 12, 2024 19:15
fajin-corp added a commit that referenced this pull request Jul 20, 2024
### Description

This is a partial change from
[fajin/qdqmatmulnbitstoolchain](#21180).
The original PR is blocked by Web CI failures.

MatMulNBits is a heavily optimized matmul operation. Currently a MatMul
can be converted to MatMulNBits to speed up the model inference.
However, MatMulNBits is an ORT only op. To make the graph compatible
with ONNX ops and utilize MatMulNBits at the same time, we introduce
Q/DQ support for MatMulNBits.

To convert MatMul ops in a model to MatMulNBits:
1. use matmul_4bits_quantizer.py to convert MatMul to DQ + MatMul using
QDQ mode.
2. In ORT session, DQ + MatMul is fused to MatMulNBits

#### Note
MatMulNBits assume B weight is uint4. When no zp is provided, zp
defaults to 8, which is different from DQ. DQ defaults zp to 0 when no
zp provided. And DQ supports int4. Therefore some conversions are
introduced during DQ + MatMul --> MatMulNBits step.

#### Perf
Using QDQ format will increase the model initialization time and memory
consumption. With current implement, model init time increased from ~4s
to ~9s, and memory consumption increased from ~2.8GB to ~4.8GB.
The memory increase is due to 
1. in optimizer, after transpose the B weight, a in-memory tensor proto
is created using protobuf's arena.
2. in finalize step, when saving initializer and prepacking, ORT arena
is used to create buffers for initializers.

The memory allocated by arenas cannot be fully deallocated.
If disable ORT arena memory allocation, the memory consumptions of both
QDQ format and original format are ~2.2GB.
The time increase is mainly due to multiple memory copy, but can be
further optimized.

### Motivation and Context
Please see description for details.
@fajin-corp fajin-corp closed this Jul 23, 2024
@fajin-corp fajin-corp deleted the fajin/qdqmatmulnbitstoolchain branch July 23, 2024 20:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants