Skip to content

Commit

Permalink
ggml : add broadcast support for BLAS ggml_mul_mat() (#460)
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Aug 21, 2023
1 parent 6e13e28 commit d4b3b46
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -10788,6 +10788,10 @@ static void ggml_compute_forward_mul_mat(
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);

// broadcast factors
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;

// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows

Expand All @@ -10807,11 +10811,6 @@ static void ggml_compute_forward_mul_mat(

#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
// TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
// ref: https://github.com/ggerganov/ggml/pull/224
GGML_ASSERT(ne02 == ne12);
GGML_ASSERT(ne03 == ne13);

if (params->ith != 0) {
return;
}
Expand All @@ -10824,20 +10823,24 @@ static void ggml_compute_forward_mul_mat(
return;
}

for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
const void * x = (char *) src0->data + i03*nb03 + i02*nb02;
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
for (int64_t i13 = 0; i13 < ne13; i13++) {
for (int64_t i12 = 0; i12 < ne12; i12++) {
// broadcast src0 into src1 across 2nd,3rd dimension
const int64_t i03 = i13/r3;
const int64_t i02 = i12/r2;

const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);

float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);

if (type != GGML_TYPE_F32) {
float * const wdata = params->wdata;
ggml_to_float_t const to_float = type_traits[type].to_float;

size_t id = 0;
for (int64_t i01 = 0; i01 < ne01; ++i01) {
to_float((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
to_float((const char *) x + i01*nb01, wdata + id, ne00);
id += ne00;
}

Expand Down Expand Up @@ -10917,10 +10920,6 @@ static void ggml_compute_forward_mul_mat(
assert(ne12 % ne02 == 0);
assert(ne13 % ne03 == 0);

// broadcast factors
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;

// block-tiling attempt
const int64_t blck_0 = 16;
const int64_t blck_1 = 16;
Expand Down

0 comments on commit d4b3b46

Please sign in to comment.