Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: refactor and optimize IQ MMVQ #8215

Merged
merged 6 commits into from
Jul 1, 2024

Conversation

JohannesGaessler
Copy link
Collaborator

This PR refactors and optimizes the IQ MMVQ CUDA code. Notably as part of these changes I'm changing some values in ggml-common.h. The "qr" values are meant to represent how many low bit data values are contained in a single 8 bit integer. This value is used to derive "qi" which represents how many 32 bit integers are needed to represent the low bit data values of a quantized block. These values are intended to be properties of the data type independent of any kernels.

In MMVQ qr and qi are used to determine how one load of 32 integers for the quantized weights needs to be aligned with the loads of the q8 activations. It is oftentimes beneficial to load more values at once which is intended to be done via the "vdr" value which is a factor that increases the number of simultaneous loads so that the total stride per invocation of vec_dot_q_cuda is qr*vdr. However, for the IQ quants this was instead done by increasing QR. This does not matter for MMVQ but it's a problem for MMQ where the values of qr and qi matter for determining how much shared memory needs to be allocated and how the activations need to be loaded. So for this reason I'm changing the qr and qi values of the IQ quants to the originally intended values. Notably this affects the SYCL backend but I am not able to test the corresponding changes myself due to a lack of Intel hardware. @arthw @airMeng I don't know who to tag in terms of llama.cpp SYCL developers; please either test my changes or tell me who I should contact.

Performance changes
Model GPU Microbatch size Test t/s master t/s 59f5fe5 Speedup
llama 8B IQ1_M - 1.75 bpw RX 6800 1 pp512 53.40 49.22 0.92
llama 8B IQ1_M - 1.75 bpw RX 6800 2 pp512 85.65 86.32 1.01
llama 8B IQ1_M - 1.75 bpw RX 6800 4 pp512 113.24 112.06 0.99
llama 8B IQ1_M - 1.75 bpw RX 6800 8 pp512 163.77 137.24 0.84
llama 8B IQ1_M - 1.75 bpw RTX 3090 1 pp512 157.19 167.03 1.06
llama 8B IQ1_M - 1.75 bpw RTX 3090 2 pp512 257.59 271.74 1.05
llama 8B IQ1_M - 1.75 bpw RTX 3090 4 pp512 369.37 396.47 1.07
llama 8B IQ1_M - 1.75 bpw RTX 3090 8 pp512 509.12 577.37 1.13
llama 8B IQ1_M - 1.75 bpw RTX 4090 1 pp512 303.46 315.04 1.04
llama 8B IQ1_M - 1.75 bpw RTX 4090 2 pp512 491.38 505.95 1.03
llama 8B IQ1_M - 1.75 bpw RTX 4090 4 pp512 768.49 800.65 1.04
llama 8B IQ1_M - 1.75 bpw RTX 4090 8 pp512 1104.59 1199.08 1.09
llama 8B IQ1_M - 1.75 bpw P40 1 pp512 51.24 55.00 1.07
llama 8B IQ1_M - 1.75 bpw P40 2 pp512 59.62 62.14 1.04
llama 8B IQ1_M - 1.75 bpw P40 4 pp512 79.55 87.26 1.10
llama 8B IQ1_M - 1.75 bpw P40 8 pp512 100.85 123.60 1.23
llama 8B IQ1_S - 1.5625 bpw RX 6800 1 pp512 55.56 54.13 0.97
llama 8B IQ1_S - 1.5625 bpw RX 6800 2 pp512 89.01 92.03 1.03
llama 8B IQ1_S - 1.5625 bpw RX 6800 4 pp512 119.90 143.32 1.20
llama 8B IQ1_S - 1.5625 bpw RX 6800 8 pp512 163.32 177.29 1.09
llama 8B IQ1_S - 1.5625 bpw RTX 3090 1 pp512 171.96 182.08 1.06
llama 8B IQ1_S - 1.5625 bpw RTX 3090 2 pp512 294.86 321.29 1.09
llama 8B IQ1_S - 1.5625 bpw RTX 3090 4 pp512 400.41 454.23 1.13
llama 8B IQ1_S - 1.5625 bpw RTX 3090 8 pp512 547.64 613.82 1.12
llama 8B IQ1_S - 1.5625 bpw RTX 4090 1 pp512 314.54 324.53 1.03
llama 8B IQ1_S - 1.5625 bpw RTX 4090 2 pp512 532.28 536.75 1.01
llama 8B IQ1_S - 1.5625 bpw RTX 4090 4 pp512 802.89 869.56 1.08
llama 8B IQ1_S - 1.5625 bpw RTX 4090 8 pp512 1135.76 1215.14 1.07
llama 8B IQ1_S - 1.5625 bpw P40 1 pp512 54.77 60.16 1.10
llama 8B IQ1_S - 1.5625 bpw P40 2 pp512 62.83 65.50 1.04
llama 8B IQ1_S - 1.5625 bpw P40 4 pp512 83.18 91.52 1.10
llama 8B IQ1_S - 1.5625 bpw P40 8 pp512 103.75 101.55 0.98
llama 8B IQ2_M - 2.7 bpw RX 6800 1 pp512 37.80 40.21 1.06
llama 8B IQ2_M - 2.7 bpw RX 6800 2 pp512 65.40 71.74 1.10
llama 8B IQ2_M - 2.7 bpw RX 6800 4 pp512 89.33 108.04 1.21
llama 8B IQ2_M - 2.7 bpw RX 6800 8 pp512 118.42 138.37 1.17
llama 8B IQ2_M - 2.7 bpw RTX 3090 1 pp512 144.67 154.45 1.07
llama 8B IQ2_M - 2.7 bpw RTX 3090 2 pp512 247.07 257.73 1.04
llama 8B IQ2_M - 2.7 bpw RTX 3090 4 pp512 355.66 364.56 1.03
llama 8B IQ2_M - 2.7 bpw RTX 3090 8 pp512 523.70 538.26 1.03
llama 8B IQ2_M - 2.7 bpw RTX 4090 1 pp512 256.04 256.11 1.00
llama 8B IQ2_M - 2.7 bpw RTX 4090 2 pp512 436.39 434.25 1.00
llama 8B IQ2_M - 2.7 bpw RTX 4090 4 pp512 691.87 693.52 1.00
llama 8B IQ2_M - 2.7 bpw RTX 4090 8 pp512 1149.50 1154.73 1.00
llama 8B IQ2_M - 2.7 bpw P40 1 pp512 50.45 53.80 1.07
llama 8B IQ2_M - 2.7 bpw P40 2 pp512 55.44 58.77 1.06
llama 8B IQ2_M - 2.7 bpw P40 4 pp512 80.85 83.89 1.04
llama 8B IQ2_M - 2.7 bpw P40 8 pp512 119.35 120.94 1.01
llama 8B IQ2_S - 2.5 bpw RX 6800 1 pp512 46.08 46.31 1.00
llama 8B IQ2_S - 2.5 bpw RX 6800 2 pp512 75.10 81.27 1.08
llama 8B IQ2_S - 2.5 bpw RX 6800 4 pp512 96.45 123.49 1.28
llama 8B IQ2_S - 2.5 bpw RX 6800 8 pp512 116.69 140.50 1.20
llama 8B IQ2_S - 2.5 bpw RTX 3090 1 pp512 146.70 151.47 1.03
llama 8B IQ2_S - 2.5 bpw RTX 3090 2 pp512 255.82 251.69 0.98
llama 8B IQ2_S - 2.5 bpw RTX 3090 4 pp512 378.84 377.85 1.00
llama 8B IQ2_S - 2.5 bpw RTX 3090 8 pp512 535.84 535.17 1.00
llama 8B IQ2_S - 2.5 bpw RTX 4090 1 pp512 274.49 276.89 1.01
llama 8B IQ2_S - 2.5 bpw RTX 4090 2 pp512 469.03 454.41 0.97
llama 8B IQ2_S - 2.5 bpw RTX 4090 4 pp512 776.19 769.60 0.99
llama 8B IQ2_S - 2.5 bpw RTX 4090 8 pp512 1184.92 1168.13 0.99
llama 8B IQ2_S - 2.5 bpw P40 1 pp512 46.43 46.90 1.01
llama 8B IQ2_S - 2.5 bpw P40 2 pp512 57.69 58.24 1.01
llama 8B IQ2_S - 2.5 bpw P40 4 pp512 81.95 86.35 1.05
llama 8B IQ2_S - 2.5 bpw P40 8 pp512 117.88 117.48 1.00
llama 8B IQ2_XS - 2.3125 bpw RX 6800 1 pp512 48.85 47.87 0.98
llama 8B IQ2_XS - 2.3125 bpw RX 6800 2 pp512 85.05 82.40 0.97
llama 8B IQ2_XS - 2.3125 bpw RX 6800 4 pp512 113.93 122.77 1.08
llama 8B IQ2_XS - 2.3125 bpw RX 6800 8 pp512 142.12 137.76 0.97
llama 8B IQ2_XS - 2.3125 bpw RTX 3090 1 pp512 150.82 154.55 1.02
llama 8B IQ2_XS - 2.3125 bpw RTX 3090 2 pp512 261.42 256.79 0.98
llama 8B IQ2_XS - 2.3125 bpw RTX 3090 4 pp512 387.68 382.61 0.99
llama 8B IQ2_XS - 2.3125 bpw RTX 3090 8 pp512 539.88 536.58 0.99
llama 8B IQ2_XS - 2.3125 bpw RTX 4090 1 pp512 283.86 285.14 1.00
llama 8B IQ2_XS - 2.3125 bpw RTX 4090 2 pp512 483.41 468.26 0.97
llama 8B IQ2_XS - 2.3125 bpw RTX 4090 4 pp512 799.44 793.42 0.99
llama 8B IQ2_XS - 2.3125 bpw RTX 4090 8 pp512 1181.49 1165.17 0.99
llama 8B IQ2_XS - 2.3125 bpw P40 1 pp512 48.15 48.11 1.00
llama 8B IQ2_XS - 2.3125 bpw P40 2 pp512 60.15 60.02 1.00
llama 8B IQ2_XS - 2.3125 bpw P40 4 pp512 83.81 88.44 1.06
llama 8B IQ2_XS - 2.3125 bpw P40 8 pp512 119.43 117.26 0.98
llama 8B IQ2_XXS - 2.0625 bpw RX 6800 1 pp512 51.04 42.05 0.82
llama 8B IQ2_XXS - 2.0625 bpw RX 6800 2 pp512 70.78 76.72 1.08
llama 8B IQ2_XXS - 2.0625 bpw RX 6800 4 pp512 90.88 122.41 1.35
llama 8B IQ2_XXS - 2.0625 bpw RX 6800 8 pp512 108.26 147.99 1.37
llama 8B IQ2_XXS - 2.0625 bpw RTX 3090 1 pp512 108.48 160.28 1.48
llama 8B IQ2_XXS - 2.0625 bpw RTX 3090 2 pp512 173.38 269.81 1.56
llama 8B IQ2_XXS - 2.0625 bpw RTX 3090 4 pp512 203.23 381.82 1.88
llama 8B IQ2_XXS - 2.0625 bpw RTX 3090 8 pp512 258.22 558.16 2.16
llama 8B IQ2_XXS - 2.0625 bpw RTX 4090 1 pp512 238.67 303.88 1.27
llama 8B IQ2_XXS - 2.0625 bpw RTX 4090 2 pp512 378.23 494.32 1.31
llama 8B IQ2_XXS - 2.0625 bpw RTX 4090 4 pp512 470.38 723.03 1.54
llama 8B IQ2_XXS - 2.0625 bpw RTX 4090 8 pp512 633.18 1195.46 1.89
llama 8B IQ2_XXS - 2.0625 bpw P40 1 pp512 29.94 46.71 1.56
llama 8B IQ2_XXS - 2.0625 bpw P40 2 pp512 40.33 54.33 1.35
llama 8B IQ2_XXS - 2.0625 bpw P40 4 pp512 40.90 81.67 2.00
llama 8B IQ2_XXS - 2.0625 bpw P40 8 pp512 37.94 116.89 3.08
llama 8B IQ3_S - 3.4375 bpw RX 6800 1 pp512 33.10 38.16 1.15
llama 8B IQ3_S - 3.4375 bpw RX 6800 2 pp512 40.60 74.50 1.83
llama 8B IQ3_S - 3.4375 bpw RX 6800 4 pp512 45.80 123.68 2.70
llama 8B IQ3_S - 3.4375 bpw RX 6800 8 pp512 48.94 151.62 3.10
llama 8B IQ3_S - 3.4375 bpw RTX 3090 1 pp512 131.03 143.17 1.09
llama 8B IQ3_S - 3.4375 bpw RTX 3090 2 pp512 221.41 234.53 1.06
llama 8B IQ3_S - 3.4375 bpw RTX 3090 4 pp512 332.33 350.82 1.06
llama 8B IQ3_S - 3.4375 bpw RTX 3090 8 pp512 501.54 526.60 1.05
llama 8B IQ3_S - 3.4375 bpw RTX 4090 1 pp512 220.21 221.34 1.01
llama 8B IQ3_S - 3.4375 bpw RTX 4090 2 pp512 368.67 369.48 1.00
llama 8B IQ3_S - 3.4375 bpw RTX 4090 4 pp512 664.61 662.90 1.00
llama 8B IQ3_S - 3.4375 bpw RTX 4090 8 pp512 1122.96 1144.50 1.02
llama 8B IQ3_S - 3.4375 bpw P40 1 pp512 39.70 44.13 1.11
llama 8B IQ3_S - 3.4375 bpw P40 2 pp512 47.98 52.10 1.09
llama 8B IQ3_S - 3.4375 bpw P40 4 pp512 74.74 79.60 1.07
llama 8B IQ3_S - 3.4375 bpw P40 8 pp512 111.09 115.68 1.04
llama 8B IQ3_S mix - 3.66 bpw RX 6800 1 pp512 34.07 38.58 1.13
llama 8B IQ3_S mix - 3.66 bpw RX 6800 2 pp512 42.81 73.89 1.73
llama 8B IQ3_S mix - 3.66 bpw RX 6800 4 pp512 48.75 118.86 2.44
llama 8B IQ3_S mix - 3.66 bpw RX 6800 8 pp512 52.12 142.37 2.73
llama 8B IQ3_S mix - 3.66 bpw RTX 3090 1 pp512 133.40 144.93 1.09
llama 8B IQ3_S mix - 3.66 bpw RTX 3090 2 pp512 223.57 235.91 1.06
llama 8B IQ3_S mix - 3.66 bpw RTX 3090 4 pp512 331.13 347.12 1.05
llama 8B IQ3_S mix - 3.66 bpw RTX 3090 8 pp512 485.86 504.17 1.04
llama 8B IQ3_S mix - 3.66 bpw RTX 4090 1 pp512 216.78 217.79 1.00
llama 8B IQ3_S mix - 3.66 bpw RTX 4090 2 pp512 365.57 368.64 1.01
llama 8B IQ3_S mix - 3.66 bpw RTX 4090 4 pp512 661.20 663.19 1.00
llama 8B IQ3_S mix - 3.66 bpw RTX 4090 8 pp512 1089.92 1113.22 1.02
llama 8B IQ3_S mix - 3.66 bpw P40 1 pp512 40.91 45.28 1.11
llama 8B IQ3_S mix - 3.66 bpw P40 2 pp512 49.27 53.23 1.08
llama 8B IQ3_S mix - 3.66 bpw P40 4 pp512 76.73 81.46 1.06
llama 8B IQ3_S mix - 3.66 bpw P40 8 pp512 111.40 115.70 1.04
llama 8B IQ3_XS - 3.3 bpw RX 6800 1 pp512 39.95 42.88 1.07
llama 8B IQ3_XS - 3.3 bpw RX 6800 2 pp512 56.50 80.17 1.42
llama 8B IQ3_XS - 3.3 bpw RX 6800 4 pp512 69.21 127.19 1.84
llama 8B IQ3_XS - 3.3 bpw RX 6800 8 pp512 75.96 154.91 2.04
llama 8B IQ3_XS - 3.3 bpw RTX 3090 1 pp512 131.71 142.08 1.08
llama 8B IQ3_XS - 3.3 bpw RTX 3090 2 pp512 225.58 236.88 1.05
llama 8B IQ3_XS - 3.3 bpw RTX 3090 4 pp512 348.83 360.71 1.03
llama 8B IQ3_XS - 3.3 bpw RTX 3090 8 pp512 511.77 527.72 1.03
llama 8B IQ3_XS - 3.3 bpw RTX 4090 1 pp512 229.77 231.39 1.01
llama 8B IQ3_XS - 3.3 bpw RTX 4090 2 pp512 391.05 392.01 1.00
llama 8B IQ3_XS - 3.3 bpw RTX 4090 4 pp512 710.10 714.86 1.01
llama 8B IQ3_XS - 3.3 bpw RTX 4090 8 pp512 1141.69 1152.60 1.01
llama 8B IQ3_XS - 3.3 bpw P40 1 pp512 39.44 42.72 1.08
llama 8B IQ3_XS - 3.3 bpw P40 2 pp512 50.67 50.65 1.00
llama 8B IQ3_XS - 3.3 bpw P40 4 pp512 73.72 80.02 1.09
llama 8B IQ3_XS - 3.3 bpw P40 8 pp512 110.53 112.94 1.02
llama 8B IQ3_XXS - 3.0625 bpw RX 6800 1 pp512 46.62 46.60 1.00
llama 8B IQ3_XXS - 3.0625 bpw RX 6800 2 pp512 78.05 83.07 1.06
llama 8B IQ3_XXS - 3.0625 bpw RX 6800 4 pp512 110.96 126.46 1.14
llama 8B IQ3_XXS - 3.0625 bpw RX 6800 8 pp512 131.35 155.94 1.19
llama 8B IQ3_XXS - 3.0625 bpw RTX 3090 1 pp512 135.27 143.46 1.06
llama 8B IQ3_XXS - 3.0625 bpw RTX 3090 2 pp512 233.96 242.70 1.04
llama 8B IQ3_XXS - 3.0625 bpw RTX 3090 4 pp512 360.31 367.51 1.02
llama 8B IQ3_XXS - 3.0625 bpw RTX 3090 8 pp512 519.38 527.16 1.01
llama 8B IQ3_XXS - 3.0625 bpw RTX 4090 1 pp512 236.20 235.58 1.00
llama 8B IQ3_XXS - 3.0625 bpw RTX 4090 2 pp512 404.28 403.91 1.00
llama 8B IQ3_XXS - 3.0625 bpw RTX 4090 4 pp512 722.81 731.84 1.01
llama 8B IQ3_XXS - 3.0625 bpw RTX 4090 8 pp512 1152.84 1162.57 1.01
llama 8B IQ3_XXS - 3.0625 bpw P40 1 pp512 40.21 42.77 1.06
llama 8B IQ3_XXS - 3.0625 bpw P40 2 pp512 51.78 50.33 0.97
llama 8B IQ3_XXS - 3.0625 bpw P40 4 pp512 74.36 79.64 1.07
llama 8B IQ3_XXS - 3.0625 bpw P40 8 pp512 111.46 112.72 1.01
llama 8B IQ4_NL - 4.5 bpw RX 6800 1 pp512 53.35 53.31 1.00
llama 8B IQ4_NL - 4.5 bpw RX 6800 2 pp512 98.53 95.45 0.97
llama 8B IQ4_NL - 4.5 bpw RX 6800 4 pp512 144.67 154.64 1.07
llama 8B IQ4_NL - 4.5 bpw RX 6800 8 pp512 155.23 137.54 0.89
llama 8B IQ4_NL - 4.5 bpw RTX 3090 1 pp512 122.65 122.39 1.00
llama 8B IQ4_NL - 4.5 bpw RTX 3090 2 pp512 213.90 214.49 1.00
llama 8B IQ4_NL - 4.5 bpw RTX 3090 4 pp512 332.52 334.08 1.00
llama 8B IQ4_NL - 4.5 bpw RTX 3090 8 pp512 474.78 473.78 1.00
llama 8B IQ4_NL - 4.5 bpw RTX 4090 1 pp512 186.29 186.63 1.00
llama 8B IQ4_NL - 4.5 bpw RTX 4090 2 pp512 330.35 339.99 1.03
llama 8B IQ4_NL - 4.5 bpw RTX 4090 4 pp512 641.99 641.73 1.00
llama 8B IQ4_NL - 4.5 bpw RTX 4090 8 pp512 1030.30 1037.84 1.01
llama 8B IQ4_NL - 4.5 bpw P40 1 pp512 38.95 39.11 1.00
llama 8B IQ4_NL - 4.5 bpw P40 2 pp512 52.43 52.37 1.00
llama 8B IQ4_NL - 4.5 bpw P40 4 pp512 85.21 84.63 0.99
llama 8B IQ4_NL - 4.5 bpw P40 8 pp512 133.29 133.45 1.00
llama 8B IQ4_XS - 4.25 bpw RX 6800 1 pp512 59.13 58.71 0.99
llama 8B IQ4_XS - 4.25 bpw RX 6800 2 pp512 100.45 96.94 0.97
llama 8B IQ4_XS - 4.25 bpw RX 6800 4 pp512 131.00 135.02 1.03
llama 8B IQ4_XS - 4.25 bpw RX 6800 8 pp512 185.85 167.36 0.90
llama 8B IQ4_XS - 4.25 bpw RTX 3090 1 pp512 128.99 128.28 0.99
llama 8B IQ4_XS - 4.25 bpw RTX 3090 2 pp512 208.93 210.96 1.01
llama 8B IQ4_XS - 4.25 bpw RTX 3090 4 pp512 310.73 309.15 0.99
llama 8B IQ4_XS - 4.25 bpw RTX 3090 8 pp512 480.94 475.20 0.99
llama 8B IQ4_XS - 4.25 bpw RTX 4090 1 pp512 194.54 193.03 0.99
llama 8B IQ4_XS - 4.25 bpw RTX 4090 2 pp512 343.13 343.17 1.00
llama 8B IQ4_XS - 4.25 bpw RTX 4090 4 pp512 637.72 637.37 1.00
llama 8B IQ4_XS - 4.25 bpw RTX 4090 8 pp512 1087.27 1075.96 0.99
llama 8B IQ4_XS - 4.25 bpw P40 1 pp512 36.07 37.07 1.03
llama 8B IQ4_XS - 4.25 bpw P40 2 pp512 40.50 44.89 1.11
llama 8B IQ4_XS - 4.25 bpw P40 4 pp512 68.90 76.90 1.12
llama 8B IQ4_XS - 4.25 bpw P40 8 pp512 113.47 110.05 0.97

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jun 29, 2024
@airMeng
Copy link
Collaborator

airMeng commented Jun 30, 2024

Thank you for tagging me. I'm not quite familiar with IQ_XXX so I tried Llama-3-Lumimaid-8B-v0.1-OAS-IQ4_XS-imat.gguf and this PR works fine. Is the model one of the targets of this PR? After all the code changes looks good for me.

@OuadiElfarouki @luoyu-intel for awareness.

@JohannesGaessler
Copy link
Collaborator Author

Is the model one of the targets of this PR? After all the code changes looks good for me.

iq4_xs as well as all iq2 and iq3 models should be affected. Overall the changes I made to SYCL should be very simple, I just can't test whether they actually work.

@mofosyne mofosyne added the Review Complexity : High Generally require indepth knowledge of LLMs or GPUs label Jun 30, 2024
@slaren
Copy link
Collaborator

slaren commented Jun 30, 2024

Shouldn't the removed code allow the IQ quants to work with GPUs without dp4a? I understand that due to the CC check ggml_cuda_mul_mat this code was actually never used, but if that was fixed it would allow the IQ quants to work rather than crash on old GPUs.

@JohannesGaessler
Copy link
Collaborator Author

If __dp4a is unavailable it would be better to use DMMV. The performance of casting 32 bit integers to arrays of 8 bit integers and doing regular arithmetic with them is quite poor.

@slaren
Copy link
Collaborator

slaren commented Jun 30, 2024

Right, but the problem is that dmmv does not support IQ quants.

Alternatively it can be reported correctly in supports_op that these quants do not work on some hardware. Then these ops will be run (a lot slower) on the CPU.

@JohannesGaessler
Copy link
Collaborator Author

How about this: replace __dp4a with ggml_cuda_dp4a for which you can then implement a workaround if the instruction is unavailable.

@JohannesGaessler
Copy link
Collaborator Author

Performance with a __dp4a workaround was much better than I expected (~90% of regular speed on an RTX 3090) so I made MMVQ the default regardless of compute capability; I'll get someone to test this on a GPU for which the default change would actually make a difference.

@the-crypt-keeper
Copy link

the-crypt-keeper commented Jul 1, 2024

Meta-Llama-3-8B-Instruct-Q6_K.gguf on main

ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: Tesla P100-PCIE-16GB, compute capability 6.0, VMM: yes

model size params backend ngl test t/s
llama 8B Q6_K 6.14 GiB 8.03 B CUDA 99 pp512 511.07 ± 0.00
llama 8B Q6_K 6.14 GiB 8.03 B CUDA 99 tg128 16.01 ± 0.00

build: 0ddeff1 (3273)

Meta-Llama-3-8B-Instruct-Q6_K.gguf on branch

ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: Tesla P100-PCIE-16GB, compute capability 6.0, VMM: yes

model size params backend ngl test t/s
llama 8B Q6_K 6.14 GiB 8.03 B CUDA 99 pp512 512.43 ± 0.00
llama 8B Q6_K 6.14 GiB 8.03 B CUDA 99 tg128 27.05 ± 0.00

build: 30f85eb (3271)

🚀

Meta-Llama-3-8B-Instruct-IQ4_NL.gguf main

GGML_ASSERT: ggml/src/ggml-cuda/dmmv.cu:665: false

Meta-Llama-3-8B-Instruct-IQ4_NL.gguf branch

Device 0: Tesla P100-PCIE-16GB, compute capability 6.0, VMM: yes

model size params backend ngl test t/s
llama 8B IQ4_NL - 4.5 bpw 4.35 GiB 8.03 B CUDA 99 pp512 450.43 ± 0.00
llama 8B IQ4_NL - 4.5 bpw 4.35 GiB 8.03 B CUDA 99 tg128 29.06 ± 0.00

Meta-Llama-3-8B-Instruct-Q8_0.gguf main

Device 0: Tesla P100-PCIE-16GB, compute capability 6.0, VMM: yes

model size params backend ngl test t/s
llama 8B Q8_0 7.95 GiB 8.03 B CUDA 99 pp512 508.32 ± 0.00
llama 8B Q8_0 7.95 GiB 8.03 B CUDA 99 tg128 14.65 ± 0.00

Meta-Llama-3-8B-Instruct-Q8_0.gguf branch

Device 0: Tesla P100-PCIE-16GB, compute capability 6.0, VMM: yes

model size params backend ngl test t/s
llama 8B Q8_0 7.95 GiB 8.03 B CUDA 99 pp512 510.02 ± 0.00
llama 8B Q8_0 7.95 GiB 8.03 B CUDA 99 tg128 35.43 ± 0.00

🚀 🐎

@the-crypt-keeper
Copy link

the-crypt-keeper commented Jul 1, 2024

This PR improved batch performance by almost an order of magnitude at both 2- and 3- streams on my P100 as well:

Meta-Llama-3-8B-Instruct-Q8_0.gguf batch main

main: n_kv_max = 4096, n_batch = 512, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 14, n_threads_batch = 14

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.986 519.24 8.690 14.73 9.676 66.14
512 128 2 1280 1.970 519.86 20.093 12.74 22.063 58.02
512 128 3 1920 2.975 516.30 20.679 18.57 23.654 81.17
512 128 4 2560 3.995 512.66 21.081 24.29 25.075 102.09

(not sure why b=2 got WORSE here but I re-ran several times)

Meta-Llama-3-8B-Instruct-Q8_0.gguf batch branch

main: n_kv_max = 4096, n_batch = 512, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 14, n_threads_batch = 14

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.987 518.75 3.575 35.81 4.562 140.29
512 128 2 1280 1.969 520.15 3.765 68.00 5.733 223.25
512 128 3 1920 2.970 517.10 4.459 86.12 7.430 258.43
512 128 4 2560 3.991 513.13 6.111 83.78 10.102 253.41

b=2 is now double and b=3 is the sweet spot for overall throughput

Meta-Llama-3-8B-Instruct-IQ4_NL.gguf batch branch

main: n_kv_max = 4096, n_batch = 512, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 14, n_threads_batch = 14

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 1.115 459.29 4.367 29.31 5.482 116.74
512 128 2 1280 2.226 460.05 4.546 56.32 6.772 189.03
512 128 3 1920 3.355 457.77 5.315 72.24 8.671 221.44
512 128 4 2560 4.508 454.29 6.278 81.55 10.786 237.34

I am unable to compare with main, same GGML_ASSERT: ggml/src/ggml-cuda/dmmv.cu:665: false pops out.

On branch IQ4 shows same trend as Q8.

@JohannesGaessler JohannesGaessler merged commit cb5fad4 into ggerganov:master Jul 1, 2024
53 checks passed
@duaneking
Copy link

I have a device that CUDA lists as Device: cuda:0 NVIDIA GeForce GTX 1070 : native here in my home lab. Would it be useful?

@JohannesGaessler
Copy link
Collaborator Author

No.

arthw pushed a commit to arthw/llama.cpp that referenced this pull request Jul 2, 2024
* CUDA: refactor and optimize IQ MMVQ

* uint -> uint32_t

* __dp4a -> ggml_cuda_dp4a

* remove MIN_CC_DP4A checks

* change default

* try CI fix
@smcnally
Copy link

smcnally commented Jul 4, 2024

Understood cuda-iq-opt-3 is merged. Is test data from a "Maxwell 2.0" / Compute 5.2 GPU still helpful? llama-simple and llama-benchmark-matmult worked with a Tesla M40. I ran through these before the M40 reached 90oC:

  • codellama-7b-instruct.Q5_K_M.gguf
  • L3-8B-Lunaris-v1.i1-Q5_K_M.gguf
  • Hermes-2-Pro-Llama-3-Instruct-Merged-DPO-Q6_K.gguf
  • alphamonarch-7b.Q6_K.gguf
  • capybarahermes-2.5-mistral-7b.Q5_K_M.gguf (a few variations with this one)

./llama-simple -n 128 -c 4096 -ngl 99 -m /mnt/models/gguf/capybarahermes-2.5-mistral-7b.Q5_K_M.gguf -p "List 5 benefits NVIDIA's Maxwell 2.0 architecture has over Maxwell 1.0. Step through the major benefits CUDA Compute Capability 5.2 has over 5.0."

llama-bench, -batched-bench, -server, and -cli all core dumped. I have detailed files.

@smcnally
Copy link

smcnally commented Jul 4, 2024

Bench and server in ggerganov:master are working great with mixed 5.2 and 6.1

./llama-bench -m llava-v1.6-vicuna-13b.Q4_K_M.gguf 
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 3 CUDA devices:
  Device 0: NVIDIA P104-100, compute capability 6.1, VMM: yes
  Device 1: NVIDIA P106-100, compute capability 6.1, VMM: yes
  Device 2: Tesla M40, compute capability 5.2, VMM: yes
| model                          |       size |     params | backend    | ngl |          test |              t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | ---------------: |
| llama 13B Q4_K - Medium        |   7.33 GiB |    13.02 B | CUDA       |  99 |         pp512 |    193.24 ± 0.85 |
| llama 13B Q4_K - Medium        |   7.33 GiB |    13.02 B | CUDA       |  99 |         tg128 |     16.67 ± 0.31 |

build: f6190247 (3291)

llama-bench, -batched-bench, -server, and -cli all core dumped.

These are all working in master

@JohannesGaessler
Copy link
Collaborator Author

Understood cuda-iq-opt-3 is merged. Is test data from a "Maxwell 2.0" / Compute 5.2 GPU still helpful?

More data is always helpful. If it turns out that some changes in this PR were bad they can potentially be reverted.

llama-bench, -batched-bench, -server, and -cli all core dumped. I have detailed files.

They are all working in master.

Just to be clear, do you mean that they work on master prior to or after this PR?

@smcnally
Copy link

smcnally commented Jul 4, 2024

Just to be clear, do you mean that they work on master prior to or after this PR?

these work in master after this PR was merged.

./llama-batched-bench --version
version: 3291 (f619024)

I'm running more tests against the M40 on its own now. Will gather and share more

./llama-bench -m llava-v1.6-vicuna-13b.Q4_K_M.gguf -o md -fa 1 -ngl 41 -b 32 -ub 32 -p 32 -t 20
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: Tesla M40, compute capability 5.2, VMM: yes

model size params backend ngl threads n_batch n_ubatch fa test t/s
llama 13B Q4_K - Medium 7.33 GiB 13.02 B CUDA 41 20 32 32 1 pp32 40.85 ± 0.03
llama 13B Q4_K - Medium 7.33 GiB 13.02 B CUDA 41 20 32 32 1 tg128 11.80 ± 4.84

build: f619024 (3291)

@smcnally
Copy link

smcnally commented Jul 4, 2024

I've run llama-bench and llama-batched-bench several times against several models and only the Tesla M40 available. Verbosity is available.

./llama-batched-bench --version
version: 3291 (f619024)

Is running these same against earlier pre-merge builds helpful?

./llama-batched-bench -m llava-v1.6-vicuna-13b.Q4_K_M.gguf -c 4096 -b 32 -ub 512 -npp 128,256 -ntg 128,256 -npl 1,2,4 -ngl 99 --flash-attn

ggml_cuda_init: found 1 CUDA devices:
  Device 0: Tesla M40, compute capability 5.2, VMM: yes

main: n_kv_max = 4096, n_batch = 32, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = 99, n_threads = 12, n_threads_batch = 12

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
128 128 1 256 3.188 40.15 8.055 15.89 11.243 22.77
128 128 2 512 6.289 40.71 10.941 23.40 17.230 29.72
128 128 4 1024 12.582 40.69 18.982 26.97 31.563 32.44
128 256 1 384 3.152 40.60 16.725 15.31 19.878 19.32
128 256 2 768 6.308 40.58 23.176 22.09 29.485 26.05
128 256 4 1536 14.487 35.34 110.389 9.28 124.876 12.30
256 128 1 384 18.898 13.55 33.370 3.84 52.268 7.35
256 128 2 768 42.960 11.92 41.464 6.17 84.424 9.10
256 128 4 1536 90.833 11.27 66.340 7.72 157.173 9.77
256 256 1 512 21.503 11.91 70.335 3.64 91.838 5.58

@JohannesGaessler
Copy link
Collaborator Author

I don't need the numbers in isolation. I only would have wanted to know whether there is a performance regression since I changed one of the defaults. But if it didn't work prior to the PR anyways there is no point.

@smcnally
Copy link

smcnally commented Jul 5, 2024

These are llama-bench runs built against ggerganov:master tags/b3266 (pre-merge cuda-iq-opt-3 build: 1c5eba6 (3266)) and post-merge build: f619024 (3291)

  • Hathor-L3-8B-v.01-Q5_K_M-imat.gguf
  • replete-coder-llama3-8b-iq4_nl-imat.gguf
  • llava-v1.6-vicuna-13b.Q4_K_M.gguf

The -bench run times are much better in the new builds. I don't see huge t/s deltas. replete-coder core dumps on 3266.

build: 1c5eba6 (3266) - Hathor-L3-8B-v.01-Q5_K_M-imat.gguf

time ./llama-bench -m /mnt/models/gguf/Hathor-L3-8B-v.01-Q5_K_M-imat.gguf -t 20 -fa 1 -ngl 99 -b 512 -ub 512
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: Tesla M40, compute capability 5.2, VMM: yes

model size params backend ngl threads n_batch fa test t/s
llama 8B Q5_K - Medium 5.33 GiB 8.03 B CUDA 99 20 512 1 pp512 249.11 ± 0.89
llama 8B Q5_K - Medium 5.33 GiB 8.03 B CUDA 99 20 512 1 tg128 13.10 ± 0.15

build: 1c5eba6 (3266)

real	1m9.391s
user	1m8.492s
sys	0m0.877s

build: 213701b (3324) - Hathor-L3-8B-v.01-Q5_K_M-imat.gguf

time ./llama-bench -m /mnt/models/gguf/Hathor-L3-8B-v.01-Q5_K_M-imat.gguf -t 20 -fa 1 -ngl 99 -b 512 -ub 512
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: Tesla M40, compute capability 5.2, VMM: yes

model size params backend ngl threads n_batch fa test t/s
llama 8B Q5_K - Medium 5.33 GiB 8.03 B CUDA 99 20 512 1 pp512 250.31 ± 0.69
llama 8B Q5_K - Medium 5.33 GiB 8.03 B CUDA 99 20 512 1 tg128 21.77 ± 0.07

build: 213701b (3324)

real	0m56.774s
user	0m49.289s
sys	0m2.607s

build: 1c5eba6 (3266) - replete-coder-llama3-8b-iq4_nl-imat.gguf

GGML_ASSERT: ggml/src/ggml-cuda/dmmv.cu:665: false
Could not attach to process.  If your uid matches the uid of the target
process, check the setting of /proc/sys/kernel/yama/ptrace_scope, or try
again as the root user.  For more details, see /etc/sysctl.d/10-ptrace.conf
ptrace: Operation not permitted.
No stack.
The program is not being run.
Aborted (core dumped)

build: 213701b (3324) - replete-coder-llama3-8b-iq4_nl-imat.gguf

time ./llama-bench -m /mnt/models/gguf/replete-coder-llama3-8b-iq4_nl-imat.gguf -t 20 -fa 1 -ngl 99 -b 512 -ub 512
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: Tesla M40, compute capability 5.2, VMM: yes

model size params backend ngl threads n_batch fa test t/s
llama 8B IQ4_NL - 4.5 bpw 4.35 GiB 8.03 B CUDA 99 20 512 1 pp512 225.20 ± 1.97
llama 8B IQ4_NL - 4.5 bpw 4.35 GiB 8.03 B CUDA 99 20 512 1 tg128 20.40 ± 0.17

build: 213701b (3324)

real	0m58.026s
user	0m51.406s
sys	0m2.245s

build: 1c5eba6 (3266) - llava-v1.6-vicuna-13b.Q4_K_M.gguf

time ./llama-bench -m /mnt/models/gguf/llava-v1.6-vicuna-13b.Q4_K_M.gguf -t 20 -fa 1 -ngl 99 -b 512 -ub 512
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: Tesla M40, compute capability 5.2, VMM: yes

model size params backend ngl threads n_batch fa test t/s
llama 13B Q4_K - Medium 7.33 GiB 13.02 B CUDA 99 20 512 1 pp512 127.27 ± 0.32
llama 13B Q4_K - Medium 7.33 GiB 13.02 B CUDA 99 20 512 1 tg128 9.05 ± 0.16

build: 1c5eba6 (3266)

real	2m2.376s
user	1m45.276s
sys	0m5.418s

build: 213701b (3324) - llava-v1.6-vicuna-13b.Q4_K_M.gguf

time ./llama-bench -m /mnt/models/gguf/llava-v1.6-vicuna-13b.Q4_K_M.gguf -t 20 -fa 1 -ngl 99 -b 512 -ub 512
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: Tesla M40, compute capability 5.2, VMM: yes

model size params backend ngl threads n_batch fa test t/s
llama 13B Q4_K - Medium 7.33 GiB 13.02 B CUDA 99 20 512 1 pp512 127.88 ± 0.23
llama 13B Q4_K - Medium 7.33 GiB 13.02 B CUDA 99 20 512 1 tg128 14.08 ± 0.21

build: 213701b (3324)

real	1m28.005s
user	1m18.770s
sys	0m3.974s

@JohannesGaessler
Copy link
Collaborator Author

Thanks, those numbers look good.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs Review Complexity : High Generally require indepth knowledge of LLMs or GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants