From d4b3b463a94ef50c6334d7bd6c4af23774ba90c9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 21 Aug 2023 15:06:03 +0300 Subject: [PATCH] ggml : add broadcast support for BLAS ggml_mul_mat() (#460) --- src/ggml.c | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/src/ggml.c b/src/ggml.c index 1182e76fb..6d4d13306 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -10788,6 +10788,10 @@ static void ggml_compute_forward_mul_mat( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); + // broadcast factors + const int64_t r2 = ne12/ne02; + const int64_t r3 = ne13/ne03; + // nb01 >= nb00 - src0 is not transposed // compute by src0 rows @@ -10807,11 +10811,6 @@ static void ggml_compute_forward_mul_mat( #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { - // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension - // ref: https://github.com/ggerganov/ggml/pull/224 - GGML_ASSERT(ne02 == ne12); - GGML_ASSERT(ne03 == ne13); - if (params->ith != 0) { return; } @@ -10824,12 +10823,16 @@ static void ggml_compute_forward_mul_mat( return; } - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - const void * x = (char *) src0->data + i03*nb03 + i02*nb02; - const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); + for (int64_t i13 = 0; i13 < ne13; i13++) { + for (int64_t i12 = 0; i12 < ne12; i12++) { + // broadcast src0 into src1 across 2nd,3rd dimension + const int64_t i03 = i13/r3; + const int64_t i02 = i12/r2; + + const void * x = (char *) src0->data + i02*nb02 + i03*nb03; + const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13); - float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3); if (type != GGML_TYPE_F32) { float * const wdata = params->wdata; @@ -10837,7 +10840,7 @@ static void ggml_compute_forward_mul_mat( size_t id = 0; for (int64_t i01 = 0; i01 < ne01; ++i01) { - to_float((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00); + to_float((const char *) x + i01*nb01, wdata + id, ne00); id += ne00; } @@ -10917,10 +10920,6 @@ static void ggml_compute_forward_mul_mat( assert(ne12 % ne02 == 0); assert(ne13 % ne03 == 0); - // broadcast factors - const int64_t r2 = ne12/ne02; - const int64_t r3 = ne13/ne03; - // block-tiling attempt const int64_t blck_0 = 16; const int64_t blck_1 = 16;