Skip to content

Commit

Permalink
CUDA: revise q8_1 data layout for mul_mat_q (llama/7824)
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler authored and ggerganov committed Jun 15, 2024
1 parent a7d88a6 commit 0610ddd
Show file tree
Hide file tree
Showing 5 changed files with 281 additions and 150 deletions.
88 changes: 57 additions & 31 deletions src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1347,10 +1347,30 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
GGML_UNUSED(main_device);
}

static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) {

#if !defined(GGML_USE_HIPBLAS)
// cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
cudaMemcpy3DPeerParms p = {};
p.dstDevice = dstDevice;
p.dstPtr = make_cudaPitchedPtr(dst, dpitch, dpitch, height);
p.srcDevice = srcDevice;
p.srcPtr = make_cudaPitchedPtr(src, spitch, spitch, height);
p.extent = make_cudaExtent(width, height, 1);
return cudaMemcpy3DPeerAsync(&p, stream);
#else
// HIP does not support cudaMemcpy3DPeerAsync or vmm pools
GGML_UNUSED(dstDevice);
GGML_UNUSED(srcDevice);
return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream);
#endif // !defined(GGML_USE_HIPBLAS)
}

static void ggml_cuda_op_mul_mat(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op,
const bool convert_src1_to_q8_1) {
quantize_cuda_t quantize_src1) {

const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
Expand Down Expand Up @@ -1407,7 +1427,9 @@ static void ggml_cuda_op_mul_mat(
}

struct dev_data {
ggml_cuda_pool_alloc<char> src0_dd_alloc;
int cc;

ggml_cuda_pool_alloc<char> src0_dd_alloc;
ggml_cuda_pool_alloc<float> src1_ddf_alloc;
ggml_cuda_pool_alloc<char> src1_ddq_alloc;
ggml_cuda_pool_alloc<float> dst_dd_alloc;
Expand All @@ -1426,6 +1448,8 @@ static void ggml_cuda_op_mul_mat(
int used_devices = 0;

for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
dev[id].cc = ggml_cuda_info().devices[id].cc;

// by default, use all rows
dev[id].row_low = 0;
dev[id].row_high = ne01;
Expand Down Expand Up @@ -1476,11 +1500,15 @@ static void ggml_cuda_op_mul_mat(
dev[id].src1_ddf = dev[id].src1_ddf_alloc.alloc(ctx.pool(id), ggml_nelements(src1));
}

if (convert_src1_to_q8_1) {
dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
if (quantize_src1) {
size_t src_1_ddq_size = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
if (quantize_src1 == quantize_mmq_q8_1_cuda) {
src_1_ddq_size += get_mmq_x_max_host(dev[id].cc)*sizeof(block_q8_1_mmq);
}
dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size);

if (src1_on_device && src1_is_contiguous) {
quantize_row_q8_1_cuda(dev[id].src1_ddf, dev[id].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
quantize_src1(dev[id].src1_ddf, dev[id].src1_ddq, ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type, stream);
CUDA_CHECK(cudaGetLastError());
}
}
Expand Down Expand Up @@ -1526,7 +1554,12 @@ static void ggml_cuda_op_mul_mat(
const int64_t i03 = i0 / ne12;
const int64_t i02 = i0 % ne12;

const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
size_t src1_ddq_i_offset = i0*ne11 * src1_padded_col_size*q8_1_ts/q8_1_bs;
if (quantize_src1 == quantize_mmq_q8_1_cuda) {
src1_ddq_i_offset += src1_col_0 * sizeof(block_q8_1_mmq);
} else {
src1_ddq_i_offset += src1_col_0 * src1_padded_col_size*q8_1_ts/q8_1_bs;
}

// for split tensors the data begins at i0 == i0_offset_low
char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
Expand All @@ -1543,10 +1576,17 @@ static void ggml_cuda_op_mul_mat(
// copy src0, src1 to device if necessary
if (src1_is_contiguous) {
if (id != ctx.device) {
if (convert_src1_to_q8_1) {
if (quantize_src1) {
char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
CUDA_CHECK(cudaMemcpyPeerAsync(src1_ddq_i, id, src1_ddq_i_source, ctx.device,
src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream));
if (quantize_src1 == quantize_mmq_q8_1_cuda) {
const size_t pitch = ne11*sizeof(block_q8_1_mmq);
const size_t width = src1_ncols*sizeof(block_q8_1_mmq);
const size_t height = src1_padded_col_size/(4*QK8_1);
CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(src1_ddq_i, id, pitch, src1_ddq_i_source, ctx.device, pitch, width, height, stream));
} else {
CUDA_CHECK(cudaMemcpyPeerAsync(
src1_ddq_i, id, src1_ddq_i_source, ctx.device, src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream));
}
} else {
float * src1_ddf_i_source = (float *) src1->data;
src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
Expand All @@ -1561,8 +1601,8 @@ static void ggml_cuda_op_mul_mat(
GGML_ASSERT(false);
}

if (convert_src1_to_q8_1 && !src1_is_contiguous) {
quantize_row_q8_1_cuda(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
if (quantize_src1 && !src1_is_contiguous) {
quantize_src1(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1, src1_padded_col_size, src0->type, stream);
CUDA_CHECK(cudaGetLastError());
}

Expand All @@ -1587,22 +1627,8 @@ static void ggml_cuda_op_mul_mat(
float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
dhf_dst_i += src1_col_0*ne0 + dev[id].row_low;
#if !defined(GGML_USE_HIPBLAS)
// cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
cudaMemcpy3DPeerParms p = {};
p.dstDevice = ctx.device;
p.dstPtr = make_cudaPitchedPtr(dhf_dst_i, ne0*sizeof(float), row_diff, src1_ncols);
p.srcDevice = id;
p.srcPtr = make_cudaPitchedPtr(dst_dd_i, row_diff*sizeof(float), row_diff, src1_ncols);
p.extent = make_cudaExtent(row_diff*sizeof(float), src1_ncols, 1);
CUDA_CHECK(cudaMemcpy3DPeerAsync(&p, stream));
#else
// HIP does not support cudaMemcpy3DPeerAsync or vmm pools
CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float),
dst_dd_i, row_diff*sizeof(float),
row_diff*sizeof(float), src1_ncols,
cudaMemcpyDeviceToDevice, stream));
#endif
CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(
dhf_dst_i, ctx.device, ne0*sizeof(float), dst_dd_i, id, row_diff*sizeof(float), row_diff*sizeof(float), src1_ncols, stream));
} else {
float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
Expand Down Expand Up @@ -1941,13 +1967,13 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
// KQ + KQV multi-batch
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
} else if (use_dequantize_mul_mat_vec) {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr);
} else if (use_mul_mat_vec_q) {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
} else if (use_mul_mat_q) {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
} else {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ void ggml_cuda_op_mul_mat_q(
const int64_t nb01 = src0->nb[1];

const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
GGML_ASSERT(ne10 % QK8_1 == 0);

const int64_t ne0 = dst->ne[0];
Expand All @@ -25,7 +26,7 @@ void ggml_cuda_op_mul_mat_q(
// nrows_dst == nrows of the matrix that the kernel writes into
const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;

const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, nrows_dst};
const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst};

switch (src0->type) {
case GGML_TYPE_Q4_0:
Expand Down
Loading

0 comments on commit 0610ddd

Please sign in to comment.