diff --git a/ggml.c b/ggml.c index 1556040b7..d8e1fbd4e 100644 --- a/ggml.c +++ b/ggml.c @@ -771,6 +771,40 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) { const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs); const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + sizeof(float)); +#if defined(__AVX2__) && QK % 32 == 0 + for (int i = 0; i < nb; i++) { + // scale factor + const __m256 d_v = _mm256_broadcast_ss((const float *) (pd + i*bs)); + + const uint8_t * restrict pp = pb + i*bs; + + for (int l = 0; l < QK; l += 32) { + // Load 32x4-bit integers into 32x8-bit integers + __m256i vx8 = bytesFromNibbles(pp+l/2); + + // Subtract 8 from the integers + vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8)); + + // Convert to 16-bit int + const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0)); + const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1)); + + // Convert to 32-bit int -> float 32 + const __m256 vf[4] = { + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))), + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))), + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))), + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1))) + }; + + // Scale and store + for (int j = 0; j < 4; j++) { + __m256 result = _mm256_mul_ps(vf[j], d_v); + _mm256_storeu_ps(y + i * QK + l + j*8, result); + } + } + } +#else // scalar for (int i = 0; i < nb; i++) { const float d = *(const float *) (pd + i*bs); @@ -795,6 +829,7 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) { assert(!isnan(y[i*QK + l + 1])); } } +#endif } void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {