Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: Quantized matrix matrix multiplication #2160

Merged

Conversation

JohannesGaessler
Copy link
Collaborator

After #2043 and #2067 I've tried implementing a matrix matrix multiplication kernel using integer intrinsics (currently only q4_0). The results are mixed:

GPU Model Test t/s master t/s PR Speedup
RTX 3090 7b q4_0 pp 1042 924 0.89
RTX 3090 13b q4_0 pp 677 533 0.79
RTX 3090 33b q4_0 pp 303 233 0.77
P40 7b q4_0 pp 426 516 1.21
P40 13b q4_0 pp 247 289 1.17
P40 33b q4_0 pp 104 124 1.19

On my RTX 3090 the new kernel is slower but on my P40 it's faster. Since matrix matrix multiplications are compute bound I suspect that the reason is that (unlike the P40) the RTX 3090 is capable of executing floating point and integer (i.e. pointer) arithmetic in parallel. So using integer intrinsics instead of floating point operations leaves the floating point hardware underutilized. However, the same GPUs that can execute floating point and integer arithmetic in parallel also have tensor cores which should be much faster anyways so I think that this is not a problem long-term.

Due to the use of shared memory the implementation has gotten rather ugly; using the structs for quantized data in shared memory had terrible performance (most likely due to memory bank conflicts) so the current implementation dissects the data into quants and scales. This also has the unfortunate side effect of tying the allocation of shared memory closely to the quantization type. I very much do not want to implement an entire matrix matrix multiplication kernel for each quantization type but creating a good template could also be tricky; I'll need to think about it some more.

@JohannesGaessler JohannesGaessler changed the title Cuda matrix matrix 6 CUDA: Quantized matrix matrix multiplication Jul 9, 2023
@slaren
Copy link
Collaborator

slaren commented Jul 10, 2023

I think this looks good. I imagine that you are already planning on doing this, but as long cuBLAS may be faster, it would be good to have the option to use it. But I think that at this performance level we could already use this by default.

@JohannesGaessler
Copy link
Collaborator Author

My current plan is to implement matrix vector kernels based on integer intrinsics for k-quants and then try to come up with a good way to create a template for matrix matrix multiplication. Ideally using tensor cores won't be too difficult to add then and this PR can be merged as a universal upgrade (but still with the option to use cuBLAS).

@cmp-nct
Copy link
Contributor

cmp-nct commented Jul 11, 2023

those results are stunning!

@JohannesGaessler
Copy link
Collaborator Author

After looking into it I think tensor cores should get their own kernel since the use of shared memory will need to be different. Instead I've done some performance optimizations:

GPU Model Test t/s master t/s PR Speedup
RTX 3090 7b q4_0 pp 1049 1050 1.00
RTX 3090 13b q4_0 pp 676 625 0.92
RTX 3090 33b q4_0 pp 298 275 0.92
P40 7b q4_0 pp 425 629 1.48
P40 13b q4_0 pp 244 360 1.48
P40 33b q4_0 pp 103 157 1.52

I think this performance would be good enough to merge and use as the default; it's possible that the performance for other quantization formats will be worse though.

@JohannesGaessler
Copy link
Collaborator Author

I pushed a template for matrix matrix multiplication. The template accepts three functions:

  1. a function that allocates the shared memory to hold the tile for the quantized matrix,
  2. a function that loads data from global memory to the aforementioned tile, and
  3. a function that does the actual matrix multiplication.

@abc-nix
Copy link

abc-nix commented Jul 15, 2023

While I was testing your new changes (now including your last commit) on RTX 3060 12GB, I found an error while testing tulu-13B-q4_0 (43 layers offloaded) during the PP step that doesn't occur in master.

When I process less than 512 tokens (510t), I get:
CUDA error 700 at ggml-cuda.cu:3609: an illegal memory access was encountered
When I process a bit more (533 tokens), I get this error:
CUDA error 700 at ggml-cuda.cu:3609: an illegal memory access was encountered

I tested with pure llama-13b-q4_0 and I don't get this error, so maybe it is just this specific llama finetune. I will download others and test later.

Related to PP speed results, I cannot continue to use tulu, but what I got so far for a prompt 533 tokens long:

Model quant layers main PP (t/s) PR PP (t/s)
7B q4_K_M 35/35 537.18 487.31
7B q4_0 35/35 505.31 524.49
13B q4_K_M 43/43 289.92 274.90
13B q4_0 43/43 276.84 error
30B q4_K_M 29/63 58.96 57.44
30B q4_0 29/63 57.16 60.16

For llama-13B q4_0, for the same prompt and same 43 layers offloaded to the RTX 3060 12GB:

main PP (533t): 279.55 tokens/s

PR PP (533t): 295.80 tokens/s

@JohannesGaessler
Copy link
Collaborator Author

With base llama there are out-of-bounds memory accesses as well; I just haven't gotten around to fixing them yet because they randomly don't matter for my testing.

@cmp-nct
Copy link
Contributor

cmp-nct commented Jul 15, 2023

This looks so great, I can't wait to get rid of cuBLAS entirely.
Maybe once a 8_0 variant is available it could already be used as a full cuBLAS replacement without sacrificing quality in as good as all scenarios?
After all, for cublas we currently live-convert/dequantize the inputs anyway to 32/16. Doing that down to just 8 would be quite an improvement already.

@JohannesGaessler
Copy link
Collaborator Author

The out-of-bounds memory accesses should be fixed now.

@abc-nix
Copy link

abc-nix commented Jul 15, 2023

I am still getting a similar error. I will paste the output in case it is helpful. No layers are being offloaded, just to see if there was a problem with the amount of layers I was offloading to the GPU.

~/llama-cuda  cuda-matrix-matrix-6  llama-run 
main: build = 850 (a3b096b)
main: seed  = 1689449834
ggml_init_cublas: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3060, compute capability 8.6
llama.cpp: loading model from /models/tulu-13B-GGML/tulu-13b.ggmlv3.q4_0.bin
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32001
llama_model_load_internal: n_ctx      = 2048
llama_model_load_internal: n_embd     = 5120
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 40
llama_model_load_internal: n_layer    = 40
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 13824
llama_model_load_internal: model size = 13B
llama_model_load_internal: ggml ctx size =    0.09 MB
llama_model_load_internal: using CUDA for GPU acceleration
llama_model_load_internal: mem required  = 9031.71 MB (+ 1608.00 MB per state)
llama_model_load_internal: offloading 0 repeating layers to GPU
llama_model_load_internal: offloaded 0/43 layers to GPU
llama_model_load_internal: total VRAM used: 480 MB
llama_new_context_with_model: kv self size  = 1600.00 MB

system_info: n_threads = 4 / 4 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | VSX = 0 | 
sampling: repeat_last_n = 256, repeat_penalty = 1.176470, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.500000, typical_p = 1.000000, temp = 0.000000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000
generate: n_ctx = 2048, n_batch = 512, n_predict = -1, n_keep = 477


 USER: The following input text is a fragment of the description of a github project. Rewrite it correcting all grammar, punctuation, and spelling errors. Also make the text flow better so that anyone can understand what it is about.

### INPUT START ###
# GGML - Large Language Models for Everyone

[GGML](https://github.com/ggerganov/ggml) is a C library for machine learning
(ML) - the "GG" refers to the initials of its originator
([Georgi Gerganov](https://ggerganov.com/)). In addition to defining low-level
machine learning primitives (like a [tensor](#weights) type), GGML defines a
binary format for distributing large language models (LLMs) This crate provides
Rust [bindings](sys) into the reference implementation of GGML, as well as a
collection of [native](src) Rust helpers to provide safe, idiomatic access to
those bindings. GGML makes use of a technique called
"[quantization](<https://en.wikipedia.org/wiki/Quantization_(signal_processing)>)"
that allows for large language models to run on consumer hardware. This
documents describes the basics of the GGML format, including how
[quantization](#quantization) is used to democratize access to LLMs.

## Format

GGML files consists of binary-encoded data that is laid out according to a
specified format. The format specifies what kind of data is present in the file,
how it is represented, and the order in which it appears. The first piece of
information present in a valid GGML file is a GGML version number, followed by
three components that define a large language model: the model's
[hyperparameters](#hyperparameters), its [vocabulary](#vocabulary), and its
[weights](#weights). Continue reading to learn more about GGML versions and the
components of a GGML model.
### INPUT END ###

ASSISTANT:CUDA error 700 at ggml-cuda.cu:3618: an illegal memory access was encountered

I understand if you are not interested in this particular error and want to concentrate on the matrix x matrix multiplication code first. Also, this is a model with 32001 n_vocab, which also adds extra complexity to the already difficult task you are undertaking.

@JohannesGaessler
Copy link
Collaborator Author

I implemented all of the older quantization formats and rebased onto master. These are the results:

GPU Model Test t/s master t/s PR Speedup
RTX 3090 7b q4_0 pp 1292 1266 0.98
RTX 3090 7b q4_1 pp 1283 1263 0.98
RTX 3090 7b q5_0 pp 1277 647 0.51
RTX 3090 7b q5_1 pp 1269 634 0.50
RTX 3090 7b q8_0 pp 1273 1039 0.82
P40 7b q4_0 pp 461 627 1.36
P40 7b q4_1 pp 459 368 0.80
P40 7b q5_0 pp 460 245 0.53
P40 7b q5_1 pp 451 254 0.56
P40 7b q8_0 pp 453 532 1.17

Using cuBLAS the performance is very consistent. The performance of the new kernels varies a lot depending on quantization format. I think the fundamental reason is that for matrix matrix multiplication you are compute bound and I/O is much less important. So it can be faster to dequantize the entire weight matrix once and then work with f16/f32 values which you can multiply directly (and which is comparatively fast on GPUs) than it is to do multiple integer/logical operations on the quantized data to get the result.

Notably, the performance for q5_0 and q5_1 is bad, presumably because the 5th bits are ordered in an inconvenient way that requires 4 bit shifts, bit-wise ANDs, and bit-wise ORs. This could be reduced to a single bit shift, bit-wise AND, and bit-wise OR (could be done when the weights are loaded into VRAM if the different bit order is bad for CPU performance). The performance of q4_1 on the P40 is also relatively bad, the problem being that for good performance some f16 calculations are necessary which have bad performance on the P40 (a f32 workaround is used instead).

Overall I expect the performance for q2_k, q3_k, and q5_k to not be good due to the large number of operations per data value that will be necessary.

Caveats: the performance can probably still be optimized a lot. In particular I think there is still potential to optimize memory bank conflicts and tile sizes. For Ampere or newer it's also possible to utilize asynchronous data loading.

@ggerganov @slaren What is your judgement regarding prompt processing speed vs. VRAM usage? I personally would prefer lower VRAM usage as the default because I find prompt processing to be fast enough either way. Also my findings may apply to a lesser extent to CPU matrix matrix multiplication as well. CPUs generally have comparatively fast integer arithmetic though and I think the memory bandwidth is a much more severe bottleneck for CPUs than for GPUs.

@slaren
Copy link
Collaborator

slaren commented Jul 16, 2023

I think that for llama.cpp it is reasonable to prefer VRAM usage over prompt processing speed by default. In some cases, such as summarization or code completion, prompt processing may be more important. When the backends interface is completed, this could be an option selected at runtime when initializing the ggml-cuda backend (if the binary was built with cuBLAS support enabled).

@ggerganov
Copy link
Owner

Looks like great progress so far.

What is your judgement regarding prompt processing speed vs. VRAM usage?

It's hard to say since prompt processing has it's applications and is also important during perplexity computations, so there is probably no right answer. But, I still think that we can achieve at least parity with cuBLAS performance when using quantized mmm for all types. So, I'm inclined to say that we can accept the pp speed regression in some cases now and hope we will solve it eventually.

If we don't succeed, then probably we'll provide options as @slaren suggested.

@cmp-nct
Copy link
Contributor

cmp-nct commented Jul 16, 2023

I ran a test on falcon 7 and 40 (7 with MATRIX_ROW_PADDING to 64, 40 with default)
Performance was quite good but the result tensors all filled with zeros
I tried a few different input shapes by changing batch sizes, nothing ever resulted in an actual non zero output.
I assume it's not yet generic enough or only ready for specific shapes ?

Here is the shape of a 7B multiplication:

+======================+======================+======================+======================+
| DST QQ:0
| qkv=W_qkv*cur                    [f32 type]
+----------------------+----------------------+----------------------+----------------------+
| Dimensions           | Strides              | Layer id             | Backend              |
| 2                    | 4x18688              | 29                   | CPU                  |
+----------------------+----------------------+----------------------+----------------------+
| Elements             | Src0                 | Src1                 | Operation            |
| 4672 x 64            | 4544 x 4672          | 4544 x 64            | MUL_MAT              |
+----------------------+----------------------+----------------------+----------------------+
| Transposed:      No  | Permuted:        No  | Contiguous:      Yes | Size:        1.14 MB |
| Src0 name:           | transformer.h.29.self_attention.query_key_value.weight             |
| Src1 name:           | inpFF                                                              |
+----------------------+----------------------+----------------------+----------------------+

Switching manually to cublas instead of the new mat_q works.

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Jul 17, 2023

ggllm.cpp is not setting cmake CUDA architectures correctly, see ggerganov/ggml#389 . The result is that on Pascal or newer the compiled PTX code does not match the runtime check.

@cmp-nct
Copy link
Contributor

cmp-nct commented Jul 17, 2023

ggllm.cpp is not setting cmake CUDA architectures correctly, see ggerganov/ggml#389 . The result is that on Pascal or newer the compiled PTX code does not match the runtime check.

That's interesting. I fixed that for the next update. though it did not change the results on my system)
My tests are on a 4090 and 40+3090 - in case it matters

In total I have two problems with implementing the new cuda code, I didn't want to oversaturate the report here with two issues yesterday but they are linked: I noted that the QKV pure vector multiplication with the new direct dequant-kernels failed as well. - Though that did not always happen and only that specific matmul operation, not the others.
I just tested
#define GGML_CUDA_DMMV_X 32 // instead of 64

This changed:

  1. The vector multiplication of qkv works now in my current tests
  2. the matrix-matrix multiplication now returns "some" correct information

To get closer into 2:

  1. I ran a short test with a simple sentence as prompt and I got a result using batched processing with the new kernel function.
  2. I ran a second test with a larger input (~2k tokens) and -b 1024 (and 256) and I got garbage again.
  3. Next I used cublas for the lm_head and first ff_up and new kernel only on the ff_down matmul:
    cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur);
    identical results to pure cublas !
  4. Another test: this time cublas only for lm_head, new matmul for ff_up and ff_down: garbage output (;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;)

So changing DMVV_X from 64 to 32 made a big difference on both, the dequantization vec mulmat kernels and on the matmul kernel. But it randomly works and fails depending on which shape you are feeding it.
Likely two problems that intersect ?

@JohannesGaessler
Copy link
Collaborator Author

Sorry but I don't see how this is related to the new kernels I'm implementing. They are only used for quantized data and the KV cache is f16.

@cmp-nct
Copy link
Contributor

cmp-nct commented Jul 17, 2023

Sorry but I don't see how this is related to the new kernels I'm implementing. They are only used for quantized data and the KV cache is f16.

I've spent hours into it, I wouldn't report that as an issue if it was not relevant.
qkv = W_qkv * cur: W_qkv is Q4_1 in my test. cur is a 32 bit tensor
KV cache is 32 bit on gglm as a default, so I can assure you it's a completely normal matmul operation which usually is done on cublas at batched ingestion and on a vector mulmat kernel at single token processing. Nothing special about it.

Also the other cases where I have similar failures, completely normal multiplications that work perfectly fine using cublas and fail using the new kernel. That's quite easy to test, all you've to do is add "false" into the branch that differentiates between mat_q and cublas in ggml_cuda_mul_mat().
From my understanding of the cuda code the matmul is built as a full replacement for Q4-8, so if it works with mat_cublas() it should work with mul_mat_q() as well ?

The DMMV_X = 64 causes ALL ggml_cuda_op_mul_mat_q to fail reliably (output tensor zero), at 32 it fails at some of them.
DMMV_X = 64 also causes this one normal vector multiplication to fail.

I am aware that there is no priority or preference to get this working on Falcon, it's llama.cpp after all. But in my opinion a full matmul replacement should work as reliable as the cublas one. So on all legal tensor input shapes, any non supported shape should ASSERT

I'll finish my work on upgrading and cleaning the current ggllm backend, I probably have to make a mix of the current and the previous version. Once that's done I can send you a branch to check out and verify the problem if you are interested.
It will be easier if it can just be reproduced with a click instead of describing it

@JohannesGaessler
Copy link
Collaborator Author

Sorry, but I don't intend to provide extended support for ggml-cuda.cu forks that have deviated substantially from upstream, especially for WIP draft PRs. I simply don't have the time for things like that. In this specific case the problem is very likely caused by the mul_mat_vec kernels not having support for all matrix dimensions and this only manifests as a bug because I am using placeholder logic to determine which kernels should be used.

@cmp-nct
Copy link
Contributor

cmp-nct commented Jul 17, 2023

That's fine. I just wanted to point the problem out.
imho this should be solved before it is merged into ggml as generic solution once it's ready to move on from draft status.

Thanks for the update, most likely it is as you said that the vector and mulmat kernels just don't support the shapes.

@JohannesGaessler
Copy link
Collaborator Author

I'll try to finally get this PR in a state that can be merged this weekend. k-quant support is currently still missing. cuBLAS will still be a mandatory dependency because the KV cache needs it. Despite that I plan to make the switch between cuBLAS and the new kernels a compile-time option since my understanding is that the long-term goal is to drop cuBLAS as a mandatory dependency.

@JohannesGaessler JohannesGaessler merged commit 11f3ca0 into ggerganov:master Jul 29, 2023
25 checks passed
@mirek190
Copy link

...wait s slower now with rtx 3090 ??

@JohannesGaessler
Copy link
Collaborator Author

Yes, but the VRAM usage is reduced by 700/970/1430 MiB for 7b/13b/33b. Compile with LLAMA_CUDA_CUBLAS for the old implementation.

@mirek190
Copy link

Loosing even 50% performance with awesome qk_m models to gain 1GB of VRAM ?
Where is benefit here?

@JohannesGaessler
Copy link
Collaborator Author

Prompt processing usually takes up much less time than the actual generation so I think this tradeoff is worthwhile. On an RTX 3090 you can now run 33b with more context or with better quantization. On 16 GB RAM + 8 GB VRAM it should now be possible to run 33b q4 at 2048 context.

@Green-Sky
Copy link
Collaborator

Would it be ok to ship the perplexity tool in the release builds using cublas, while the main/server use the non-cublas kenerls?

@JohannesGaessler
Copy link
Collaborator Author

Yeah, that would make sense.

@Dampfinchen
Copy link

Dampfinchen commented Jul 30, 2023

Prompt processing usually takes up much less time than the actual generation so I think this tradeoff is worthwhile. On an RTX 3090 you can now run 33b with more context or with better quantization. On 16 GB RAM + 8 GB VRAM it should now be possible to run 33b q4 at 2048 context.

That may be the case with high end hardware like P40's and 3090's, but not with more common hardware, especially not at a ctx of 2048 and over. For me, prompt processing even with cublas takes a considerable amount of time. With a 2060 and a 13b model, around 30 seconds (1800 tokens), and 60s generation resulting in a total of around 90s (180 tokens generated), so half the time of generation is spent with prompt processing. In my opinion prompt processing speed is far more important than generation, because the time for the AI to answer feels longer, while with generation you can see the tokens generated in real time using token streaming, so slower generation is not a big deal in my eyes.

@Green-Sky
Copy link
Collaborator

Yeah, that would make sense.

until you outperform it, that is 😄

@Dampfinchen
Copy link

Dampfinchen commented Jul 30, 2023

I appreciate the work done here, but in my opinion, this should only be the default if generation speed and more importantly, prompt processing speed is on par with cublas for all quantization formats. Otherwise most people will have to wonder why the AI suddenly takes longer to answer.

@Dampfinchen
Copy link

Dampfinchen commented Jul 30, 2023

Alright, I have tested it now @JohannesGaessler this is the performance on midrange systems without 24 GB VRAM.

RTX 2060, 32 GB RAM, Core i7 9750H

Old implementation, 13b q5_1, 13 GPU layers, VRAM usage 5,1 GB:

before

New implementation, 17 layers, VRAM Usage 5,2 GB:

after 17 layers

As I've expected, the slower prompt processing time outweights the faster generation time. (Note I was being generous, by adding 100 MB more VRAM usage and using an older build of llama.cpp with the old implementation). So overall time is indeed slower with the new implementation.

Still, I if you chat with the model, which results in around 800 token processed at average instead of 1800, I guess it would end up being faster. But then again, you'd have to wait longer for the generation to start which in my opinion, is less than desireable. For me personally, prompt processing speed as it is with the old implementation, is good for a 13b model on systems with 6 and 8 GB VRAM. But I wouldn't want to run 33b on it because the prompt processing time likely would be twice as slow so it wouldn't result in a good user experience, as you'd have to wait a long time for the generation to start. The new implementation would make it even less possible for me to run 33b models at adequate speed.

@Dampfinchen
Copy link

Dampfinchen commented Jul 30, 2023

Sorry for so many posts in a row, but I have some more data to share.

I've found a good usecase for this implementation on systems with 8 and 6 GB VRAM. q4_0 7b now runs entirely in VRAM.

7b

So if you want to run lower quality 7b models on this kind of hardware (or higher quality 7b k_m models on GPUs with 8 GB VRAM), this PR is indeed an excellent option to have, if it were a seperate flag that is not the default used by popular inference programs like text generation webui and koboldcpp.

Likewise GPUs with 12 GB VRAM and perhaps even 10 GB GPUs should be able to run 13b models entirely in VRAM comfortably with this PR.

However, as many if not most people use llama.cpp to run models too big for their VRAM (which I'd argue partial offloading is llama.cpp's killerfeature) my original point still stands. You won't get full GPU offloading on a 8 GB VRAM GPU with 13b, let alone 33b models even at a ctx of 2048.

@LostRuins
Copy link
Collaborator

Are there any perplexity value comparisons with the non-QxQ version? Unless I grossly misunderstand how this works, I'd imagine the output of q4 x q4 directly without dequantizing is likely to be severely degraded in precision. Or is the matmul done at some intermediate format?

@JohannesGaessler
Copy link
Collaborator Author

Sorry for so many posts in a row, but I have some more data to share.

There is nothing to apologize for. I chose the default based on the overall goals of the project (to not use external BLAS libraries at all) and the way I use llama.cpp and what I care about when I do. I don't expect people to universally agree with my priorities.

Are there any perplexity value comparisons with the non-QxQ version? Unless I grossly misunderstand how this works, I'd imagine the output of q4 x q4 directly without dequantizing is likely to be severely degraded in precision. Or is the matmul done at some intermediate format?

The hidden state is quantized to 8 bits. So the matrix matrix multiplication is for example done using q4 x q8. This is the same way it's done on the CPU (or when using the mul_mat_vec_q kernels) and neither with that or with this implementation I observed worse perplexity.

@nauful
Copy link

nauful commented Jul 30, 2023

The latest update does not run for me on Windows 11 with nvcc 12.2 and 4090, nor with the build released on github. I get the following error:
GGML_ASSERT: D:\a\llama.cpp\llama.cpp\ggml-cuda.cu:4387: i01_high == rows_per_iter || g_device_count > 1

As a workaround, I inserted the following code at ggml-cuda.cu:4381, which then ran fine:

if (g_device_count == 1) {
    i01_low = 0;
    i01_high = rows_per_iter;
}

The following did not work:

if (split) {
...
}
else if (g_device_count == 1) {
// Never reached even when split is always 0, compiler bug?
}

@dranger003
Copy link
Contributor

The latest update does not run for me on Windows 11 with nvcc 12.2 and 4090, nor with the build released on github. I get the following error: GGML_ASSERT: D:\a\llama.cpp\llama.cpp\ggml-cuda.cu:4387: i01_high == rows_per_iter || g_device_count > 1

As a workaround, I inserted the following code at ggml-cuda.cu:4381, which then ran fine:

if (g_device_count == 1) {
    i01_low = 0;
    i01_high = rows_per_iter;
}

The following did not work:

if (split) {
...
}
else if (g_device_count == 1) {
// Never reached even when split is always 0, compiler bug?
}

I confirm this behavior (and fix) also running on Windows 11. What is interesting is that Llama-2-13b works fine but openchat_v3.2 does not and asserts and I see they have 32000 and 32002 vocab size respectively - not sure this is relevant or not.

@LostRuins
Copy link
Collaborator

I have this issue too on the latest commit, which happens when I try to offload the KV buffers.
ggml-cuda.cu:4741: i01_high == rows_per_iter || g_device_count > 1

image

however @dranger003's fix does not seem to work for me, adding the fix at that offset, I get
CUDA error 1 at C:\temp_project\ggml-cuda.cu:4836: invalid argument

@LostRuins
Copy link
Collaborator

LostRuins commented Aug 1, 2023

Extra info: Model used is WizardLM-7B-uncensored.ggmlv3.q4_K_M.bin, which as vocab_size=32001
Compiled and running on Win10

Printing debug information for the associated variables:
i01_low:0, i01_high:32000, rows_per_iter:32001, g_device_count:1
thus failing the assert 32000!=32001
but this has not been an issue until a recent change. running an older build reveals

i01_low:0, i01_high:32001, rows_per_iter:32001, g_device_count:1

Edit: apologize for the spam. I have narrowed it down to commit 11f3ca0 causing this issue.

Initial state after line 4720:
i01_low:0, i01_high:32001, rows_per_iter:32001, i0_offset_high:0, g_device_count:1 split:1 row_high:32000

it happens regardless of whether mul_mat_q is set or not. I believe it may be related to row_high -= row_high % GGML_CUDA_MMQ_Y; as that was not there previously. on further testing, any model with vocab!=32000 seems to be affected.

@dranger003
Copy link
Contributor

@LostRuins Thanks for the updates. I ran into this issue once more with another model newhope.ggmlv3.q8_0.bin which has a vocab_size of 32001. I'll take a look at recent issues, but we may need to open a new issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.