Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix usage of F16C intrinsics in AVX code #563

Merged
merged 2 commits into from
Mar 28, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1122,13 +1122,36 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
#define GGML_F16_EPR 8

// F16 arithmetic is not supported by AVX, so we use F32 instead
// we take advantage of the _mm256_cvt intrinsics to convert F16 <-> F32

#define GGML_F32Cx8 __m256
#define GGML_F32Cx8_ZERO _mm256_setzero_ps()
#define GGML_F32Cx8_SET1(x) _mm256_set1_ps(x)

#if defined(__F16C__)
// the _mm256_cvt intrinsics require F16C
#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
#else
anzz1 marked this conversation as resolved.
Show resolved Hide resolved
static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
float tmp[8];

for (int i = 0; i < 8; i++)
tmp[i] = GGML_FP16_TO_FP32(x[i]);

return _mm256_loadu_ps(tmp);
}
static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
float arr[8];

_mm256_storeu_ps(arr, y);

for (int i = 0; i < 8; i++)
x[i] = GGML_FP16_TO_FP32(arr[i]);
slaren marked this conversation as resolved.
Show resolved Hide resolved
}
#define GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x)
#define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y)
#endif

#define GGML_F32Cx8_FMA GGML_F32x8_FMA
#define GGML_F32Cx8_ADD _mm256_add_ps
#define GGML_F32Cx8_MUL _mm256_mul_ps
Expand Down