Skip to content

Commit

Permalink
CUDA: add __restrict__ to mul mat vec kernels (ggerganov#2140)
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler authored and YellowRoseCx committed Jul 8, 2023
1 parent 4539bc2 commit f864f60
Showing 1 changed file with 25 additions and 28 deletions.
53 changes: 25 additions & 28 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ typedef float2 dfloat2;
#endif //GGML_CUDA_DMMV_F16

typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v);
typedef void (*to_fp32_cuda_t)(const void * __restrict__ x, float * __restrict__ y, int k, cudaStream_t stream);
typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v);
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
typedef void (*ggml_cuda_op_t)(
Expand Down Expand Up @@ -185,7 +185,7 @@ typedef struct {
} block_q8_1;
static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding");

typedef float (*vec_dot_q_cuda_t)(const void * vbq, const block_q8_1 * bq8_1, const int iqs);
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs);

//================================= k-quants

Expand Down Expand Up @@ -461,7 +461,7 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in

//================================== k-quants

static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float * __restrict__ yy) {

const int i = blockIdx.x;
const block_q2_K * x = (const block_q2_K *) vx;
Expand Down Expand Up @@ -494,7 +494,7 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {

}

static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float * __restrict__ yy) {

const int i = blockIdx.x;
const block_q3_K * x = (const block_q3_K *) vx;
Expand Down Expand Up @@ -558,7 +558,7 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
}
#endif

static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float * __restrict__ yy) {
const block_q4_K * x = (const block_q4_K *) vx;

const int i = blockIdx.x;
Expand Down Expand Up @@ -598,7 +598,7 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
#endif
}

static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float * __restrict__ yy) {
const block_q5_K * x = (const block_q5_K *) vx;

const int i = blockIdx.x;
Expand Down Expand Up @@ -644,7 +644,7 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
#endif
}

static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float * __restrict__ yy) {
const block_q6_K * x = (const block_q6_K *) vx;

const int i = blockIdx.x;
Expand Down Expand Up @@ -688,7 +688,7 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
#endif
}

static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {

static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");

Expand Down Expand Up @@ -796,7 +796,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
}
}

static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {

const int row = blockIdx.y*blockDim.y + threadIdx.y;
if (row > nrows) return;
Expand Down Expand Up @@ -900,7 +900,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
}
}

static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {

const int row = blockIdx.y*blockDim.y + threadIdx.y;
if (row > nrows) return;
Expand Down Expand Up @@ -1003,7 +1003,7 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
}
}

static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) {
static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) {

const int row = blockIdx.x;
const int num_blocks_per_row = ncols / QK_K;
Expand Down Expand Up @@ -1107,7 +1107,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
}
}

static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {

static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");

Expand Down Expand Up @@ -1225,7 +1225,7 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
v.y = x[ib + iqs + 1];
}

static __global__ void quantize_q8_1(const float * x, void * vy, const int k) {
static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;

if (i >= k) {
Expand Down Expand Up @@ -1261,7 +1261,7 @@ static __global__ void quantize_q8_1(const float * x, void * vy, const int k) {
}

template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __global__ void dequantize_block(const void * vx, float * y, const int k) {
static __global__ void dequantize_block(const void * __restrict__ vx, float * __restrict__ y, const int k) {
const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;

if (i >= k) {
Expand All @@ -1281,7 +1281,7 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
y[iybs + iqs + y_offset] = v.y;
}

static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;

Expand All @@ -1306,7 +1306,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * vbq, cons
#endif // __CUDA_ARCH__ >= 600
}

static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;

Expand All @@ -1331,7 +1331,7 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * vbq, cons
#endif // __CUDA_ARCH__ >= 600
}

static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;

Expand Down Expand Up @@ -1366,7 +1366,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * vbq, cons
#endif // __CUDA_ARCH__ >= 600
}

static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;

Expand Down Expand Up @@ -1400,7 +1400,7 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * vbq, cons
#endif // __CUDA_ARCH__ >= 600
}

static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;

Expand All @@ -1420,7 +1420,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * vbq, cons
}

template <int qk, int qi, typename block_q_t, vec_dot_q_cuda_t vec_dot_q_cuda>
static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * dst, const int ncols, const int nrows) {
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
const int row = blockIdx.y*blockDim.y + threadIdx.y;

if (row >= nrows) {
Expand Down Expand Up @@ -1458,7 +1458,7 @@ static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * d
}

template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) {
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
// qk = quantized weights per x block
// qr = number of quantized weights per data value in x block
const int row = blockIdx.y*blockDim.y + threadIdx.y;
Expand Down Expand Up @@ -1525,7 +1525,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y,
}
}

static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
const half * x = (const half *) vx;

const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
Expand Down Expand Up @@ -1572,7 +1572,7 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl
}

static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
const int row_stride_x, const int channel_stride_x) {

const half * x = (const half *) vx;
Expand Down Expand Up @@ -2434,10 +2434,7 @@ inline void ggml_cuda_op_mul_mat_vec(
src0->type == GGML_TYPE_Q5_1 ||
src0->type == GGML_TYPE_Q8_0;

// The integer intrinsics used in mul_mat_vec_q are available with compute capability 6.
// However, they have bad performance with Pascal cards.
// Therefore, in a multi GPU setting decide at runtime which GPUs should use mul_mat_vec_q.
const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 700 && mul_mat_vec_q_implemented;
const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 600 && mul_mat_vec_q_implemented;
#endif

if (use_mul_mat_vec_q) {
Expand Down

0 comments on commit f864f60

Please sign in to comment.