Skip to content

Commit

Permalink
CUDA: quantized KV support for FA vec (llama/7527)
Browse files Browse the repository at this point in the history
* CUDA: quantized KV support for FA vec

* try CI fix

* fix commented-out kernel variants

* add q8_0 q4_0 tests

* fix nwarps > batch size

* split fattn compile via extern templates

* fix flake8

* fix metal tests

* fix cmake

* make generate_cu_files.py executable

* add autogenerated .cu files

* fix AMD

* error if type_v != FP16 and not flash_attn

* remove obsolete code
  • Loading branch information
JohannesGaessler authored and ggerganov committed Jun 15, 2024
1 parent 174189e commit ebe7877
Show file tree
Hide file tree
Showing 104 changed files with 2,595 additions and 540 deletions.
8 changes: 6 additions & 2 deletions src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2905,10 +2905,14 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
return op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128;
#else
if (op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128) {
if (op->src[0]->ne[0] == 128) {
return true;
}
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA;
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
return true;
}
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
default:
return false;
Expand Down
Loading

0 comments on commit ebe7877

Please sign in to comment.