Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: use tensor cores for MMQ #7676

Merged
merged 4 commits into from
Jun 10, 2024

Conversation

JohannesGaessler
Copy link
Collaborator

This PR aims to add int8 tensor core support for mul_mat_q kernels (legacy quants only for now). The supported hardware will be Turing or newer. So far there is only a prototype for q8_0 which on its own is still slower than FP16 cuBLAS but faster for end-to-end performance because it needs less data conversion. Current performance:

GPU Model Microbatch size Test t/s master t/s PR Speedup
RTX 4090 llama 8B Q8_0 64 pp4096 1688.32 3107.57 1.84
RTX 4090 llama 8B Q8_0 128 pp4096 3208.06 5515.37 1.72
RTX 4090 llama 8B Q8_0 256 pp4096 5297.47 8062.99 1.52
RTX 4090 llama 8B Q8_0 512 pp4096 7095.77 9458.58 1.33
RTX 4090 llama 8B Q8_0 1024 pp4096 8322.32 9533.70 1.15
RTX 4090 llama 8B Q8_0 2048 pp4096 8594.14 9131.38 1.06
RTX 4090 llama 8B Q8_0 4096 pp4096 8593.86 9127.01 1.06

As of right now this PR must be compiled with LLAMA_CUDA_FORCE_MMQ. scripts/copare_llama_bench.py needs the fix added by #7673 .

The way to make int8 tensor cores work is to write PTX code (the CUDA equivalent of assembly) because with the "high level" WMMA interface you do not have a defined memory layout which makes it impossible to correctly apply the scales of ggml quantized data blocks. I plan to wrap the PTX code in simple CUDA functions in order to hopefully make it easier to understand what it does.

@mofosyne mofosyne added the Review Complexity : High Generally require indepth knowledge of LLMs or GPUs label May 31, 2024
@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels May 31, 2024
@slaren
Copy link
Collaborator

slaren commented May 31, 2024

model fa test master cuBLAS t/s PR MMQ t/s speedup
llama 7B Q8_0 1 pp32 462.62 ± 0.62 1179.08 ± 5.90 2.548
llama 7B Q8_0 1 pp64 1406.62 ± 3.72 2249.44 ± 9.03 1.599
llama 7B Q8_0 1 pp128 2475.62 ± 3.91 4036.57 ± 14.50 1.630
llama 7B Q8_0 1 pp256 3921.04 ± 7.08 5005.20 ± 7.61 1.276
llama 7B Q8_0 1 pp512 4816.36 ± 8.50 5237.32 ± 11.06 1.087
llama 7B Q8_0 1 pp1024 5702.84 ± 4.58 5483.00 ± 9.20 0.961
llama 7B Q8_0 1 pp2048 5768.32 ± 7.25 5367.63 ± 13.85 0.930
llama 7B Q8_0 1 pp4096 5300.56 ± 5.46 4818.36 ± 3.07 0.909

llama-bench still limits the batch size to the value of n_batch, even when it is smaller than n_ubatch. The default n_batch is 2048, so this may be why you are seeing the same performance for 2048 and 4096. It should probably be smarter about that, or at least use the value after it is adjusted by llama.cpp.

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Jun 8, 2024

I implemented support for q4_0, q4_1, q5_0, q5_1, and q8_0 based on #7824 . The performance currently looks like this:

Vs. master MMQ
GPU Model Microbatch size Test t/s master t/s cuda-ptx-mma-7 Speedup
RTX 4090 llama 8B Q4_0 16 pp512 1601.20 1937.21 1.21
RTX 4090 llama 8B Q4_0 32 pp512 2620.80 3427.61 1.31
RTX 4090 llama 8B Q4_0 64 pp512 3704.68 5420.06 1.46
RTX 4090 llama 8B Q4_0 128 pp512 4991.25 7443.66 1.49
RTX 4090 llama 8B Q4_0 256 pp512 5833.07 9601.53 1.65
RTX 4090 llama 8B Q4_0 512 pp512 6279.91 10358.75 1.65
RTX 4090 llama 8B Q4_1 16 pp512 1558.18 1888.20 1.21
RTX 4090 llama 8B Q4_1 32 pp512 2580.88 3252.02 1.26
RTX 4090 llama 8B Q4_1 64 pp512 3685.14 5778.21 1.57
RTX 4090 llama 8B Q4_1 128 pp512 4809.24 6765.41 1.41
RTX 4090 llama 8B Q4_1 256 pp512 5606.55 9152.41 1.63
RTX 4090 llama 8B Q4_1 512 pp512 5853.33 9795.29 1.67
RTX 4090 llama 8B Q5_0 16 pp512 1252.35 1520.90 1.21
RTX 4090 llama 8B Q5_0 32 pp512 1988.05 2564.28 1.29
RTX 4090 llama 8B Q5_0 64 pp512 2916.94 4413.28 1.51
RTX 4090 llama 8B Q5_0 128 pp512 4224.09 6479.58 1.53
RTX 4090 llama 8B Q5_0 256 pp512 5138.20 8278.94 1.61
RTX 4090 llama 8B Q5_0 512 pp512 5773.74 9434.57 1.63
RTX 4090 llama 8B Q5_1 16 pp512 1300.38 1669.12 1.28
RTX 4090 llama 8B Q5_1 32 pp512 2157.43 2692.43 1.25
RTX 4090 llama 8B Q5_1 64 pp512 2989.86 4581.81 1.53
RTX 4090 llama 8B Q5_1 128 pp512 4229.24 6192.37 1.46
RTX 4090 llama 8B Q5_1 256 pp512 5015.76 8026.92 1.60
RTX 4090 llama 8B Q5_1 512 pp512 5542.11 9070.31 1.64
RTX 4090 llama 8B Q8_0 16 pp512 1035.64 1134.04 1.10
RTX 4090 llama 8B Q8_0 32 pp512 1793.27 2216.12 1.24
RTX 4090 llama 8B Q8_0 64 pp512 2980.14 4140.20 1.39
RTX 4090 llama 8B Q8_0 128 pp512 4323.77 6573.77 1.52
RTX 4090 llama 8B Q8_0 256 pp512 5308.85 8961.02 1.69
RTX 4090 llama 8B Q8_0 512 pp512 5888.12 10139.99 1.72
RTX 3090 llama 8B Q4_0 16 pp512 897.63 1178.88 1.31
RTX 3090 llama 8B Q4_0 32 pp512 1194.11 1752.72 1.47
RTX 3090 llama 8B Q4_0 64 pp512 1655.52 2470.20 1.49
RTX 3090 llama 8B Q4_0 128 pp512 1938.21 3227.85 1.67
RTX 3090 llama 8B Q4_0 256 pp512 2374.01 3892.61 1.64
RTX 3090 llama 8B Q4_0 512 pp512 2483.79 4148.35 1.67
RTX 3090 llama 8B Q4_1 16 pp512 883.36 1301.24 1.47
RTX 3090 llama 8B Q4_1 32 pp512 1180.58 1847.37 1.56
RTX 3090 llama 8B Q4_1 64 pp512 1544.92 2390.60 1.55
RTX 3090 llama 8B Q4_1 128 pp512 1851.64 3130.67 1.69
RTX 3090 llama 8B Q4_1 256 pp512 2238.07 3646.03 1.63
RTX 3090 llama 8B Q4_1 512 pp512 2363.05 4068.86 1.72
RTX 3090 llama 8B Q5_0 16 pp512 618.30 850.13 1.37
RTX 3090 llama 8B Q5_0 32 pp512 879.04 1367.69 1.56
RTX 3090 llama 8B Q5_0 64 pp512 1278.90 1951.02 1.53
RTX 3090 llama 8B Q5_0 128 pp512 1683.43 2743.75 1.63
RTX 3090 llama 8B Q5_0 256 pp512 2144.71 3518.48 1.64
RTX 3090 llama 8B Q5_0 512 pp512 2256.40 3654.78 1.62
RTX 3090 llama 8B Q5_1 16 pp512 686.36 1014.38 1.48
RTX 3090 llama 8B Q5_1 32 pp512 894.08 1530.24 1.71
RTX 3090 llama 8B Q5_1 64 pp512 1298.23 1986.79 1.53
RTX 3090 llama 8B Q5_1 128 pp512 1638.68 2667.84 1.63
RTX 3090 llama 8B Q5_1 256 pp512 2053.43 3355.29 1.63
RTX 3090 llama 8B Q5_1 512 pp512 2190.26 3493.64 1.60
RTX 3090 llama 8B Q8_0 16 pp512 575.23 825.82 1.44
RTX 3090 llama 8B Q8_0 32 pp512 934.07 1437.37 1.54
RTX 3090 llama 8B Q8_0 64 pp512 1357.47 2242.88 1.65
RTX 3090 llama 8B Q8_0 128 pp512 1760.07 3132.53 1.78
RTX 3090 llama 8B Q8_0 256 pp512 2168.84 3958.13 1.82
RTX 3090 llama 8B Q8_0 512 pp512 2359.58 4185.76 1.77
Vs. master cuBLAS
GPU Model Microbatch size Test t/s master t/s cuda-ptx-mma-7 Speedup
RTX 4090 llama 8B Q4_0 16 pp512 1596.75 1937.21 1.21
RTX 4090 llama 8B Q4_0 32 pp512 2621.55 3427.61 1.31
RTX 4090 llama 8B Q4_0 64 pp512 3704.48 5420.06 1.46
RTX 4090 llama 8B Q4_0 128 pp512 3578.31 7443.66 2.08
RTX 4090 llama 8B Q4_0 256 pp512 5838.52 9601.53 1.64
RTX 4090 llama 8B Q4_0 512 pp512 7776.11 10358.75 1.33
RTX 4090 llama 8B Q4_1 16 pp512 1558.45 1888.20 1.21
RTX 4090 llama 8B Q4_1 32 pp512 2578.79 3252.02 1.26
RTX 4090 llama 8B Q4_1 64 pp512 3683.87 5778.21 1.57
RTX 4090 llama 8B Q4_1 128 pp512 3540.69 6765.41 1.91
RTX 4090 llama 8B Q4_1 256 pp512 5787.13 9152.41 1.58
RTX 4090 llama 8B Q4_1 512 pp512 7671.16 9795.29 1.28
RTX 4090 llama 8B Q5_0 16 pp512 1251.55 1520.90 1.22
RTX 4090 llama 8B Q5_0 32 pp512 1987.14 2564.28 1.29
RTX 4090 llama 8B Q5_0 64 pp512 2914.51 4413.28 1.51
RTX 4090 llama 8B Q5_0 128 pp512 3477.98 6479.58 1.86
RTX 4090 llama 8B Q5_0 256 pp512 5726.53 8278.94 1.45
RTX 4090 llama 8B Q5_0 512 pp512 7677.37 9434.57 1.23
RTX 4090 llama 8B Q5_1 16 pp512 1296.52 1669.12 1.29
RTX 4090 llama 8B Q5_1 32 pp512 2152.98 2692.43 1.25
RTX 4090 llama 8B Q5_1 64 pp512 2983.69 4581.81 1.54
RTX 4090 llama 8B Q5_1 128 pp512 3490.47 6192.37 1.77
RTX 4090 llama 8B Q5_1 256 pp512 5703.59 8026.92 1.41
RTX 4090 llama 8B Q5_1 512 pp512 7632.39 9070.31 1.19
RTX 4090 llama 8B Q8_0 16 pp512 1035.42 1134.04 1.10
RTX 4090 llama 8B Q8_0 32 pp512 1793.94 2216.12 1.24
RTX 4090 llama 8B Q8_0 64 pp512 2981.81 4140.20 1.39
RTX 4090 llama 8B Q8_0 128 pp512 3376.19 6573.77 1.95
RTX 4090 llama 8B Q8_0 256 pp512 5575.87 8961.02 1.61
RTX 4090 llama 8B Q8_0 512 pp512 7511.04 10139.99 1.35
RTX 3090 llama 8B Q4_0 16 pp512 899.84 1178.88 1.31
RTX 3090 llama 8B Q4_0 32 pp512 1188.91 1752.72 1.47
RTX 3090 llama 8B Q4_0 64 pp512 1653.41 2470.20 1.49
RTX 3090 llama 8B Q4_0 128 pp512 2347.86 3227.85 1.37
RTX 3090 llama 8B Q4_0 256 pp512 3516.58 3892.61 1.11
RTX 3090 llama 8B Q4_0 512 pp512 4094.19 4148.35 1.01
RTX 3090 llama 8B Q4_1 16 pp512 877.33 1301.24 1.48
RTX 3090 llama 8B Q4_1 32 pp512 1169.76 1847.37 1.58
RTX 3090 llama 8B Q4_1 64 pp512 1517.09 2390.60 1.58
RTX 3090 llama 8B Q4_1 128 pp512 2323.90 3130.67 1.35
RTX 3090 llama 8B Q4_1 256 pp512 3469.03 3646.03 1.05
RTX 3090 llama 8B Q4_1 512 pp512 4024.70 4068.86 1.01
RTX 3090 llama 8B Q5_0 16 pp512 621.64 850.13 1.37
RTX 3090 llama 8B Q5_0 32 pp512 873.34 1367.69 1.57
RTX 3090 llama 8B Q5_0 64 pp512 1257.98 1951.02 1.55
RTX 3090 llama 8B Q5_0 128 pp512 2182.46 2743.75 1.26
RTX 3090 llama 8B Q5_0 256 pp512 3286.80 3518.48 1.07
RTX 3090 llama 8B Q5_0 512 pp512 3968.35 3654.78 0.92
RTX 3090 llama 8B Q5_1 16 pp512 681.66 1014.38 1.49
RTX 3090 llama 8B Q5_1 32 pp512 876.71 1530.24 1.75
RTX 3090 llama 8B Q5_1 64 pp512 1275.26 1986.79 1.56
RTX 3090 llama 8B Q5_1 128 pp512 2161.81 2667.84 1.23
RTX 3090 llama 8B Q5_1 256 pp512 3264.62 3355.29 1.03
RTX 3090 llama 8B Q5_1 512 pp512 3923.75 3493.64 0.89
RTX 3090 llama 8B Q8_0 16 pp512 571.46 825.82 1.45
RTX 3090 llama 8B Q8_0 32 pp512 918.53 1437.37 1.56
RTX 3090 llama 8B Q8_0 64 pp512 1335.22 2242.88 1.68
RTX 3090 llama 8B Q8_0 128 pp512 2201.87 3132.53 1.42
RTX 3090 llama 8B Q8_0 256 pp512 3327.89 3958.13 1.19
RTX 3090 llama 8B Q8_0 512 pp512 3979.69 4185.76 1.05

My immediate next goals will be to add support for k-quants, optimize the performance, and to refactor and simplify the code (in that order).

Copy link
Contributor

github-actions bot commented Jun 8, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 534 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8758.38ms p(95)=19966.03ms fails=, finish reason: stop=481 truncated=53
  • Prompt processing (pp): avg=104.27tk/s p(95)=461.66tk/s
  • Token generation (tg): avg=60.6tk/s p(95)=45.2tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=cuda-ptx-mma-2 commit=a64a81a2946bffa8f108fd3476565fccb885820e

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 534 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1718006321 --> 1718006943
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 674.06, 674.06, 674.06, 674.06, 674.06, 706.38, 706.38, 706.38, 706.38, 706.38, 741.14, 741.14, 741.14, 741.14, 741.14, 815.77, 815.77, 815.77, 815.77, 815.77, 833.93, 833.93, 833.93, 833.93, 833.93, 829.04, 829.04, 829.04, 829.04, 829.04, 843.94, 843.94, 843.94, 843.94, 843.94, 857.95, 857.95, 857.95, 857.95, 857.95, 855.41, 855.41, 855.41, 855.41, 855.41, 867.7, 867.7, 867.7, 867.7, 867.7, 889.86, 889.86, 889.86, 889.86, 889.86, 914.3, 914.3, 914.3, 914.3, 914.3, 915.14, 915.14, 915.14, 915.14, 915.14, 907.72, 907.72, 907.72, 907.72, 907.72, 911.09, 911.09, 911.09, 911.09, 911.09, 912.19, 912.19, 912.19, 912.19, 912.19, 905.44, 905.44, 905.44, 905.44, 905.44, 902.95, 902.95, 902.95, 902.95, 902.95, 903.64, 903.64, 903.64, 903.64, 903.64, 905.38, 905.38, 905.38, 905.38, 905.38, 909.89, 909.89, 909.89, 909.89, 909.89, 908.33, 908.33, 908.33, 908.33, 908.33, 909.73, 909.73, 909.73, 909.73, 909.73, 901.42, 901.42, 901.42, 901.42, 901.42, 901.32, 901.32, 901.32, 901.32, 901.32, 901.68, 901.68, 901.68, 901.68, 901.68, 918.04, 918.04, 918.04, 918.04, 918.04, 914.37, 914.37, 914.37, 914.37, 914.37, 912.69, 912.69, 912.69, 912.69, 912.69, 914.88, 914.88, 914.88, 914.88, 914.88, 917.96, 917.96, 917.96, 917.96, 917.96, 916.34, 916.34, 916.34, 916.34, 916.34, 915.94, 915.94, 915.94, 915.94, 915.94, 922.69, 922.69, 922.69, 922.69, 922.69, 926.1, 926.1, 926.1, 926.1, 926.1, 932.12, 932.12, 932.12, 932.12, 932.12, 923.32, 923.32, 923.32, 923.32, 923.32, 920.04, 920.04, 920.04, 920.04, 920.04, 919.87, 919.87, 919.87, 919.87, 919.87, 921.24, 921.24, 921.24, 921.24, 921.24, 920.65, 920.65, 920.65, 920.65, 920.65, 926.82, 926.82, 926.82, 926.82, 926.82, 899.08, 899.08, 899.08, 899.08, 899.08, 897.76, 897.76, 897.76, 897.76, 897.76, 896.71, 896.71, 896.71, 896.71, 896.71, 894.91, 894.91, 894.91, 894.91, 894.91, 887.41, 887.41, 887.41, 887.41, 887.41, 888.06, 888.06, 888.06, 888.06, 888.06, 889.46, 889.46, 889.46, 889.46, 889.46, 888.47, 888.47, 888.47, 888.47, 888.47, 895.06, 895.06, 895.06, 895.06, 895.06, 893.04, 893.04, 893.04, 893.04, 893.04, 895.59, 895.59, 895.59, 895.59, 895.59, 898.4, 898.4, 898.4, 898.4, 898.4, 895.17, 895.17, 895.17, 895.17, 895.17, 888.65, 888.65, 888.65, 888.65, 888.65, 888.79, 888.79, 888.79, 888.79, 888.79, 889.14, 889.14, 889.14, 889.14, 889.14, 889.98, 889.98, 889.98, 889.98, 889.98, 890.22, 890.22, 890.22, 890.22, 890.22, 892.27, 892.27]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 534 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1718006321 --> 1718006943
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 49.14, 49.14, 49.14, 49.14, 49.14, 33.94, 33.94, 33.94, 33.94, 33.94, 32.5, 32.5, 32.5, 32.5, 32.5, 35.12, 35.12, 35.12, 35.12, 35.12, 34.8, 34.8, 34.8, 34.8, 34.8, 34.51, 34.51, 34.51, 34.51, 34.51, 34.85, 34.85, 34.85, 34.85, 34.85, 35.07, 35.07, 35.07, 35.07, 35.07, 35.2, 35.2, 35.2, 35.2, 35.2, 34.79, 34.79, 34.79, 34.79, 34.79, 34.09, 34.09, 34.09, 34.09, 34.09, 33.94, 33.94, 33.94, 33.94, 33.94, 32.9, 32.9, 32.9, 32.9, 32.9, 32.81, 32.81, 32.81, 32.81, 32.81, 31.77, 31.77, 31.77, 31.77, 31.77, 31.22, 31.22, 31.22, 31.22, 31.22, 29.83, 29.83, 29.83, 29.83, 29.83, 30.22, 30.22, 30.22, 30.22, 30.22, 30.0, 30.0, 30.0, 30.0, 30.0, 29.99, 29.99, 29.99, 29.99, 29.99, 30.25, 30.25, 30.25, 30.25, 30.25, 30.26, 30.26, 30.26, 30.26, 30.26, 30.34, 30.34, 30.34, 30.34, 30.34, 30.45, 30.45, 30.45, 30.45, 30.45, 30.52, 30.52, 30.52, 30.52, 30.52, 30.76, 30.76, 30.76, 30.76, 30.76, 30.81, 30.81, 30.81, 30.81, 30.81, 30.56, 30.56, 30.56, 30.56, 30.56, 30.61, 30.61, 30.61, 30.61, 30.61, 30.92, 30.92, 30.92, 30.92, 30.92, 30.95, 30.95, 30.95, 30.95, 30.95, 31.02, 31.02, 31.02, 31.02, 31.02, 31.06, 31.06, 31.06, 31.06, 31.06, 31.23, 31.23, 31.23, 31.23, 31.23, 31.12, 31.12, 31.12, 31.12, 31.12, 30.97, 30.97, 30.97, 30.97, 30.97, 30.97, 30.97, 30.97, 30.97, 30.97, 30.93, 30.93, 30.93, 30.93, 30.93, 30.94, 30.94, 30.94, 30.94, 30.94, 31.08, 31.08, 31.08, 31.08, 31.08, 31.18, 31.18, 31.18, 31.18, 31.18, 31.32, 31.32, 31.32, 31.32, 31.32, 31.27, 31.27, 31.27, 31.27, 31.27, 31.05, 31.05, 31.05, 31.05, 31.05, 30.86, 30.86, 30.86, 30.86, 30.86, 30.67, 30.67, 30.67, 30.67, 30.67, 29.28, 29.28, 29.28, 29.28, 29.28, 28.93, 28.93, 28.93, 28.93, 28.93, 28.9, 28.9, 28.9, 28.9, 28.9, 28.92, 28.92, 28.92, 28.92, 28.92, 29.0, 29.0, 29.0, 29.0, 29.0, 29.09, 29.09, 29.09, 29.09, 29.09, 29.1, 29.1, 29.1, 29.1, 29.1, 29.02, 29.02, 29.02, 29.02, 29.02, 29.08, 29.08, 29.08, 29.08, 29.08, 29.03, 29.03, 29.03, 29.03, 29.03, 29.07, 29.07, 29.07, 29.07, 29.07, 29.21, 29.21, 29.21, 29.21, 29.21, 29.39, 29.39, 29.39, 29.39, 29.39, 29.45, 29.45, 29.45, 29.45, 29.45, 29.54, 29.54]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 534 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1718006321 --> 1718006943
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.21, 0.21, 0.21, 0.21, 0.21, 0.22, 0.22, 0.22, 0.22, 0.22, 0.12, 0.12, 0.12, 0.12, 0.12, 0.18, 0.18, 0.18, 0.18, 0.18, 0.26, 0.26, 0.26, 0.26, 0.26, 0.11, 0.11, 0.11, 0.11, 0.11, 0.2, 0.2, 0.2, 0.2, 0.2, 0.18, 0.18, 0.18, 0.18, 0.18, 0.17, 0.17, 0.17, 0.17, 0.17, 0.2, 0.2, 0.2, 0.2, 0.2, 0.24, 0.24, 0.24, 0.24, 0.24, 0.29, 0.29, 0.29, 0.29, 0.29, 0.11, 0.11, 0.11, 0.11, 0.11, 0.39, 0.39, 0.39, 0.39, 0.39, 0.43, 0.43, 0.43, 0.43, 0.43, 0.31, 0.31, 0.31, 0.31, 0.31, 0.2, 0.2, 0.2, 0.2, 0.2, 0.11, 0.11, 0.11, 0.11, 0.11, 0.31, 0.31, 0.31, 0.31, 0.31, 0.14, 0.14, 0.14, 0.14, 0.14, 0.26, 0.26, 0.26, 0.26, 0.26, 0.15, 0.15, 0.15, 0.15, 0.15, 0.18, 0.18, 0.18, 0.18, 0.18, 0.29, 0.29, 0.29, 0.29, 0.29, 0.13, 0.13, 0.13, 0.13, 0.13, 0.17, 0.17, 0.17, 0.17, 0.17, 0.32, 0.32, 0.32, 0.32, 0.32, 0.25, 0.25, 0.25, 0.25, 0.25, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.16, 0.16, 0.16, 0.16, 0.16, 0.24, 0.24, 0.24, 0.24, 0.24, 0.14, 0.14, 0.14, 0.14, 0.14, 0.17, 0.17, 0.17, 0.17, 0.17, 0.1, 0.1, 0.1, 0.1, 0.1, 0.15, 0.15, 0.15, 0.15, 0.15, 0.25, 0.25, 0.25, 0.25, 0.25, 0.24, 0.24, 0.24, 0.24, 0.24, 0.13, 0.13, 0.13, 0.13, 0.13, 0.16, 0.16, 0.16, 0.16, 0.16, 0.14, 0.14, 0.14, 0.14, 0.14, 0.29, 0.29, 0.29, 0.29, 0.29, 0.5, 0.5, 0.5, 0.5, 0.5, 0.59, 0.59, 0.59, 0.59, 0.59, 0.72, 0.72, 0.72, 0.72, 0.72, 0.74, 0.74, 0.74, 0.74, 0.74, 0.31, 0.31, 0.31, 0.31, 0.31, 0.2, 0.2, 0.2, 0.2, 0.2, 0.21, 0.21, 0.21, 0.21, 0.21, 0.18, 0.18, 0.18, 0.18, 0.18, 0.2, 0.2, 0.2, 0.2, 0.2, 0.16, 0.16, 0.16, 0.16, 0.16, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.26, 0.26, 0.26, 0.26, 0.26, 0.16, 0.16, 0.16, 0.16, 0.16, 0.13, 0.13, 0.13, 0.13, 0.13, 0.14, 0.14, 0.14, 0.14, 0.14, 0.16, 0.16, 0.16, 0.16, 0.16, 0.11, 0.11, 0.11, 0.11, 0.11, 0.19, 0.19]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 534 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1718006321 --> 1718006943
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 2.0, 2.0, 2.0, 2.0, 2.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 1.0, 1.0, 1.0, 1.0, 1.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 2.0, 2.0, 2.0, 2.0, 2.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 2.0, 2.0, 2.0, 2.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0]
                    
Loading

@JohannesGaessler
Copy link
Collaborator Author

There was a bug with out-of-bounds writes. That's why the server bench performance was bad in terms of request throughput: the generations were garbage and never hit the EOS token.


static __device__ __forceinline__ int get_j(const int /* l */) {
const int ret = threadIdx.x / (K/2);
__builtin_assume(ret >= 0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a GGML_CUDA_ASSUME macro because this is not available on every version of the compiler.

@slaren
Copy link
Collaborator

slaren commented Jun 10, 2024

A bit slower than f16 cuBLAS with large batch sizes, but well worth it for the lower batch sizes and memory savings.

GPU Model Microbatch size Test t/s master t/s cuda-ptx-mma-2 Speedup
RTX 3090 Ti 7B Q4_0 16 pp1024 904.24 1115.29 1.23
RTX 3090 Ti 7B Q4_0 32 pp1024 1263.72 1779.10 1.41
RTX 3090 Ti 7B Q4_0 64 pp1024 1698.78 2492.78 1.47
RTX 3090 Ti 7B Q4_0 128 pp1024 2445.14 3264.99 1.34
RTX 3090 Ti 7B Q4_0 256 pp1024 3657.37 4039.07 1.10
RTX 3090 Ti 7B Q4_0 512 pp1024 4321.33 4576.51 1.06
RTX 3090 Ti 7B Q4_0 1024 pp1024 4756.03 4485.33 0.94
RTX 3090 Ti 7B Q4_1 16 pp1024 898.22 1209.24 1.35
RTX 3090 Ti 7B Q4_1 32 pp1024 1253.04 1900.19 1.52
RTX 3090 Ti 7B Q4_1 64 pp1024 1654.76 2399.77 1.45
RTX 3090 Ti 7B Q4_1 128 pp1024 2435.10 3173.25 1.30
RTX 3090 Ti 7B Q4_1 256 pp1024 3597.73 3870.62 1.08
RTX 3090 Ti 7B Q4_1 512 pp1024 4239.37 4472.32 1.05
RTX 3090 Ti 7B Q4_1 1024 pp1024 4722.81 4390.55 0.93
RTX 3090 Ti 7B Q5_0 16 pp1024 655.17 857.68 1.31
RTX 3090 Ti 7B Q5_0 32 pp1024 955.18 1452.40 1.52
RTX 3090 Ti 7B Q5_0 64 pp1024 1394.30 2073.34 1.49
RTX 3090 Ti 7B Q5_0 128 pp1024 2354.17 2845.91 1.21
RTX 3090 Ti 7B Q5_0 256 pp1024 3518.24 3667.58 1.04
RTX 3090 Ti 7B Q5_0 512 pp1024 4188.13 4032.22 0.96
RTX 3090 Ti 7B Q5_0 1024 pp1024 4662.08 3988.33 0.86
RTX 3090 Ti 7B Q5_1 16 pp1024 722.31 1014.57 1.40
RTX 3090 Ti 7B Q5_1 32 pp1024 981.44 1643.96 1.68
RTX 3090 Ti 7B Q5_1 64 pp1024 1417.79 2075.32 1.46
RTX 3090 Ti 7B Q5_1 128 pp1024 2334.16 2767.17 1.19
RTX 3090 Ti 7B Q5_1 256 pp1024 3487.31 3571.26 1.02
RTX 3090 Ti 7B Q5_1 512 pp1024 4173.30 3915.44 0.94
RTX 3090 Ti 7B Q5_1 1024 pp1024 4648.29 3866.20 0.83

@ggerganov
Copy link
Owner

There was a bug with out-of-bounds writes. That's why the server bench performance was bad in terms of request throughput: the generations were garbage and never hit the EOS token.

Btw, the server bench still produces unexpectedly low number of iterations - 283 in last run, with 212 of them being truncated. Maybe there is still some lingering issue

@JohannesGaessler
Copy link
Collaborator Author

Do you mean the run that the bot posted in this PR? That was prior to the fix. I was able to reproduce the issue when running the server benchmark locally and my fix worked to restore the performance in terms of iterations/time.

@ggerganov
Copy link
Owner

The bot updates the post after each new successful commit. See the edit in this comment from ~hour ago:

#7676 (comment)

@JohannesGaessler
Copy link
Collaborator Author

The bot updates the post after each new successful commit. See the edit in this comment from ~hour ago:

Thank you, I wasn't aware that that is how the bot works. There seems to have still been an issue where (for some matrix shapes) the writeback returned too early. And because the exact kernel that is being run depends on the SM count of a GPU I presumably just never encountered one of the problematic matrix shapes while testing.

@JohannesGaessler JohannesGaessler merged commit 1f0dabd into ggerganov:master Jun 10, 2024
59 of 72 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs Review Complexity : High Generally require indepth knowledge of LLMs or GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants