Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: refactor mmq, dmmv, mmvq #7716

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,8 @@ if (LLAMA_HIPBLAS)
list(APPEND GGML_SOURCES_ROCM "ggml-cuda.cu")
file(GLOB SRCS "ggml-cuda/template-instances/fattn-wmma*.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
file(GLOB SRCS "ggml-cuda/template-instances/mmq*.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})

add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUDA)

Expand Down
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ ifdef LLAMA_CUBLAS
endif

OBJS_CUDA_TEMP_INST = $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-wmma*.cu))
OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/mmq*.cu))
ifdef LLAMA_CUDA_FA_ALL_QUANTS
OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-vec*.cu))
else
Expand Down
6 changes: 6 additions & 0 deletions ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,18 @@ typedef sycl::half2 ggml_half2;
#define QI1_S (QK_K / (4*QR1_S))
#define QR1_S 8

#define QI1_M (QK_K / (4*QR1_M))
#define QR1_M 8

#define QI4_NL (QK4_NL / (4*QR4_NL))
#define QR4_NL 2

#define QI4_XS (QK_K / (4*QR4_XS))
#define QR4_XS 8

#define QI3_S (QK_K / (4*QR3_S))
#define QR3_S 8

#endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP

#define QK4_0 32
Expand Down
84 changes: 9 additions & 75 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -633,88 +633,22 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {

// cuda split buffer

static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split) {
int64_t min_compute_capability = INT_MAX;
int64_t max_compute_capability = INT_MIN;
static int64_t get_row_rounding(const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split) {
int64_t row_rounding = 0;
for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
if (tensor_split[id] < (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
if (min_compute_capability > ggml_cuda_info().devices[id].cc) {
min_compute_capability = ggml_cuda_info().devices[id].cc;
}
if (max_compute_capability < ggml_cuda_info().devices[id].cc) {
max_compute_capability = ggml_cuda_info().devices[id].cc;
}
if (tensor_split[id] >= (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
continue;
}
}

#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
switch(type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
case GGML_TYPE_F16:
case GGML_TYPE_F32:
return 1;
case GGML_TYPE_Q2_K:
return max_compute_capability >= CC_RDNA2 ? 128 : 32;
case GGML_TYPE_Q3_K:
return min_compute_capability < CC_RDNA2 ? 128 : 64;
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S:
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
default:
GGML_ASSERT(false);
}
#else
switch(type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
return max_compute_capability >= CC_VOLTA ? 128 : 64;
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
return 64;
case GGML_TYPE_F16:
case GGML_TYPE_F32:
return 1;
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ3_S:
return max_compute_capability >= CC_VOLTA ? 128 : 64;
case GGML_TYPE_Q6_K:
return 64;
default:
GGML_ASSERT(false);
const int cc = ggml_cuda_info().devices[id].cc;
row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc, get_mmq_x_max_host(cc)));
}
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
return row_rounding;
}

static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split, int id) {
const int64_t nrows = ggml_nrows(tensor);
const int64_t rounding = get_row_rounding(tensor->type, tensor_split);
const int64_t rounding = get_row_rounding(tensor_split);

*row_low = id == 0 ? 0 : nrows*tensor_split[id];
*row_low -= *row_low % rounding;
Expand Down Expand Up @@ -1499,7 +1433,7 @@ static void ggml_cuda_op_mul_mat(
// for multi GPU, get the row boundaries from tensor split
// and round to mul_mat_q tile sizes
if (split) {
const int64_t rounding = get_row_rounding(src0->type, tensor_split);
const int64_t rounding = get_row_rounding(tensor_split);

if (id != 0) {
dev[id].row_low = ne01*tensor_split[id];
Expand Down
67 changes: 66 additions & 1 deletion ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@
#endif

#define MMVQ_MAX_BATCH_SIZE 8 // max batch size to use MMVQ kernels
#define MMQ_MAX_BATCH_SIZE 32 // max batch size to use MMQ kernels when tensor cores are available
#define MMQ_MAX_BATCH_SIZE 64 // max batch size to use MMQ kernels when tensor cores are available

#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses

Expand Down Expand Up @@ -484,6 +484,71 @@ static __device__ __forceinline__ float get_alibi_slope(
return powf(base, exph);
}

static constexpr __device__ int ggml_blck_size_device(ggml_type type) {
return type == GGML_TYPE_F16 ? 1 :
type == GGML_TYPE_Q4_0 ? QK4_0 :
type == GGML_TYPE_Q4_1 ? QK4_1 :
slaren marked this conversation as resolved.
Show resolved Hide resolved
type == GGML_TYPE_Q5_0 ? QK5_0 :
type == GGML_TYPE_Q5_1 ? QK5_1 :
type == GGML_TYPE_Q8_0 ? QK8_0 :
type == GGML_TYPE_Q2_K ? QK_K :
type == GGML_TYPE_Q3_K ? QK_K :
type == GGML_TYPE_Q4_K ? QK_K :
type == GGML_TYPE_Q5_K ? QK_K :
type == GGML_TYPE_Q6_K ? QK_K :
type == GGML_TYPE_IQ2_XXS ? QK_K :
type == GGML_TYPE_IQ2_XS ? QK_K :
type == GGML_TYPE_IQ2_S ? QK_K :
type == GGML_TYPE_IQ3_XXS ? QK_K :
type == GGML_TYPE_IQ1_S ? QK_K :
type == GGML_TYPE_IQ1_M ? QK_K :
type == GGML_TYPE_IQ4_NL ? QK4_NL :
type == GGML_TYPE_IQ4_XS ? QK_K :
type == GGML_TYPE_IQ3_S ? QK_K :
0;
}

static constexpr __device__ int get_qr_device(ggml_type type) {
return type == GGML_TYPE_F16 ? 1 :
type == GGML_TYPE_Q4_0 ? QR4_0 :
type == GGML_TYPE_Q4_1 ? QR4_1 :
type == GGML_TYPE_Q5_0 ? QR5_0 :
type == GGML_TYPE_Q5_1 ? QR5_1 :
type == GGML_TYPE_Q8_0 ? QR8_0 :
type == GGML_TYPE_Q2_K ? QR2_K :
type == GGML_TYPE_Q3_K ? QR3_K :
type == GGML_TYPE_Q4_K ? QR4_K :
type == GGML_TYPE_Q5_K ? QR5_K :
type == GGML_TYPE_Q6_K ? QR6_K :
type == GGML_TYPE_IQ2_XXS ? QR2_XXS :
type == GGML_TYPE_IQ2_XS ? QR2_XS :
type == GGML_TYPE_IQ2_S ? QR2_S :
type == GGML_TYPE_IQ3_XXS ? QR3_XXS :
type == GGML_TYPE_IQ1_S ? QR1_S :
type == GGML_TYPE_IQ1_M ? QR1_M :
type == GGML_TYPE_IQ4_NL ? QR4_NL :
type == GGML_TYPE_IQ4_XS ? QR4_XS :
type == GGML_TYPE_IQ3_S ? QR3_S :
0;
}

static constexpr __device__ int get_qi_device(ggml_type type) {
return ggml_blck_size_device(type) / (sizeof(int)*get_qr_device(type));
}

static int get_mmq_x_max_host(const int cc) {
#ifdef CUDA_USE_TENSOR_CORES
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64;
#else
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64;
#endif // CUDA_USE_TENSOR_CORES
}

// Round rows to this value for --split-mode row:
static int get_mmq_y_host(const int cc, const int mmq_x) {
return cc >= CC_VOLTA && mmq_x >= 32 ? 128 : 64;
}

//////////////////////

struct ggml_cuda_device_info {
Expand Down
30 changes: 21 additions & 9 deletions ggml-cuda/dmmv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -422,10 +422,22 @@ static __device__ void convert_f16(const void * vx, const int64_t ib, const int
v.y = x[ib + iqs + 1];
}

template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) {
return type == GGML_TYPE_Q4_0 ? dequantize_q4_0 :
type == GGML_TYPE_Q4_1 ? dequantize_q4_1 :
type == GGML_TYPE_Q5_0 ? dequantize_q5_0 :
type == GGML_TYPE_Q5_1 ? dequantize_q5_1 :
type == GGML_TYPE_Q8_0 ? dequantize_q8_0 :
type == GGML_TYPE_F16 ? convert_f16 :
nullptr;
}
Comment on lines +425 to +433
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could also be moved to ggml_cuda_type_traits.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are circular dependencies between common.cuh and dequantize.cuh okay?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, but there is probably too much in common.cuh.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So should we for now just keep this as-is?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Up to you, it's just a suggestion.


template <ggml_type type>
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
// qk = quantized weights per x block
// qr = number of quantized weights per data value in x block
constexpr int qk = ggml_blck_size_device(type); // quantized weights per x block
constexpr int qr = get_qr_device(type); // number of quantized weights per data value in x block
constexpr dequantize_kernel_t dequantize_kernel = get_dequantize_kernel(type);

const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y;

if (row >= nrows) {
Expand Down Expand Up @@ -493,7 +505,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y,
// the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
dequantize_mul_mat_vec<GGML_TYPE_Q4_0>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}

Expand All @@ -502,7 +514,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y,
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
dequantize_mul_mat_vec<GGML_TYPE_Q4_1>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}

Expand All @@ -511,7 +523,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y,
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
dequantize_mul_mat_vec<GGML_TYPE_Q5_0>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}

Expand All @@ -520,7 +532,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y,
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
dequantize_mul_mat_vec<GGML_TYPE_Q5_1>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}

Expand All @@ -529,7 +541,7 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y,
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
dequantize_mul_mat_vec<GGML_TYPE_Q8_0>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}

Expand Down Expand Up @@ -580,7 +592,7 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<1, 1, convert_f16>
dequantize_mul_mat_vec<GGML_TYPE_F16>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}

Expand Down
Loading
Loading