Skip to content

Commit

Permalink
ggml : fix loongarch build (O2 issue) (llama/7636)
Browse files Browse the repository at this point in the history
  • Loading branch information
junchao-loongson authored and ggerganov committed Jun 15, 2024
1 parent 936fa11 commit 0acc2e5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
20 changes: 14 additions & 6 deletions src/ggml-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -6828,6 +6828,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r

int bit = 0;
int is = 0;
__m256i xvbit;

const uint8_t * restrict q3 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
Expand All @@ -6836,21 +6837,25 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
// load low 2 bits
const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32;

xvbit = __lasx_xvreplgr2vr_h(bit);
// prepare low and high bits
const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3);
const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
++bit;

xvbit = __lasx_xvreplgr2vr_h(bit);
const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3);
const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
++bit;

xvbit = __lasx_xvreplgr2vr_h(bit);
const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3);
const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
++bit;

xvbit = __lasx_xvreplgr2vr_h(bit);
const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3);
const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
++bit;

// load Q8 quants
Expand Down Expand Up @@ -8033,6 +8038,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
__m256i sumi = __lasx_xvldi(0);

int bit = 0;
__m256i xvbit;

for (int j = 0; j < QK_K/64; ++j) {

Expand All @@ -8041,13 +8047,15 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r

const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32;

xvbit = __lasx_xvreplgr2vr_h(bit++);
const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4);
const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvand_v(hbits, hmask), bit++), 4);
const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
const __m256i q5_0 = __lasx_xvadd_b(q5l_0, q5h_0);
hmask = __lasx_xvslli_h(hmask, 1);

xvbit = __lasx_xvreplgr2vr_h(bit++);
const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4);
const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvand_v(hbits, hmask), bit++), 4);
const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
const __m256i q5_1 = __lasx_xvadd_b(q5l_1, q5h_1);
hmask = __lasx_xvslli_h(hmask, 1);

Expand Down
2 changes: 1 addition & 1 deletion src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1580,7 +1580,7 @@ do { \
#define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))

static inline __m256 __lasx_f32cx8_load(ggml_fp16_t *x) {
static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t *x) {
float tmp[8];

for (int i = 0; i < 8; i++) {
Expand Down

0 comments on commit 0acc2e5

Please sign in to comment.