Skip to content

Commit

Permalink
struct for qk, qr, qi
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Jun 3, 2024
1 parent bd8422d commit 8b6962d
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 56 deletions.
188 changes: 139 additions & 49 deletions ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -484,57 +484,147 @@ static __device__ __forceinline__ float get_alibi_slope(
return powf(base, exph);
}

static constexpr __device__ int ggml_blck_size_device(ggml_type type) {
return type == GGML_TYPE_F16 ? 1 :
type == GGML_TYPE_Q4_0 ? QK4_0 :
type == GGML_TYPE_Q4_1 ? QK4_1 :
type == GGML_TYPE_Q5_0 ? QK5_0 :
type == GGML_TYPE_Q5_1 ? QK5_1 :
type == GGML_TYPE_Q8_0 ? QK8_0 :
type == GGML_TYPE_Q2_K ? QK_K :
type == GGML_TYPE_Q3_K ? QK_K :
type == GGML_TYPE_Q4_K ? QK_K :
type == GGML_TYPE_Q5_K ? QK_K :
type == GGML_TYPE_Q6_K ? QK_K :
type == GGML_TYPE_IQ2_XXS ? QK_K :
type == GGML_TYPE_IQ2_XS ? QK_K :
type == GGML_TYPE_IQ2_S ? QK_K :
type == GGML_TYPE_IQ3_XXS ? QK_K :
type == GGML_TYPE_IQ1_S ? QK_K :
type == GGML_TYPE_IQ1_M ? QK_K :
type == GGML_TYPE_IQ4_NL ? QK4_NL :
type == GGML_TYPE_IQ4_XS ? QK_K :
type == GGML_TYPE_IQ3_S ? QK_K :
0;
}
template <ggml_type type>
struct ggml_cuda_type_traits;

static constexpr __device__ int get_qr_device(ggml_type type) {
return type == GGML_TYPE_F16 ? 1 :
type == GGML_TYPE_Q4_0 ? QR4_0 :
type == GGML_TYPE_Q4_1 ? QR4_1 :
type == GGML_TYPE_Q5_0 ? QR5_0 :
type == GGML_TYPE_Q5_1 ? QR5_1 :
type == GGML_TYPE_Q8_0 ? QR8_0 :
type == GGML_TYPE_Q2_K ? QR2_K :
type == GGML_TYPE_Q3_K ? QR3_K :
type == GGML_TYPE_Q4_K ? QR4_K :
type == GGML_TYPE_Q5_K ? QR5_K :
type == GGML_TYPE_Q6_K ? QR6_K :
type == GGML_TYPE_IQ2_XXS ? QR2_XXS :
type == GGML_TYPE_IQ2_XS ? QR2_XS :
type == GGML_TYPE_IQ2_S ? QR2_S :
type == GGML_TYPE_IQ3_XXS ? QR3_XXS :
type == GGML_TYPE_IQ1_S ? QR1_S :
type == GGML_TYPE_IQ1_M ? QR1_M :
type == GGML_TYPE_IQ4_NL ? QR4_NL :
type == GGML_TYPE_IQ4_XS ? QR4_XS :
type == GGML_TYPE_IQ3_S ? QR3_S :
0;
}
template<>
struct ggml_cuda_type_traits<GGML_TYPE_F16> {
static constexpr int qk = 1;
static constexpr int qr = 1;
};

static constexpr __device__ int get_qi_device(ggml_type type) {
return ggml_blck_size_device(type) / (sizeof(int)*get_qr_device(type));
}
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q4_0> {
static constexpr int qk = QK4_0;
static constexpr int qr = QR4_0;
static constexpr int qi = QI4_0;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q4_1> {
static constexpr int qk = QK4_1;
static constexpr int qr = QR4_1;
static constexpr int qi = QI4_1;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q5_0> {
static constexpr int qk = QK5_0;
static constexpr int qr = QR5_0;
static constexpr int qi = QI5_0;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q5_1> {
static constexpr int qk = QK5_1;
static constexpr int qr = QR5_1;
static constexpr int qi = QI5_1;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {
static constexpr int qk = QK8_0;
static constexpr int qr = QR8_0;
static constexpr int qi = QI8_0;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR2_K;
static constexpr int qi = QI2_K;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q3_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR3_K;
static constexpr int qi = QI3_K;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q4_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR4_K;
static constexpr int qi = QI4_K;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q5_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR5_K;
static constexpr int qi = QI5_K;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q6_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR6_K;
static constexpr int qi = QI6_K;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XXS> {
static constexpr int qk = QK_K;
static constexpr int qr = QR2_XXS;
static constexpr int qi = QI2_XXS;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XS> {
static constexpr int qk = QK_K;
static constexpr int qr = QR2_XS;
static constexpr int qi = QI2_XS;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_S> {
static constexpr int qk = QK_K;
static constexpr int qr = QR2_S;
static constexpr int qi = QI2_S;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_XXS> {
static constexpr int qk = QK_K;
static constexpr int qr = QR3_XXS;
static constexpr int qi = QI3_XXS;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ1_S> {
static constexpr int qk = QK_K;
static constexpr int qr = QR1_S;
static constexpr int qi = QI1_S;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ1_M> {
static constexpr int qk = QK_K;
static constexpr int qr = QR1_M;
static constexpr int qi = QI1_M;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> {
static constexpr int qk = QK4_NL;
static constexpr int qr = QR4_NL;
static constexpr int qi = QI4_NL;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> {
static constexpr int qk = QK_K;
static constexpr int qr = QR4_XS;
static constexpr int qi = QI4_XS;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
static constexpr int qk = QK_K;
static constexpr int qr = QR3_S;
static constexpr int qi = QI3_S;
};

static int get_mmq_x_max_host(const int cc) {
#ifdef CUDA_USE_TENSOR_CORES
Expand Down
4 changes: 2 additions & 2 deletions ggml-cuda/dmmv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,8 @@ static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type

template <ggml_type type>
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
constexpr int qk = ggml_blck_size_device(type); // quantized weights per x block
constexpr int qr = get_qr_device(type); // number of quantized weights per data value in x block
constexpr int qk = ggml_cuda_type_traits<type>::qk; // quantized weights per x block
constexpr int qr = ggml_cuda_type_traits<type>::qr; // number of quantized weights per data value in x block
constexpr dequantize_kernel_t dequantize_kernel = get_dequantize_kernel(type);

const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y;
Expand Down
6 changes: 3 additions & 3 deletions ggml-cuda/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1033,10 +1033,10 @@ static __global__ void mul_mat_q(
return;
}

constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qr = ggml_cuda_type_traits<type>::qr;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int mmq_y = get_mmq_y_device(mmq_x);
constexpr int qk = ggml_blck_size_device(type);
constexpr int qr = get_qr_device(type);
constexpr int qi = get_qi_device(type);
constexpr bool need_sum = get_need_sum(type);
constexpr int vdr = get_vdr_mmq(type);

Expand Down
4 changes: 2 additions & 2 deletions ggml-cuda/mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {

constexpr int qk = ggml_blck_size_device(type);
constexpr int qi = get_qi_device(type);
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int vdr = get_vdr_mmvq(type);

constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
Expand Down

0 comments on commit 8b6962d

Please sign in to comment.