Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Q4_K implementation for Metal #1733

Merged
merged 4 commits into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Optimizing q4_K metal dot some more
For n = 256 it is now 28.1 ms/token compared to
27 ms/token for q4_0.
  • Loading branch information
Kawrakow committed Jun 7, 2023
commit a5becf764b198bbc848764e13d8ea7259c036dc7
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ endif

# keep standard at C11 and C++11
# -Ofast tends to produce faster code, but may not be available for some compilers.
OPT = -Ofast
#OPT = -O3
#OPT = -Ofast
OPT = -O3
CFLAGS = -I. $(OPT) -std=c11 -fPIC
CXXFLAGS = -I. -I./examples $(OPT) -std=c++11 -fPIC
LDFLAGS =
Expand Down
315 changes: 162 additions & 153 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// 34.7 ms / token
#include <metal_stdlib>

using namespace metal;
Expand All @@ -12,15 +11,6 @@ typedef struct {
uint8_t qs[QK4_0 / 2]; // nibbles / quants
} block_q4_0;

#define QK_K 256

typedef struct {
half d; // super-block scale for quantized scales
half dmin; // super-block scale for quantized mins
uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_k;

static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
const int qk = QK4_0;

Expand All @@ -41,56 +31,6 @@ static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, i
}
}

static inline uchar2 get_scale_min_k4(int j, device const uint8_t * q) {
uchar2 r;
if (j < 4) {
r[0] = q[j] & 63; r[1] = q[j + 4] & 63;
} else {
r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
}
return r;
}
static inline uchar4 get_scale_min_k4_2(int j, device const uint8_t * q) {
uchar4 r;
if (j < 4) {
r[0] = q[j+0] & 63; r[1] = q[j+4] & 63;
r[2] = q[j+1] & 63; r[3] = q[j+5] & 63;
} else {
r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
}
return r;
}

static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;

for (int i = 0; i < nb; i++) {

const float d = x[i].d;
const float min = x[i].dmin;

device const uint8_t * q = x[i].qs;
device const uint8_t * scales = x[i].scales;

int is = 0;
for (int j = 0; j < QK_K; j += 64) {
const uchar2 sc1 = get_scale_min_k4(is, scales);
const float d1 = d * sc1[0]; const float m1 = min * sc1[1];
const uchar2 sc2 = get_scale_min_k4(is+1, scales);
const float d2 = d * sc2[0]; const float m2 = min * sc2[1];
for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
q += 32; is += 2;
}

}
}

kernel void kernel_add(
device const float * src0,
device const float * src1,
Expand Down Expand Up @@ -261,22 +201,6 @@ kernel void kernel_get_rows_q4_0(
(device float *) ((device char *) dst + i*nb1), ne00);
}

kernel void kernel_get_rows_q4_k(
device const void * src0,
device const int * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant uint64_t & nb1,
uint tpig[[thread_position_in_grid]]) {
const int i = tpig;
const int r = ((device int32_t *) src1)[i];

dequantize_row_q4_k(
(device const block_q4_k *) ((device char *) src0 + r*nb01),
(device float *) ((device char *) dst + i*nb1), ne00);
}

kernel void kernel_rms_norm(
device const void * src0,
device float * dst,
Expand Down Expand Up @@ -393,83 +317,6 @@ kernel void kernel_mul_mat_q4_0_f32(
}
}

kernel void kernel_mul_mat_q4_k_f32(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
threadgroup float * sum [[threadgroup(0)]],
uint2 tgpig[[threadgroup_position_in_grid]],
uint2 tpig[[thread_position_in_grid]], // we don't use this for now
uint2 tpitg[[thread_position_in_threadgroup]],
uint2 tptg[[threads_per_threadgroup]]) {

const int nb = ne00/QK_K;

const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;

device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
device const float * yy = (device const float *) src1 + r1*ne10;

const uint nth = tptg.x*tptg.y;
const uint ith = tptg.y*tpitg.x + tpitg.y;

const int tid = tpitg.y; // 0...16
const int il = tid/4; // 0...3
const int ir = tid%4; // 0...3
const int n = 8;
const int is = 2*il;

sum[ith] = 0.0f;

float sumf = 0;
for (int i = tpitg.x; i < nb; i += tptg.x) {

device const uint8_t * q = (x + i)->qs + 32*il + n*ir;
device const float * y = yy + i*QK_K + 64*il + n*ir;
device const uint8_t * scales = (x + i)->scales;

const float dall = (float)((x + i)->d);
const float dmin = (float)((x + i)->dmin);

const uchar4 sc = get_scale_min_k4_2(is, scales);

float4 s = {0.f, 0.f, 0.f, 0.f};
for (int l = 0; l < n; ++l) {
s[0] += y[l] * (q[l] & 0xF); s[1] += y[l];
s[2] += y[l+32] * (q[l] >> 4); s[3] += y[l+32];
}
sumf += dall * (s[0] * sc[0] + s[2] * sc[2]) - dmin * (s[1] * sc[1] + s[3] * sc[3]);

}
sum[ith] = sumf;

// accumulate the sum from all threads in the threadgroup
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint i = nth/2; i > 0; i /= 2) {
if (ith < i) {
sum[ith] += sum[ith + i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}

if (ith == 0) {
dst[r1*ne0 + r0] = sum[0];
}
}

kernel void kernel_mul_mat_f16_f32(
device const char * src0,
device const char * src1,
Expand Down Expand Up @@ -656,3 +503,165 @@ kernel void kernel_cpy_f32_f32(
dst_data[i00] = src[0];
}
}

//============================================ k-quants ======================================================

#define QK_K 256

typedef struct {
half d; // super-block scale for quantized scales
half dmin; // super-block scale for quantized mins
uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_k;

static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
uchar4 r;
if (j < 4) {
r[0] = q[j+0] & 63; r[1] = q[j+4] & 63;
r[2] = q[j+1] & 63; r[3] = q[j+5] & 63;
} else {
r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
}
return r;
}

static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;

for (int i = 0; i < nb; i++) {

const float d = x[i].d;
const float min = x[i].dmin;

device const uint8_t * q = x[i].qs;
device const uint8_t * scales = x[i].scales;

int is = 0;
for (int j = 0; j < QK_K; j += 64) {
const uchar4 sc = get_scale_min_k4(is, scales);
const float d1 = d * sc[0]; const float m1 = min * sc[1];
const float d2 = d * sc[2]; const float m2 = min * sc[3];
for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
q += 32; is += 2;
}

}
}

kernel void kernel_get_rows_q4_k(
device const void * src0,
device const int * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant uint64_t & nb1,
uint tpig[[thread_position_in_grid]]) {
const int i = tpig;
const int r = ((device int32_t *) src1)[i];

dequantize_row_q4_k(
(device const block_q4_k *) ((device char *) src0 + r*nb01),
(device float *) ((device char *) dst + i*nb1), ne00);
}

kernel void kernel_mul_mat_q4_k_f32(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
threadgroup float * sum [[threadgroup(0)]],
uint2 tgpig[[threadgroup_position_in_grid]],
uint2 tpig[[thread_position_in_grid]], // we don't use this for now
uint2 tpitg[[thread_position_in_threadgroup]],
uint2 tptg[[threads_per_threadgroup]]) {

const int nb = ne00/QK_K;

const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;

device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
device const float * yy = (device const float *) src1 + r1*ne10;

const uint nth = tptg.x*tptg.y;
const uint ith = tptg.y*tpitg.x + tpitg.y;

const int tid = tpitg.y; // 0...16
const int il = tid/4; // 0...3
const int ir = tid%4; // 0...3
const int n = 8;
const int is = 2*il;

sum[ith] = 0.0f;

float sumf = 0;
for (int i = tpitg.x; i < nb; i += tptg.x) {

device const uint8_t * q = (x + i)->qs + 32*il + n*ir;
device const float * y = yy + i*QK_K + 64*il + n*ir;
device const uint8_t * scales = (x + i)->scales;

const float dall = (float)((x + i)->d);
const float dmin = (float)((x + i)->dmin);

const uchar4 sc = get_scale_min_k4(is, scales);

float4 s = {0.f, 0.f, 0.f, 0.f};
for (int l = 0; l < n; ++l) {
s[0] += y[l+ 0] * (q[l] & 0xF); s[1] += y[l+ 0];
s[2] += y[l+32] * (q[l] >> 4); s[3] += y[l+32];
}
sumf += dall * (s[0] * sc[0] + s[2] * sc[2]) - dmin * (s[1] * sc[1] + s[3] * sc[3]);

}
sum[ith] = sumf;

//
// Accumulate the sum from all threads in the threadgroup
// This version is slightly faster than the commented out one below,
// which I copy-pasted from ggerganov's q4_0 dot product for metal.
//
threadgroup_barrier(mem_flags::mem_threadgroup);
if (ith%4 == 0) {
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (ith%16 == 0) {
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (ith == 0) {
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
dst[r1*ne0 + r0] = sum[0];
}

//// accumulate the sum from all threads in the threadgroup
//threadgroup_barrier(mem_flags::mem_threadgroup);
//for (uint i = nth/2; i > 0; i /= 2) {
// if (ith < i) {
// sum[ith] += sum[ith + i];
// }
// threadgroup_barrier(mem_flags::mem_threadgroup);
//}

//if (ith == 0) {
// dst[r1*ne0 + r0] = sum[0];
//}
}