Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AWQ quantization inference support (#1019) #1054

Merged
merged 12 commits into from
Sep 25, 2023
Merged

Add AWQ quantization inference support (#1019) #1054

merged 12 commits into from
Sep 25, 2023

Conversation

Narsil
Copy link
Collaborator

@Narsil Narsil commented Sep 25, 2023

Add AWQ quantization inference support

Fixes
#781

This PR (partially) adds support for AWQ quantization for inference.
More information on AWQ here. In
general, AWQ is faster and more accurate than GPTQ, which is currently
supported by TGI.

This PR installs 4-bit GEMM custom CUDA kernels released by AWQ authors
(in requirements.txt, just one line change).

Quick way to test this PR would be bring up TGI as follows:

text-generation-server download-weights abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq

text-generation-launcher \
--huggingface-hub-cache ~/.cache/huggingface/hub/ \
--model-id abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq \
--trust-remote-code --port 8080 \
--max-input-length 2048 --max-total-tokens 4096 --max-batch-prefill-tokens 4096 \
--quantize awq

Please note:

  • This PR was tested with FlashAttention v2 and vLLM.
  • This PR adds support for AWQ inference, not quantizing the models.
    That needs to be done outside of TGI, instructions
    here.
  • This PR only adds support for FlashLlama models for now.
  • Multi-GPU setup has not been tested.
  • No integration tests have been added so far, will add later if
    maintainers are interested in this change.
  • This PR can be tested on any of the models released
    here.

Please refer to the linked issue for benchmarks for
abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq
vs
TheBloke/Llama-2-7b-Chat-GPTQ.

Please note, AWQ has released faster (and in case of Llama, fused)
kernels for 4-bit GEMM, currently at the top of the main branch at
https://github.com/mit-han-lab/llm-awq, but this PR uses an older commit
that has been tested to work. We can switch to latest commit later on.

Who can review?

@OlivierDehaene OR @Narsil


What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

# Add AWQ quantization inference support

Fixes
#781

This PR (partially) adds support for AWQ quantization for inference.
More information on AWQ [here](https://arxiv.org/abs/2306.00978). In
general, AWQ is faster and more accurate than GPTQ, which is currently
supported by TGI.

This PR installs 4-bit GEMM custom CUDA kernels released by AWQ authors
(in `requirements.txt`, just one line change).

Quick way to test this PR would be bring up TGI as follows:

```
text-generation-server download-weights abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq

text-generation-launcher \
--huggingface-hub-cache ~/.cache/huggingface/hub/ \
--model-id abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq \
--trust-remote-code --port 8080 \
--max-input-length 2048 --max-total-tokens 4096 --max-batch-prefill-tokens 4096 \
--quantize awq
```

Please note:
* This PR was tested with FlashAttention v2 and vLLM.
* This PR adds support for AWQ inference, not quantizing the models.
That needs to be done outside of TGI, instructions
[here](https://github.com/mit-han-lab/llm-awq/tree/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa).
* This PR only adds support for `FlashLlama` models for now.
* Multi-GPU setup has not been tested. 
* No integration tests have been added so far, will add later if
maintainers are interested in this change.
* This PR can be tested on any of the models released
[here](https://huggingface.co/abhinavkulkarni?sort_models=downloads#models).

Please refer to the linked issue for benchmarks for
[abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq](https://huggingface.co/abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq)
vs
[TheBloke/Llama-2-7b-Chat-GPTQ](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GPTQ).

Please note, AWQ has released faster (and in case of Llama, fused)
kernels for 4-bit GEMM, currently at the top of the `main` branch at
https://github.com/mit-han-lab/llm-awq, but this PR uses an older commit
that has been tested to work. We can switch to latest commit later on.

## Who can review?

@OlivierDehaene OR @Narsil

---------

Co-authored-by: Abhinav Kulkarni <[email protected]>
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@Narsil
Copy link
Collaborator Author

Narsil commented Sep 25, 2023

@abhinavkulkarni
@casper-hansen

For visibility.

@casper-hansen
Copy link

Please note, AWQ has released faster (and in case of Llama, fused)
kernels for 4-bit GEMM, currently at the top of the main branch at
https://github.com/mit-han-lab/llm-awq, but this PR uses an older commit
that has been tested to work. We can switch to latest commit later on.

They have released one new GEMV kernel and one new GEMM kernel. On their main branch, they use GEMV for token generation and GEMM for context processing. The GEMV kernel is 20% faster than the old GEMM kernel, but most importantly, the new GEMM kernel is 5-6 slower than the old GEMM kernel.

My conclusion is that the current GEMM kernel already used in this PR is the optimal one for now.

@Narsil
Copy link
Collaborator Author

Narsil commented Sep 25, 2023

The GEMV kernel is 20% faster than the old GEMM kernel, but most importantly, the new GEMM kernel is 5-6 slower than the old GEMM kernel.

I didn't test the GEMV, but the GEMM seems to have similar speeds for me (A10G) which card did you test on ?

@casper-hansen
Copy link

casper-hansen commented Sep 25, 2023

The GEMV kernel is 20% faster than the old GEMM kernel, but most importantly, the new GEMM kernel is 5-6 slower than the old GEMM kernel.

I didn't test the GEMV, but the GEMM seems to have similar speeds for me (A10G) which card did you test on ?

There are multiple GEMM kernels. The new one is slower because it implements a different packed format for quantized models. I tested on RTX 3090, 4090, and A100.

On RTX 3090, speed of context processing on LLaMa 7B:

GEMM (original): 2400 tokens/s
GEMM (new): 440 tokens/s
GEMV: 234 tokens/s

EDIT: What I mean by "new" GEMM kernel is this pull request that is about to be merged into the original llm-awq: mit-han-lab/llm-awq#90

@Narsil
Copy link
Collaborator Author

Narsil commented Sep 25, 2023

GEMM (original): 2400 tokens/s
GEMM (new): 440 tokens/s
GEMV: 234 tokens/s

Those seem like throughput values, which are relatively not important in general. We really care about latency much more. Ideally the whole curves gives a better story since we do want throughput, but at fixed latency. (2x throughput for 2x latency, is usually unacceptable in our deployments for instance).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants