Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: int8 tensor cores for MMQ (q4_K, q5_K, q6_K) #7860

Merged
merged 1 commit into from
Jun 11, 2024

Conversation

JohannesGaessler
Copy link
Collaborator

This PR adds int8 tensor core support for the q4_K, q5_K, and q6_K mul_mat_q kernels. Originally I wanted to put all k-quants into the same PR but in retrospect the MMQ code for q2_K and q3_K is kind of bad so I think it's in need of general refactoring before I try to add int8 tensor core support.

Performance vs. master MMQ
GPU Model Microbatch size Test t/s master t/s cuda-ptx-mma-12 Speedup
RTX 4090 llama 8B Q4_K_S 16 pp2048 1463.94 1946.01 1.33
RTX 4090 llama 8B Q4_K_S 32 pp2048 2341.23 3358.03 1.43
RTX 4090 llama 8B Q4_K_S 64 pp2048 3387.63 5349.26 1.58
RTX 4090 llama 8B Q4_K_S 128 pp2048 4443.87 6901.00 1.55
RTX 4090 llama 8B Q4_K_S 256 pp2048 5222.84 8555.10 1.64
RTX 4090 llama 8B Q4_K_S 512 pp2048 5611.17 9278.55 1.65
RTX 4090 llama 8B Q4_K_S 1024 pp2048 5806.79 9265.84 1.60
RTX 4090 llama 8B Q4_K_S 2048 pp2048 5549.92 8655.44 1.56
RTX 4090 llama 8B Q5_K_S 16 pp2048 1344.52 1675.27 1.25
RTX 4090 llama 8B Q5_K_S 32 pp2048 2007.37 2792.81 1.39
RTX 4090 llama 8B Q5_K_S 64 pp2048 3135.71 4592.53 1.46
RTX 4090 llama 8B Q5_K_S 128 pp2048 3930.48 6058.36 1.54
RTX 4090 llama 8B Q5_K_S 256 pp2048 4785.05 7758.42 1.62
RTX 4090 llama 8B Q5_K_S 512 pp2048 5211.79 8671.69 1.66
RTX 4090 llama 8B Q5_K_S 1024 pp2048 5437.78 8783.84 1.62
RTX 4090 llama 8B Q5_K_S 2048 pp2048 5212.39 8263.49 1.59
RTX 4090 llama 8B Q6_K 16 pp2048 1163.29 1413.55 1.22
RTX 4090 llama 8B Q6_K 32 pp2048 1756.61 2460.78 1.40
RTX 4090 llama 8B Q6_K 64 pp2048 2838.87 4424.35 1.56
RTX 4090 llama 8B Q6_K 128 pp2048 3844.01 6071.93 1.58
RTX 4090 llama 8B Q6_K 256 pp2048 4654.52 7997.45 1.72
RTX 4090 llama 8B Q6_K 512 pp2048 5150.89 8796.22 1.71
RTX 4090 llama 8B Q6_K 1024 pp2048 5318.41 8910.36 1.68
RTX 4090 llama 8B Q6_K 2048 pp2048 5080.49 8255.21 1.62
RTX 3090 llama 8B Q4_K_S 16 pp2048 801.95 1234.40 1.54
RTX 3090 llama 8B Q4_K_S 32 pp2048 1074.81 1774.68 1.65
RTX 3090 llama 8B Q4_K_S 64 pp2048 1451.21 2341.89 1.61
RTX 3090 llama 8B Q4_K_S 128 pp2048 1747.08 2897.55 1.66
RTX 3090 llama 8B Q4_K_S 256 pp2048 2090.00 3460.06 1.66
RTX 3090 llama 8B Q4_K_S 512 pp2048 2214.67 3694.28 1.67
RTX 3090 llama 8B Q4_K_S 1024 pp2048 2318.00 3808.69 1.64
RTX 3090 llama 8B Q4_K_S 2048 pp2048 2251.89 3643.31 1.62
RTX 3090 llama 8B Q5_K_S 16 pp2048 702.35 1006.02 1.43
RTX 3090 llama 8B Q5_K_S 32 pp2048 944.24 1447.94 1.53
RTX 3090 llama 8B Q5_K_S 64 pp2048 1240.48 1944.35 1.57
RTX 3090 llama 8B Q5_K_S 128 pp2048 1583.12 2569.41 1.62
RTX 3090 llama 8B Q5_K_S 256 pp2048 1895.89 3250.03 1.71
RTX 3090 llama 8B Q5_K_S 512 pp2048 2060.21 3401.43 1.65
RTX 3090 llama 8B Q5_K_S 1024 pp2048 2152.24 3503.48 1.63
RTX 3090 llama 8B Q5_K_S 2048 pp2048 2106.41 3440.64 1.63
RTX 3090 llama 8B Q6_K 16 pp2048 590.82 884.87 1.50
RTX 3090 llama 8B Q6_K 32 pp2048 888.50 1379.53 1.55
RTX 3090 llama 8B Q6_K 64 pp2048 1198.96 2018.45 1.68
RTX 3090 llama 8B Q6_K 128 pp2048 1497.54 2756.86 1.84
RTX 3090 llama 8B Q6_K 256 pp2048 1891.89 3333.43 1.76
RTX 3090 llama 8B Q6_K 512 pp2048 2002.13 3491.20 1.74
RTX 3090 llama 8B Q6_K 1024 pp2048 2076.91 3593.81 1.73
RTX 3090 llama 8B Q6_K 2048 pp2048 2034.32 3525.21 1.73
Performance vs. master FP16 cuBLAS
GPU Model Microbatch size Test t/s master t/s cuda-ptx-mma-12 Speedup
RTX 4090 llama 8B Q4_K_S 16 pp2048 1464.96 1943.99 1.33
RTX 4090 llama 8B Q4_K_S 32 pp2048 2340.40 3356.56 1.43
RTX 4090 llama 8B Q4_K_S 64 pp2048 3391.27 5359.39 1.58
RTX 4090 llama 8B Q4_K_S 128 pp2048 3483.04 6901.05 1.98
RTX 4090 llama 8B Q4_K_S 256 pp2048 5750.04 8531.97 1.48
RTX 4090 llama 8B Q4_K_S 512 pp2048 7694.01 9267.19 1.20
RTX 4090 llama 8B Q4_K_S 1024 pp2048 9019.89 9280.11 1.03
RTX 4090 llama 8B Q4_K_S 2048 pp2048 9020.72 8650.88 0.96
RTX 4090 llama 8B Q5_K_S 16 pp2048 1339.82 1675.33 1.25
RTX 4090 llama 8B Q5_K_S 32 pp2048 1999.00 2788.10 1.39
RTX 4090 llama 8B Q5_K_S 64 pp2048 3136.95 4581.04 1.46
RTX 4090 llama 8B Q5_K_S 128 pp2048 3451.99 6037.27 1.75
RTX 4090 llama 8B Q5_K_S 256 pp2048 5674.22 7752.61 1.37
RTX 4090 llama 8B Q5_K_S 512 pp2048 7631.15 8676.09 1.14
RTX 4090 llama 8B Q5_K_S 1024 pp2048 8942.08 8780.76 0.98
RTX 4090 llama 8B Q5_K_S 2048 pp2048 8988.51 8256.40 0.92
RTX 4090 llama 8B Q6_K 16 pp2048 1160.29 1408.84 1.21
RTX 4090 llama 8B Q6_K 32 pp2048 1756.13 2459.62 1.40
RTX 4090 llama 8B Q6_K 64 pp2048 2847.17 4418.16 1.55
RTX 4090 llama 8B Q6_K 128 pp2048 3383.93 6068.10 1.79
RTX 4090 llama 8B Q6_K 256 pp2048 5576.11 8028.77 1.44
RTX 4090 llama 8B Q6_K 512 pp2048 7516.95 8795.13 1.17
RTX 4090 llama 8B Q6_K 1024 pp2048 8759.06 8900.03 1.02
RTX 4090 llama 8B Q6_K 2048 pp2048 8660.30 8151.61 0.94
RTX 3090 llama 8B Q4_K_S 16 pp2048 798.51 1226.10 1.54
RTX 3090 llama 8B Q4_K_S 32 pp2048 1073.65 1731.00 1.61
RTX 3090 llama 8B Q4_K_S 64 pp2048 1441.52 2279.13 1.58
RTX 3090 llama 8B Q4_K_S 128 pp2048 2259.12 2826.25 1.25
RTX 3090 llama 8B Q4_K_S 256 pp2048 3341.46 3411.38 1.02
RTX 3090 llama 8B Q4_K_S 512 pp2048 3963.55 3651.70 0.92
RTX 3090 llama 8B Q4_K_S 1024 pp2048 4626.58 3767.43 0.81
RTX 3090 llama 8B Q4_K_S 2048 pp2048 4687.16 3619.86 0.77
RTX 3090 llama 8B Q5_K_S 16 pp2048 700.59 989.79 1.41
RTX 3090 llama 8B Q5_K_S 32 pp2048 940.47 1431.55 1.52
RTX 3090 llama 8B Q5_K_S 64 pp2048 1229.62 1929.57 1.57
RTX 3090 llama 8B Q5_K_S 128 pp2048 2210.56 2536.07 1.15
RTX 3090 llama 8B Q5_K_S 256 pp2048 3272.96 3147.95 0.96
RTX 3090 llama 8B Q5_K_S 512 pp2048 3892.85 3321.21 0.85
RTX 3090 llama 8B Q5_K_S 1024 pp2048 4566.61 3431.34 0.75
RTX 3090 llama 8B Q5_K_S 2048 pp2048 4628.11 3393.07 0.73
RTX 3090 llama 8B Q6_K 16 pp2048 589.97 877.42 1.49
RTX 3090 llama 8B Q6_K 32 pp2048 887.02 1350.14 1.52
RTX 3090 llama 8B Q6_K 64 pp2048 1197.92 1949.14 1.63
RTX 3090 llama 8B Q6_K 128 pp2048 2221.86 2666.57 1.20
RTX 3090 llama 8B Q6_K 256 pp2048 3288.91 3240.44 0.99
RTX 3090 llama 8B Q6_K 512 pp2048 3891.56 3401.65 0.87
RTX 3090 llama 8B Q6_K 1024 pp2048 4551.18 3531.57 0.78
RTX 3090 llama 8B Q6_K 2048 pp2048 4575.31 3451.87 0.75

ggml-cuda/mma.cuh Outdated Show resolved Hide resolved
@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jun 10, 2024
@slaren
Copy link
Collaborator

slaren commented Jun 10, 2024

Looks very good. Should MMQ be the default again?

GPU Model Microbatch size Test t/s master cuBLAS t/s cuda-ptx-mma-12 Speedup
RTX 3090 Ti 7B Q4_K_M 16 pp1024 763.74 1106.51 1.45
RTX 3090 Ti 7B Q4_K_M 32 pp1024 1098.05 1824.08 1.66
RTX 3090 Ti 7B Q4_K_M 64 pp1024 1521.28 2423.70 1.59
RTX 3090 Ti 7B Q4_K_M 128 pp1024 2437.95 3051.37 1.25
RTX 3090 Ti 7B Q4_K_M 256 pp1024 3651.28 3724.66 1.02
RTX 3090 Ti 7B Q4_K_M 512 pp1024 4290.68 4208.15 0.98
RTX 3090 Ti 7B Q4_K_M 1024 pp1024 4730.21 4150.73 0.88
RTX 3090 Ti 7B Q5_K_M 16 pp1024 704.44 986.14 1.40
RTX 3090 Ti 7B Q5_K_M 32 pp1024 998.54 1557.86 1.56
RTX 3090 Ti 7B Q5_K_M 64 pp1024 1345.34 2128.63 1.58
RTX 3090 Ti 7B Q5_K_M 128 pp1024 2404.66 2833.24 1.18
RTX 3090 Ti 7B Q5_K_M 256 pp1024 3574.42 3543.32 0.99
RTX 3090 Ti 7B Q5_K_M 512 pp1024 4259.87 3906.15 0.92
RTX 3090 Ti 7B Q5_K_M 1024 pp1024 4695.34 3871.94 0.82
RTX 3090 Ti 7B Q6_K 16 pp1024 629.23 910.99 1.45
RTX 3090 Ti 7B Q6_K 32 pp1024 985.37 1523.31 1.55
RTX 3090 Ti 7B Q6_K 64 pp1024 1355.43 2203.05 1.63
RTX 3090 Ti 7B Q6_K 128 pp1024 2385.65 3010.48 1.26
RTX 3090 Ti 7B Q6_K 256 pp1024 3566.80 3654.09 1.02
RTX 3090 Ti 7B Q6_K 512 pp1024 4215.07 3999.10 0.95
RTX 3090 Ti 7B Q6_K 1024 pp1024 4617.20 3935.87 0.85

@Green-Sky
Copy link
Collaborator

@JohannesGaessler both your comparison tables are vs master mmq.

@JohannesGaessler
Copy link
Collaborator Author

Looks very good. Should MMQ be the default again?

Give me a bit more time to implement q2_K and q3_K and to optimize performance (particularly asynchronous data loading). Then I think MMQ will be universally faster. Also, in case you're not aware, int8 tensor cores are only available with Ampere (rather than Volta). So for V100s FP16 cuBLAS should still be the fastest option.

@JohannesGaessler both your comparison tables are vs master mmq.

Only for batch sizes 16, 32, and 64; I compared vs. with/without LLMA_CUDA_FORCE_MMQ and currently even without that option MMQ is used for batch sizes <= 64 on master.

@mofosyne mofosyne added the Review Complexity : High Generally require indepth knowledge of LLMs or GPUs label Jun 11, 2024
@ggerganov
Copy link
Owner

For BS >= 512 F16 cuBLAS is still faster even with tensor cores, is that correct?

@JohannesGaessler JohannesGaessler merged commit bdcb8f4 into ggerganov:master Jun 11, 2024
19 checks passed
@JohannesGaessler
Copy link
Collaborator Author

It depends on the quantization format and hardware; q4_0 on an RTX 4090 seems to be the best case scenario where it seems to already be faster even for large batch sizes.

FP16 cuBLAS:

model size params backend ngl n_ubatch fa test t/s
llama 8B Q4_0 4.33 GiB 8.03 B CUDA 99 512 1 pp4096 7430.94 ± 0.00
llama 8B Q4_0 4.33 GiB 8.03 B CUDA 99 1024 1 pp4096 8655.73 ± 0.00
llama 8B Q4_0 4.33 GiB 8.03 B CUDA 99 2048 1 pp4096 8798.88 ± 0.00
llama 8B Q4_0 4.33 GiB 8.03 B CUDA 99 4096 1 pp4096 8790.74 ± 0.00

int8 tensor core MMQ:

model size params backend ngl n_ubatch fa test t/s
llama 8B Q4_0 4.33 GiB 8.03 B CUDA 99 512 1 pp4096 9719.23 ± 0.00
llama 8B Q4_0 4.33 GiB 8.03 B CUDA 99 1024 1 pp4096 9692.96 ± 0.00
llama 8B Q4_0 4.33 GiB 8.03 B CUDA 99 2048 1 pp4096 9094.65 ± 0.00
llama 8B Q4_0 4.33 GiB 8.03 B CUDA 99 4096 1 pp4096 9115.52 ± 0.00

@sorasoras
Copy link

@JohannesGaessler Are you plan to do the same for IQ quants? It would be nice to run Int8 on my P40 instead of FP32. IQ quant has been very slow on that card

@JohannesGaessler
Copy link
Collaborator Author

I will prioritize the quantization formats that already have MMQ implementations (legacy, k-quants) but long-term I plan to also implement kernels for the other quantization formats.

@Dampfinchen
Copy link

Dampfinchen commented Jun 17, 2024

Benchmark on Turing (RTX 2060, FA, batch size default, 4096 context, q4_k_s)

Cublas

llama_print_timings:        load time =    3416.67 ms
llama_print_timings:      sample time =      21.96 ms /   180 runs   (    0.12 ms per token,  8195.23 tokens per second)
llama_print_timings: prompt eval time =    4760.33 ms /  3671 tokens (    1.30 ms per token,   771.17 tokens per second)
llama_print_timings:        eval time =    4922.92 ms /   179 runs   (   27.50 ms per token,    36.36 tokens per second)
llama_print_timings:       total time =    9861.65 ms /  3850 tokens

Force_MMQ

llama_print_timings:        load time =    3430.94 ms
llama_print_timings:      sample time =      22.20 ms /   180 runs   (    0.12 ms per token,  8109.57 tokens per second)
llama_print_timings: prompt eval time =    7682.43 ms /  3671 tokens (    2.09 ms per token,   477.84 tokens per second)
llama_print_timings:        eval time =    4949.56 ms /   179 runs   (   27.65 ms per token,    36.16 tokens per second)
llama_print_timings:       total time =   12801.45 ms /  3850 tokens

Looks like MMQ's prompt processing is still quite a bit slower. I've tested this with the most up to date build at the time of this writing.

@JohannesGaessler
Copy link
Collaborator Author

Did you use make or cmake to build the project? As of right now cmake compiles for the wrong CUDA architectures so the int8 tensor cores aren't going to actually be used.

@Dampfinchen
Copy link

Dampfinchen commented Jun 17, 2024

Did you use make or cmake to build the project? As of right now cmake compiles for the wrong CUDA architectures so the int8 tensor cores aren't going to actually be used.

Yep, you are correct. I'm using Cmake-GUI for Windows. Anything I can do to compile it for Turing? Then I might rerun the test.

@JohannesGaessler
Copy link
Collaborator Author

Lines 426 and 428 in CMakeLists.txt, replace 70 with 75.

@Dampfinchen
Copy link

Dampfinchen commented Jun 17, 2024

Lines 426 and 428 in CMakeLists.txt, replace 70 with 75.

llama_print_timings:        load time =    3430.78 ms
llama_print_timings:      sample time =      22.07 ms /   180 runs   (    0.12 ms per token,  8155.13 tokens per second)
llama_print_timings: prompt eval time =    3996.73 ms /  3671 tokens (    1.09 ms per token,   918.50 tokens per second)
llama_print_timings:        eval time =    4912.27 ms /   179 runs   (   27.44 ms per token,    36.44 tokens per second)
llama_print_timings:       total time =    9080.10 ms /  3850 tokens

Yep, that was it. It's faster than cublas now, wow. Great result!!

Also takes around 200 MB less VRAM, which is a great bonus. Thank you for the amazing work again!

@JohannesGaessler
Copy link
Collaborator Author

You were using a batch size of 512, correct?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs Review Complexity : High Generally require indepth knowledge of LLMs or GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants