From 81c983a4cfec434492ea9b3802637672bd310009 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 29 May 2024 09:45:33 +0300 Subject: [PATCH 1/2] tests : add non-cont concat tests --- tests/test-backend-ops.cpp | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b200ccccd51b0..5cde21c660514 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1262,22 +1262,37 @@ struct test_concat : public test_case { const std::array ne_a; const int64_t ne_b_d; const int dim; + const int v; // view (1 << 0: non-cont a, 1 << 1: non-cont b) std::string vars() override { - return VARS_TO_STR4(type, ne_a, ne_b_d, dim); + return VARS_TO_STR5(type, ne_a, ne_b_d, dim, v); } test_concat(ggml_type type = GGML_TYPE_F32, std::array ne_a = {10, 10, 10, 10}, int64_t ne_b_d = 10, - int dim = 2) - : type(type), ne_a(ne_a), ne_b_d(ne_b_d), dim(dim) {} + int dim = 2, int v = 0) + : type(type), ne_a(ne_a), ne_b_d(ne_b_d), dim(dim), v(v) {} ggml_tensor * build_graph(ggml_context * ctx) override { auto ne_b = ne_a; ne_b[dim] = ne_b_d; - ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); - ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data()); + ggml_tensor * a; + if (v & 1) { + auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3; + a = ggml_new_tensor(ctx, type, 4, ne.data()); + a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0); + } else { + a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + } + ggml_tensor * b; + if (v & 2) { + auto ne = ne_b; ne[0] *= 3; ne[1] *= 2; ne[2] *= 4; + b = ggml_new_tensor(ctx, type, 4, ne.data()); + b = ggml_view_4d(ctx, b, ne_b[0], ne_b[1], ne_b[2], ne_b[3], b->nb[1], b->nb[2], b->nb[3], 0); + } else { + b = ggml_new_tensor(ctx, type, 4, ne_b.data()); + } ggml_tensor * out = ggml_concat(ctx, a, b, dim); return out; } @@ -2215,9 +2230,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } - for (int dim : { 0, 1, 2, 3, }) { - test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim)); - test_cases.emplace_back(new test_concat(GGML_TYPE_I32, {11, 12, 13, 14}, 7, dim)); + for (int v : { 0, 1, 2, 3 }) { + for (int dim : { 0, 1, 2, 3, }) { + test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim, v)); + test_cases.emplace_back(new test_concat(GGML_TYPE_I32, {11, 12, 13, 14}, 7, dim, v)); + } } for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) { From 738008fbcccf5b5fd6cb93378e5f55cbc44b6b91 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 29 May 2024 10:14:34 +0300 Subject: [PATCH 2/2] cuda : non-cont concat support ggml-ci --- ggml-cuda/concat.cu | 110 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 88 insertions(+), 22 deletions(-) diff --git a/ggml-cuda/concat.cu b/ggml-cuda/concat.cu index fb9dee8f8cee5..dac10ec36b0bd 100644 --- a/ggml-cuda/concat.cu +++ b/ggml-cuda/concat.cu @@ -1,5 +1,6 @@ #include "concat.cuh" +// contiguous kernels static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) { int nidx = threadIdx.x + blockIdx.x * blockDim.x; if (nidx >= ne0) { @@ -92,39 +93,104 @@ static void concat_f32_cuda(const float * x, const float * y, float * dst, int n concat_f32_dim2<<>>(x, y, dst, ne0, ne02); } +// non-contiguous kernel (slow) +static __global__ void concat_f32_non_cont( + const char * src0, + const char * src1, + char * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne03, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, + int64_t /*ne10*/, + int64_t /*ne11*/, + int64_t /*ne12*/, + int64_t /*ne13*/, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, + int64_t ne0, + int64_t /*ne1*/, + int64_t /*ne2*/, + int64_t /*ne3*/, + uint64_t nb0, + uint64_t nb1, + uint64_t nb2, + uint64_t nb3, + int32_t dim) { + const int64_t i3 = blockIdx.z; + const int64_t i2 = blockIdx.y; + const int64_t i1 = blockIdx.x; + + int64_t o[4] = {0, 0, 0, 0}; + o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); + + const float * x; + + for (int i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); + } else { + x = (const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); + } + + float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + *y = *x; + } +} + + void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - const float * src0_d = (const float *)src0->data; - const float * src1_d = (const float *)src1->data; - - float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); const int32_t dim = ((int32_t *) dst->op_params)[0]; - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - - if (dim != 3) { - for (int i3 = 0; i3 < dst->ne[3]; i3++) { - concat_f32_cuda( - src0_d + i3 * (src0->nb[3] / 4), - src1_d + i3 * (src1->nb[3] / 4), - dst_d + i3 * ( dst->nb[3] / 4), - src0->ne[0], src0->ne[1], src0->ne[2], - dst->ne[0], dst->ne[1], dst->ne[2], dim, stream); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { + const float * src0_d = (const float *)src0->data; + const float * src1_d = (const float *)src1->data; + + float * dst_d = (float *)dst->data; + + if (dim != 3) { + for (int i3 = 0; i3 < dst->ne[3]; i3++) { + concat_f32_cuda( + src0_d + i3 * (src0->nb[3] / 4), + src1_d + i3 * (src1->nb[3] / 4), + dst_d + i3 * ( dst->nb[3] / 4), + src0->ne[0], src0->ne[1], src0->ne[2], + dst->ne[0], dst->ne[1], dst->ne[2], dim, stream); + } + } else { + const size_t size0 = ggml_nbytes(src0); + const size_t size1 = ggml_nbytes(src1); + + CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream)); } } else { - const size_t size0 = ggml_nbytes(src0); - const size_t size1 = ggml_nbytes(src1); - - CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream)); - CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream)); + dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]); + concat_f32_non_cont<<>>( + (const char *)src0->data, + (const char *)src1->data, + ( char *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim); } }