Skip to content

Commit

Permalink
CUDA: use tensor cores for MMQ (llama/7676)
Browse files Browse the repository at this point in the history
* CUDA: int8 tensor cores for MMQ (legacy quants)

* fix out-of-bounds writes

* __builtin_assume -> GGML_CUDA_ASSUME

* fix writeback returning too early
  • Loading branch information
JohannesGaessler authored and ggerganov committed Jun 15, 2024
1 parent a32a2b8 commit c570abc
Show file tree
Hide file tree
Showing 7 changed files with 550 additions and 55 deletions.
21 changes: 17 additions & 4 deletions src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
#define CC_PASCAL 600
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define CC_VOLTA 700
#define CC_TURING 750
#define CC_AMPERE 800
#define CC_OFFSET_AMD 1000000
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
Expand Down Expand Up @@ -326,9 +327,17 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int
#endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
#endif // defined(GGML_USE_HIPBLAS)

#define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
#define FP16_AVAILABLE
#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL

#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
#define FP16_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA

#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
#define INT8_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING

static bool fast_fp16_available(const int cc) {
return cc >= CC_PASCAL && cc != 610;
Expand All @@ -338,6 +347,10 @@ static bool fp16_mma_available(const int cc) {
return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
}

static bool int8_mma_available(const int cc) {
return cc < CC_OFFSET_AMD && cc >= CC_TURING;
}

[[noreturn]]
static __device__ void no_device_code(
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
Expand Down Expand Up @@ -379,7 +392,7 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
}

static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE

#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
#pragma unroll
Expand Down Expand Up @@ -412,7 +425,7 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
}

static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE

#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
return __float2half(fmaxf(__half2float(a), __half2float(b)));
Expand Down
20 changes: 10 additions & 10 deletions src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(

const int sumi = __dp4a(v, u, 0);

#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;

Expand Down Expand Up @@ -122,7 +122,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(

const int sumi = __dp4a(v, u, 0);

#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;

Expand Down Expand Up @@ -181,7 +181,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(

const int sumi = __dp4a(v, u, 0);

#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;

Expand Down Expand Up @@ -236,7 +236,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(

const int sumi = __dp4a(v, u, 0);

#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;

Expand Down Expand Up @@ -314,7 +314,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
GGML_UNUSED(Q_q8);
GGML_UNUSED(Q_ds_v);

#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
const half2 * Q_h2 = (const half2 *) Q_v;

Expand Down Expand Up @@ -407,7 +407,7 @@ static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__
const int q0 = x[ib].qs[iqs];
const int q = ((q0 >> (4*shift)) & 0x0F) - 8;

#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
return ((half) d)*((half) q);
}
Expand All @@ -428,7 +428,7 @@ static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__
const int q0 = x[ib].qs[iqs];
const int q = ((q0 >> (4*shift)) & 0x0F);

#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
return __low2half(dm)*((half) q) + __high2half(dm);
}
Expand All @@ -453,7 +453,7 @@ static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__
const int qh = ((qh0 >> idq) << 4) & 0x10;
const int q = (ql | qh) - 16;

#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
return ((half) d)*((half) q);
}
Expand All @@ -478,7 +478,7 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__
const int qh = ((qh0 >> idq) << 4) & 0x10;
const int q = (ql | qh);

#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
return __low2half(dm)*((half) q) + __high2half(dm);
}
Expand All @@ -497,7 +497,7 @@ static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__
const T d = x[ib].d;
const int q = x[ib].qs[iqs];

#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
return ((half) d)*((half) q);
}
Expand Down
2 changes: 1 addition & 1 deletion src/ggml-cuda/fattn-tile-f16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ static __global__ void flash_attn_tile_ext_f16(
const int ne1,
const int ne2,
const int ne3) {
#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.

const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
Expand Down
2 changes: 1 addition & 1 deletion src/ggml-cuda/fattn-vec-f16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ static __global__ void flash_attn_vec_ext_f16(
const int ne1,
const int ne2,
const int ne3) {
#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.

constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);
Expand Down
6 changes: 3 additions & 3 deletions src/ggml-cuda/fattn-wmma-f16.cuh
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#include "common.cuh"
#include "fattn-common.cuh"

#if FP16_MMA_AVAILABLE
#ifdef FP16_MMA_AVAILABLE
#include <mma.h>
#endif
#endif // FP16_MMA_AVAILABLE

// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
Expand Down Expand Up @@ -45,7 +45,7 @@ static __global__ void flash_attn_ext_f16(
const int ne1,
const int ne2,
const int ne3) {
#if FP16_MMA_AVAILABLE
#ifdef FP16_MMA_AVAILABLE
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.

const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
Expand Down
95 changes: 95 additions & 0 deletions src/ggml-cuda/mma.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#include "common.cuh"

struct mma_int_A_I16K8 {
static constexpr int I = 16;
static constexpr int K = 8;
static constexpr int ne = 4;

int x[ne] = {0};

static __device__ __forceinline__ int get_i(const int l) {
const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < I);
return ret;
}

static __device__ __forceinline__ int get_k(const int l) {
const int ret = (l/2) * (K/2) + threadIdx.x % (K/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < K);
return ret;
}
};

struct mma_int_B_J8K8 {
static constexpr int J = 8;
static constexpr int K = 8;
static constexpr int ne = 2;

int x[ne] = {0};

static __device__ __forceinline__ int get_j(const int /* l */) {
const int ret = threadIdx.x / (K/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < J);
return ret;
}

static __device__ __forceinline__ int get_k(const int l) {
const int ret = l * (K/2) + threadIdx.x % (K/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < K);
return ret;
}
};

struct mma_int_C_I16J8 {
static constexpr int I = 16;
static constexpr int J = 8;
static constexpr int ne = 4;

int x[ne] = {0};

static __device__ __forceinline__ int get_i(const int l) {
const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < I);
return ret;
}

static __device__ __forceinline__ int get_j(const int l) {
const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < J);
return ret;
}

__device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
#ifdef INT8_MMA_AVAILABLE
#if __CUDA_ARCH__ >= CC_AMPERE
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1]));
#else
// On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[0]), "+r"(x[1])
: "r"(mma_A.x[0]), "r"(mma_B.x[0]));
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[1]), "r"(mma_B.x[0]));
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[0]), "+r"(x[1])
: "r"(mma_A.x[2]), "r"(mma_B.x[1]));
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[3]), "r"(mma_B.x[1]));
#endif // __CUDA_ARCH__ >= CC_AMPERE
#else
GGML_UNUSED(mma_A);
GGML_UNUSED(mma_B);
NO_DEVICE_CODE;
#endif // INT8_MMA_AVAILABLE
}
};
Loading

0 comments on commit c570abc

Please sign in to comment.