-
Notifications
You must be signed in to change notification settings - Fork 971
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
tests : experiments with n-bit quantized matrix multiplication
- Loading branch information
Showing
4 changed files
with
287 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,3 +9,4 @@ compile_commands.json | |
.DS_Store | ||
|
||
src/arm_neon.h | ||
tests/arm_neon.h |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,272 @@ | ||
// quantized matrix multiplication | ||
|
||
#include <float.h> | ||
#include <stdint.h> | ||
#include <stdio.h> | ||
#include <assert.h> | ||
#include <stdlib.h> | ||
#include <string.h> | ||
#include <time.h> | ||
#include <math.h> | ||
|
||
#include <sys/time.h> | ||
|
||
#ifdef __ARM_NEON | ||
#include "arm_neon.h" | ||
#endif | ||
|
||
#ifndef MIN | ||
#define MAX(a, b) ((a) > (b) ? (a) : (b)) | ||
#define MIN(a, b) ((a) < (b) ? (a) : (b)) | ||
#endif | ||
|
||
const int M = 1280; | ||
const int N = 1536; | ||
const int K = 1280; | ||
|
||
const int QK = 64; | ||
const int QB = 7; | ||
|
||
#define gq_t uint64_t | ||
#define gq_t_bits 64 | ||
|
||
uint64_t get_time_us() { | ||
struct timeval tv; | ||
gettimeofday(&tv, NULL); | ||
return tv.tv_sec * 1000000 + tv.tv_usec; | ||
} | ||
|
||
// | ||
// naive implementation | ||
// | ||
|
||
void mul_mat_vec_f32_0( | ||
const float * restrict src0, // M x K | ||
const float * restrict src1, // N x K (transposed) | ||
float * dst, | ||
int m, int n, int k) { | ||
for (int i = 0; i < m; i++) { | ||
for (int j = 0; j < n; j++) { | ||
float sum = 0; | ||
for (int l = 0; l < k; l++) { | ||
sum += src0[i*k + l] * src1[j*k + l]; | ||
} | ||
dst[i*n + j] = sum; | ||
} | ||
} | ||
} | ||
|
||
void quantize(const float * src, void * dst, int n, int k) { | ||
char * p0 = dst; | ||
|
||
for (int j = 0; j < n; j++) { | ||
for (int i = 0; i < k/QK; i++) { | ||
float min = FLT_MAX; | ||
float max = -FLT_MAX; | ||
|
||
// find min/max | ||
#ifdef __ARM_NEON | ||
{ | ||
float32x4_t minv = vdupq_n_f32(FLT_MAX); | ||
float32x4_t maxv = vdupq_n_f32(-FLT_MAX); | ||
|
||
for (int l = 0; l < QK; l += 4) { | ||
float32x4_t v = vld1q_f32(src + j*k + i*QK + l); | ||
minv = vminq_f32(minv, v); | ||
maxv = vmaxq_f32(maxv, v); | ||
} | ||
|
||
float32x2_t minv32 = vpmin_f32(vget_low_f32(minv), vget_high_f32(minv)); | ||
float32x2_t maxv32 = vpmax_f32(vget_low_f32(maxv), vget_high_f32(maxv)); | ||
|
||
min = MIN(vget_lane_f32(minv32, 0), vget_lane_f32(minv32, 1)); | ||
max = MAX(vget_lane_f32(maxv32, 0), vget_lane_f32(maxv32, 1)); | ||
|
||
//printf("SIMD min/max: %f %f\n", min, max); | ||
} | ||
#else | ||
{ | ||
for (int l = 0; l < QK; l++) { | ||
const float v = src[j*k + i*QK + l]; | ||
if (v < min) min = v; | ||
if (v > max) max = v; | ||
} | ||
|
||
//printf("NORM min/max: %f %f\n", min, max); | ||
} | ||
#endif | ||
|
||
const float d = (max - min) / ((1 << QB) - 1); | ||
const float id = d ? 1.0/d : 0.0; | ||
|
||
memcpy(p0, &min, sizeof(float)); p0 += sizeof(float); | ||
memcpy(p0, &d, sizeof(float)); p0 += sizeof(float); | ||
|
||
//printf("min/max/d/id: %f %f %f %f\n", min, max, d, id); | ||
|
||
for (int s = 0; s < QK/gq_t_bits; ++s) { | ||
gq_t pp[QB] = {0}; | ||
|
||
for (int l = 0; l < gq_t_bits; l++) { | ||
const float v = src[j*k + i*QK + s*gq_t_bits + l]; | ||
const uint8_t q = (v - min)*id; | ||
|
||
for (int b = 0; b < QB; b++) { | ||
pp[b] |= q & (1 << b) ? (1LL << l) : 0; | ||
} | ||
} | ||
|
||
for (int b = 0; b < QB; b++) { | ||
memcpy(p0, &pp[b], sizeof(gq_t)); p0 += sizeof(gq_t); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
void mul_mat_vec_gq_0( | ||
const void * src0, | ||
const void * src1, | ||
float * dst, | ||
int m, int n, int k) { | ||
const int kp = k & ~(gq_t_bits - 1); | ||
|
||
const char * restrict p0 = src0; | ||
const char * restrict p1 = src1; | ||
|
||
for (int ir0 = 0; ir0 < m; ir0++) { | ||
for (int ir1 = 0; ir1 < n; ir1++) { | ||
float sumf = 0.0; | ||
|
||
const char * restrict pp0 = p0 + ir0*((2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(k/QK)); | ||
const char * restrict pp1 = p1 + ir1*((2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(k/QK)); | ||
|
||
for (int i = 0; i < kp/QK; i++) { | ||
float min0, d0; | ||
memcpy(&min0, pp0, sizeof(float)); pp0 += sizeof(float); | ||
memcpy(&d0, pp0, sizeof(float)); pp0 += sizeof(float); | ||
|
||
float min1, d1; | ||
memcpy(&min1, pp1, sizeof(float)); pp1 += sizeof(float); | ||
memcpy(&d1, pp1, sizeof(float)); pp1 += sizeof(float); | ||
|
||
//printf("min0/d0 = %f %f | min1/d1 = %f %f\n", min0, d0, min1, d1); | ||
|
||
#if 1 | ||
// >>> General case for any QB | ||
|
||
float s0[QB + 1]; | ||
float s1[QB + 1]; | ||
|
||
s0[0] = min0; | ||
s1[0] = min1; | ||
|
||
for (int b = 0; b < QB; b++) { | ||
s0[b + 1] = d0*(1 << b); | ||
s1[b + 1] = d1*(1 << b); | ||
} | ||
|
||
gq_t m0[QB + 1]; | ||
gq_t m1[QB + 1]; | ||
|
||
m0[0] = -1LL; | ||
m1[0] = -1LL; | ||
|
||
for (int s = 0; s < QK/gq_t_bits; ++s) { | ||
for (int b = 0; b < QB; b++) { | ||
memcpy(&m0[b + 1], pp0, sizeof(gq_t)); pp0 += sizeof(gq_t); | ||
memcpy(&m1[b + 1], pp1, sizeof(gq_t)); pp1 += sizeof(gq_t); | ||
} | ||
|
||
for (int q0 = 0; q0 < QB + 1; q0++) { | ||
for (int q1 = 0; q1 < QB + 1; q1++) { | ||
sumf += s0[q0]*s1[q1]*__builtin_popcountll(m0[q0] & m1[q1]); | ||
} | ||
} | ||
} | ||
#else | ||
#endif | ||
} | ||
|
||
dst[ir0*n + ir1] = sumf; | ||
} | ||
} | ||
} | ||
|
||
int main(int argc, const char ** argv) { | ||
float * src0 = (float *)malloc(sizeof(float)*M*K); | ||
float * src1 = (float *)malloc(sizeof(float)*N*K); | ||
float * dst = (float *)malloc(sizeof(float)*M*N); | ||
|
||
for (int i = 0; i < M*K; i++) { | ||
src0[i] = rand() / (float)RAND_MAX; | ||
} | ||
|
||
for (int i = 0; i < N*K; i++) { | ||
src1[i] = rand() / (float)RAND_MAX; | ||
} | ||
|
||
void * src0_gq = calloc(1, (2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(K/QK)*M); | ||
void * src1_gq = calloc(1, (2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(K/QK)*N); | ||
|
||
const size_t sizef16 = sizeof(__fp16)*M*K + sizeof(__fp16)*N*K; | ||
const size_t sizegq = (2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(K/QK)*M + | ||
(2*sizeof(float) + (QK/gq_t_bits)*QB*sizeof(gq_t))*(K/QK)*N; | ||
|
||
printf("compression: %f\n", (float)sizegq/sizef16); | ||
|
||
// convert fp32 -> gq | ||
{ | ||
const uint64_t t_start = get_time_us(); | ||
|
||
quantize(src0, src0_gq, M, K); | ||
quantize(src1, src1_gq, N, K); | ||
|
||
const uint64_t t_end = get_time_us(); | ||
printf("convert time: %f ms\n", (t_end - t_start) / 1000.0); | ||
} | ||
|
||
int method = 0; | ||
if (argc > 1) { | ||
method = atoi(argv[1]); | ||
} | ||
|
||
const int nIter = 1; | ||
|
||
const clock_t start = clock(); | ||
const uint64_t start_us = get_time_us(); | ||
|
||
double iM = 1.0/M; | ||
double sum = 0.0f; | ||
for (int i = 0; i < nIter; i++) { | ||
if (method == 0) { | ||
mul_mat_vec_f32_0(src0, src1, dst, M, N, K); | ||
} | ||
|
||
if (method == 1) { | ||
mul_mat_vec_gq_0(src0_gq, src1_gq, dst, M, N, K); | ||
} | ||
} | ||
|
||
for (int i = 0; i < N; i++) { | ||
sum += dst[i]*iM; | ||
} | ||
|
||
{ | ||
const clock_t end = clock(); | ||
const uint64_t end_us = get_time_us(); | ||
printf("%s: elapsed ticks: %ld\n", __func__, end - start); | ||
printf("%s: elapsed us: %llu / %f ms\n", __func__, end_us - start_us, (end_us - start_us) / 1000.0 / nIter); | ||
} | ||
|
||
printf("%f\n", sum); | ||
|
||
free(src0); | ||
free(src1); | ||
free(dst); | ||
|
||
free(src0_gq); | ||
free(src1_gq); | ||
|
||
return 0; | ||
} |