Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash attention implementations do not handle case where value vectors have different dimension from query vectors #7343

Open
fairydreaming opened this issue May 17, 2024 · 6 comments
Labels
enhancement New feature or request stale

Comments

@fairydreaming
Copy link
Collaborator

For example in ggml.c implementations of ops related to flash attention declare variable D and use it as both dimension of value vector and dimension or key/query vector This will fail for models where query and value vectors have different lengths (for example DeepSeek-V2).

Below are selected fragments of GGML_OP_FLASH_ATTN_EXT op implementation to illustrate the problem.

Creation of result tensor:

llama.cpp/ggml.c

Lines 6792 to 6793 in 51e9d02

int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);

(note that query tensor dimensions are used everywhere, while in reality ne[0] shall be equal to ne[0] of value tensor because the attention output is a linear combination of value vectors.

Definition of variable D:

llama.cpp/ggml.c

Line 15879 in 51e9d02

const int64_t D = neq0;

Assertions all expecting the same length:

llama.cpp/ggml.c

Lines 15889 to 15891 in 51e9d02

GGML_ASSERT(neq0 == D);
GGML_ASSERT(nek0 == D);
GGML_ASSERT(nev0 == D);

Usage of D as a dimension of a value vector:

llama.cpp/ggml.c

Line 15958 in 51e9d02

memset(V16, 0, D*sizeof(ggml_fp16_t));

Usage of D as a dimension of a query vector:

llama.cpp/ggml.c

Lines 15985 to 15987 in 51e9d02

for (int64_t d = 0; d < D; ++d) {
Q16[d] = GGML_FP32_TO_FP16(pq[d]);
}

Suggested solution: create two variables Dq (length of the query vector) and Dv (length of value vector) and use Dq as a query/key vector length and Dv as value vector length. I fixed ggml_compute_forward_flash_attn_ext_f16() this way and it produces correct results (confirmed by running DeepSeek-V2 with -fa option).

I'm not 100% sure if CUDA and Metal implementations are also affected, but it's likely - I also found the same variable D used in the code and comments like "K and V have same shape".

@ggerganov
Copy link
Owner

Thanks for reporting that - CUDA and Metal kernels should also be affected. We should fix that, but maybe after DS2 support is merged so we have something to test with

@oldgithubman
Copy link

Possibly relevant - #2445 (comment)

@bartowski1182
Copy link
Contributor

@JohannesGaessler don't wanna bother you but I assume that you'd be the best suited to handle this or at least shed some light onto how to handle it

@JohannesGaessler
Copy link
Collaborator

I don't think you need any special considerations in terms of program correctness - you would just have to implement it. The bigger challenge will be to implement this in such a way that the performance is good.

@ggerganov ggerganov added enhancement New feature or request and removed bug Something isn't working labels Jun 18, 2024
@ggerganov
Copy link
Owner

Btw, supporting different K/V head sizes might dramatically increase the number of FA kernels that we have to compile, so probably not really worth it

@github-actions github-actions bot added the stale label Jul 21, 2024
@oldgithubman
Copy link

@ggerganov reopen? FA would be very useful for deepseek-v2

@github-actions github-actions bot removed the stale label Jul 22, 2024
@github-actions github-actions bot added the stale label Aug 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request stale
Projects
None yet
Development

No branches or pull requests

5 participants