Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: use tensor cores for MMQ #7676

Merged
merged 4 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
#define CC_PASCAL 600
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define CC_VOLTA 700
#define CC_TURING 750
#define CC_AMPERE 800
#define CC_OFFSET_AMD 1000000
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
Expand Down Expand Up @@ -326,9 +327,17 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int
#endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
#endif // defined(GGML_USE_HIPBLAS)

#define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
#define FP16_AVAILABLE
#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL

#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
#define FP16_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA

#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
#define INT8_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING

static bool fast_fp16_available(const int cc) {
return cc >= CC_PASCAL && cc != 610;
Expand All @@ -338,6 +347,10 @@ static bool fp16_mma_available(const int cc) {
return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
}

static bool int8_mma_available(const int cc) {
return cc < CC_OFFSET_AMD && cc >= CC_TURING;
}

[[noreturn]]
static __device__ void no_device_code(
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
Expand Down Expand Up @@ -379,7 +392,7 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
}

static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE

#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
#pragma unroll
Expand Down Expand Up @@ -412,7 +425,7 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
}

static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE

#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
return __float2half(fmaxf(__half2float(a), __half2float(b)));
Expand Down
20 changes: 10 additions & 10 deletions ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(

const int sumi = __dp4a(v, u, 0);

#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;

Expand Down Expand Up @@ -122,7 +122,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(

const int sumi = __dp4a(v, u, 0);

#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;

Expand Down Expand Up @@ -181,7 +181,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(

const int sumi = __dp4a(v, u, 0);

#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;

Expand Down Expand Up @@ -236,7 +236,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(

const int sumi = __dp4a(v, u, 0);

#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;

Expand Down Expand Up @@ -314,7 +314,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
GGML_UNUSED(Q_q8);
GGML_UNUSED(Q_ds_v);

#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
const half2 * Q_h2 = (const half2 *) Q_v;

Expand Down Expand Up @@ -407,7 +407,7 @@ static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__
const int q0 = x[ib].qs[iqs];
const int q = ((q0 >> (4*shift)) & 0x0F) - 8;

#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
return ((half) d)*((half) q);
}
Expand All @@ -428,7 +428,7 @@ static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__
const int q0 = x[ib].qs[iqs];
const int q = ((q0 >> (4*shift)) & 0x0F);

#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
return __low2half(dm)*((half) q) + __high2half(dm);
}
Expand All @@ -453,7 +453,7 @@ static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__
const int qh = ((qh0 >> idq) << 4) & 0x10;
const int q = (ql | qh) - 16;

#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
return ((half) d)*((half) q);
}
Expand All @@ -478,7 +478,7 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__
const int qh = ((qh0 >> idq) << 4) & 0x10;
const int q = (ql | qh);

#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
return __low2half(dm)*((half) q) + __high2half(dm);
}
Expand All @@ -497,7 +497,7 @@ static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__
const T d = x[ib].d;
const int q = x[ib].qs[iqs];

#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
if (std::is_same<T, half>::value) {
return ((half) d)*((half) q);
}
Expand Down
2 changes: 1 addition & 1 deletion ggml-cuda/fattn-tile-f16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ static __global__ void flash_attn_tile_ext_f16(
const int ne1,
const int ne2,
const int ne3) {
#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.

const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
Expand Down
2 changes: 1 addition & 1 deletion ggml-cuda/fattn-vec-f16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ static __global__ void flash_attn_vec_ext_f16(
const int ne1,
const int ne2,
const int ne3) {
#if FP16_AVAILABLE
#ifdef FP16_AVAILABLE
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.

constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);
Expand Down
6 changes: 3 additions & 3 deletions ggml-cuda/fattn-wmma-f16.cuh
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#include "common.cuh"
#include "fattn-common.cuh"

#if FP16_MMA_AVAILABLE
#ifdef FP16_MMA_AVAILABLE
#include <mma.h>
#endif
#endif // FP16_MMA_AVAILABLE

// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
Expand Down Expand Up @@ -45,7 +45,7 @@ static __global__ void flash_attn_ext_f16(
const int ne1,
const int ne2,
const int ne3) {
#if FP16_MMA_AVAILABLE
#ifdef FP16_MMA_AVAILABLE
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.

const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
Expand Down
95 changes: 95 additions & 0 deletions ggml-cuda/mma.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#include "common.cuh"

struct mma_int_A_I16K8 {
static constexpr int I = 16;
static constexpr int K = 8;
static constexpr int ne = 4;

int x[ne] = {0};

static __device__ __forceinline__ int get_i(const int l) {
const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < I);
return ret;
}

static __device__ __forceinline__ int get_k(const int l) {
const int ret = (l/2) * (K/2) + threadIdx.x % (K/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < K);
return ret;
}
};

struct mma_int_B_J8K8 {
static constexpr int J = 8;
static constexpr int K = 8;
static constexpr int ne = 2;

int x[ne] = {0};

static __device__ __forceinline__ int get_j(const int /* l */) {
const int ret = threadIdx.x / (K/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < J);
return ret;
}

static __device__ __forceinline__ int get_k(const int l) {
const int ret = l * (K/2) + threadIdx.x % (K/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < K);
return ret;
}
};

struct mma_int_C_I16J8 {
static constexpr int I = 16;
static constexpr int J = 8;
static constexpr int ne = 4;

int x[ne] = {0};

static __device__ __forceinline__ int get_i(const int l) {
const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < I);
return ret;
}

static __device__ __forceinline__ int get_j(const int l) {
const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < J);
return ret;
}

__device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
#ifdef INT8_MMA_AVAILABLE
#if __CUDA_ARCH__ >= CC_AMPERE
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1]));
#else
// On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[0]), "+r"(x[1])
: "r"(mma_A.x[0]), "r"(mma_B.x[0]));
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[1]), "r"(mma_B.x[0]));
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[0]), "+r"(x[1])
: "r"(mma_A.x[2]), "r"(mma_B.x[1]));
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[3]), "r"(mma_B.x[1]));
#endif // __CUDA_ARCH__ >= CC_AMPERE
#else
GGML_UNUSED(mma_A);
GGML_UNUSED(mma_B);
NO_DEVICE_CODE;
#endif // INT8_MMA_AVAILABLE
}
};
Loading
Loading