From e1b79130616ece8fb4b336aecf709d38e06b54b8 Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Tue, 21 May 2024 02:02:25 +0000 Subject: [PATCH 1/2] backup --- ggml-sycl.cpp | 1015 ++++++++++++++--------------------------- ggml-sycl.h | 2 - ggml-sycl/backend.hpp | 1 - ggml-sycl/common.cpp | 116 ----- ggml-sycl/common.hpp | 200 ++++++-- ggml-sycl/convert.cpp | 2 +- ggml-sycl/dmmv.cpp | 58 ++- ggml-sycl/mmq.cpp | 20 +- ggml-sycl/pool.hpp | 233 ---------- ggml-sycl/presets.hpp | 1 + llama.cpp | 11 - 11 files changed, 555 insertions(+), 1104 deletions(-) delete mode 100644 ggml-sycl/pool.hpp diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index 307289bfa91ad..e63503a21733c 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -43,21 +43,13 @@ void * ggml_sycl_host_malloc(size_t size); void ggml_sycl_host_free(void * ptr); bool ggml_sycl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); void ggml_sycl_free_data(struct ggml_tensor * tensor); -void ggml_sycl_assign_buffers(struct ggml_tensor * tensor); -void ggml_sycl_assign_buffers_no_scratch(struct ggml_tensor * tensor); -void ggml_sycl_assign_buffers_force_inplace(struct ggml_tensor * tensor); -void ggml_sycl_assign_buffers_no_alloc(struct ggml_tensor * tensor); void ggml_sycl_copy_to_device(struct ggml_tensor * tensor); void ggml_sycl_set_main_device(int main_device); void ggml_sycl_set_mul_mat_q(bool mul_mat_q); -void ggml_sycl_set_scratch_size(size_t scratch_size); -void ggml_sycl_free_scratch(void); void ggml_sycl_get_device_description(int device, char * description, size_t description_size); bool ggml_backend_is_sycl(ggml_backend_t backend); int ggml_backend_sycl_get_device(ggml_backend_t backend); int get_main_device(); -void print_ggml_tensor(const char*name, struct ggml_tensor *src); -void log_tensor_with_cnt(const char* name, struct ggml_tensor * src, int stop_cnt); void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst, const void *ptr_src, size_t size) { @@ -67,39 +59,6 @@ void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst, free(host_buf); } -static __dpct_inline__ int get_int_from_int8(const int8_t *x8, const int &i32) { - const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment - - int x32 = 0; - x32 |= x16[0] << 0; - x32 |= x16[1] << 16; - - return x32; -} - -static __dpct_inline__ int get_int_from_uint8(const uint8_t *x8, - const int &i32) { - const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment - - int x32 = 0; - x32 |= x16[0] << 0; - x32 |= x16[1] << 16; - - return x32; -} - -static __dpct_inline__ int get_int_from_int8_aligned(const int8_t *x8, - const int &i32) { - return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment -} - -static __dpct_inline__ int get_int_from_uint8_aligned(const uint8_t *x8, - const int &i32) { - return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment -} - - -typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v); typedef void (*cpy_kernel_t)(const char * cx, char * cdst); typedef void (*ggml_sycl_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); typedef void (*ggml_sycl_op_mul_mat_t)( @@ -107,12 +66,12 @@ typedef void (*ggml_sycl_op_mul_mat_t)( const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, float *dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, - const dpct::queue_ptr &stream); + const queue_ptr &stream); typedef void (*ggml_sycl_op_flatten_t)(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream); + const queue_ptr &main_stream); static __dpct_inline__ float warp_reduce_sum(float x, @@ -1519,7 +1478,7 @@ template static void get_rows_sycl(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const void *src0_dd, const int32_t *src1_dd, float *dst_dd, - dpct::queue_ptr stream) { + queue_ptr stream) { GGML_TENSOR_BINARY_OP_LOCALS @@ -1554,7 +1513,7 @@ template static void get_rows_sycl_float(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const src0_t *src0_dd, const int32_t *src1_dd, - float *dst_dd, dpct::queue_ptr stream) { + float *dst_dd, queue_ptr stream) { GGML_TENSOR_BINARY_OP_LOCALS @@ -1594,7 +1553,7 @@ struct bin_bcast_sycl { void operator()(const struct ggml_tensor *src0, const struct ggml_tensor *src1, struct ggml_tensor *dst, const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd, - dpct::queue_ptr stream) { + queue_ptr stream) { GGML_TENSOR_BINARY_OP_LOCALS @@ -1731,7 +1690,7 @@ struct bin_bcast_sycl { static void acc_f32_sycl(const float *x, const float *y, float *dst, const int n_elements, const int ne10, const int ne11, const int ne12, const int nb1, const int nb2, - const int offset, dpct::queue_ptr stream) { + const int offset, queue_ptr stream) { int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE; stream->parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * @@ -1744,7 +1703,7 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst, } static void gelu_f32_sycl(const float *x, float *dst, const int k, - dpct::queue_ptr stream) { + queue_ptr stream) { const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; stream->parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * @@ -1756,7 +1715,7 @@ static void gelu_f32_sycl(const float *x, float *dst, const int k, } static void silu_f32_sycl(const float *x, float *dst, const int k, - dpct::queue_ptr stream) { + queue_ptr stream) { const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE; stream->parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * @@ -1768,7 +1727,7 @@ static void silu_f32_sycl(const float *x, float *dst, const int k, } static void gelu_quick_f32_sycl(const float *x, float *dst, const int k, - dpct::queue_ptr stream) { + queue_ptr stream) { const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; stream->parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * @@ -1780,7 +1739,7 @@ static void gelu_quick_f32_sycl(const float *x, float *dst, const int k, } static void tanh_f32_sycl(const float *x, float *dst, const int k, - dpct::queue_ptr stream) { + queue_ptr stream) { const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE; stream->parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * @@ -1792,7 +1751,7 @@ static void tanh_f32_sycl(const float *x, float *dst, const int k, } static void relu_f32_sycl(const float *x, float *dst, const int k, - dpct::queue_ptr stream) { + queue_ptr stream) { const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; stream->parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * @@ -1804,7 +1763,7 @@ static void relu_f32_sycl(const float *x, float *dst, const int k, } static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k, - dpct::queue_ptr stream) { + queue_ptr stream) { const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE; stream->parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * @@ -1816,7 +1775,7 @@ static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k, } static void hardswish_f32_sycl(const float *x, float *dst, const int k, - dpct::queue_ptr stream) { + queue_ptr stream) { const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE; stream->parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * @@ -1829,7 +1788,7 @@ static void hardswish_f32_sycl(const float *x, float *dst, const int k, static void leaky_relu_f32_sycl(const float *x, float *dst, const int k, const float negative_slope, - dpct::queue_ptr stream) { + queue_ptr stream) { const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; stream->parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * @@ -1841,7 +1800,7 @@ static void leaky_relu_f32_sycl(const float *x, float *dst, const int k, } static void sqr_f32_sycl(const float *x, float *dst, const int k, - dpct::queue_ptr stream) { + queue_ptr stream) { const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE; stream->parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * @@ -1854,7 +1813,7 @@ static void sqr_f32_sycl(const float *x, float *dst, const int k, static void norm_f32_sycl(const float *x, float *dst, const int ncols, const int nrows, const float eps, - dpct::queue_ptr stream) { + queue_ptr stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { const sycl::range<3> block_dims(1, 1, WARP_SIZE); @@ -1897,7 +1856,7 @@ static void norm_f32_sycl(const float *x, float *dst, const int ncols, static void group_norm_f32_sycl(const float *x, float *dst, const int num_groups, const int group_size, - const int ne_elements, dpct::queue_ptr stream) { + const int ne_elements, queue_ptr stream) { static const float eps = 1e-6f; if (group_size < 1024) { const sycl::range<3> block_dims(1, 1, WARP_SIZE); @@ -1947,7 +1906,7 @@ static void group_norm_f32_sycl(const float *x, float *dst, static void concat_f32_sycl(const float *x, const float *y, float *dst, const int ne0, int ne1, int ne2, int ne02, - dpct::queue_ptr stream) { + queue_ptr stream) { int num_blocks = (ne0 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE; sycl::range<3> gridDim(ne2, ne1, num_blocks); stream->parallel_for( @@ -1961,7 +1920,7 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst, static void upscale_f32_sycl(const float *x, float *dst, const int ne00, const int ne01, const int ne02, - const int scale_factor, dpct::queue_ptr stream) { + const int scale_factor, queue_ptr stream) { int ne0 = (ne00 * scale_factor); int num_blocks = (ne0 + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE; sycl::range<3> gridDim(ne02, (ne01 * scale_factor), num_blocks); @@ -1976,7 +1935,7 @@ static void upscale_f32_sycl(const float *x, float *dst, const int ne00, static void pad_f32_sycl(const float *x, float *dst, const int ne00, const int ne01, const int ne02, const int ne0, - const int ne1, const int ne2, dpct::queue_ptr stream) { + const int ne1, const int ne2, queue_ptr stream) { int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE; sycl::range<3> gridDim(ne2, ne1, num_blocks); stream->parallel_for( @@ -1989,7 +1948,7 @@ static void pad_f32_sycl(const float *x, float *dst, const int ne00, static void rms_norm_f32_sycl(const float *x, float *dst, const int ncols, const int nrows, const float eps, - dpct::queue_ptr stream) { + queue_ptr stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE); if (ncols < 1024) { @@ -2033,7 +1992,7 @@ static void rms_norm_f32_sycl(const float *x, float *dst, const int ncols, static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx, const int ky, const int kx_padded, - dpct::queue_ptr stream) { + queue_ptr stream) { const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE; const sycl::range<3> num_blocks(1, ky, block_num_x); const sycl::range<3> block_size(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE); @@ -2054,7 +2013,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y, const int nrows_x, const int nchannels_x, const int nchannels_y, - dpct::queue_ptr stream) { + queue_ptr stream) { const sycl::range<3> block_nums(nchannels_y, nrows_x, 1); const sycl::range<3> block_dims(1, 1, WARP_SIZE); @@ -2074,7 +2033,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y, static void ggml_mul_mat_vec_nc_f16_f32_sycl( const void *vx, const float *y, float *dst, const int ncols_x, const int nrows_x, const int row_stride_x, const int nchannels_x, - const int nchannels_y, const int channel_stride_x, dpct::queue_ptr stream) { + const int nchannels_y, const int channel_stride_x, queue_ptr stream) { const sycl::range<3> block_nums(nchannels_y, nrows_x, 1); const sycl::range<3> block_dims(1, 1, WARP_SIZE); @@ -2098,7 +2057,7 @@ ggml_cpy_f16_f32_sycl(const char *cx, char *cdst, const int ne, const int ne00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, - const int nb13, dpct::queue_ptr stream) { + const int nb13, queue_ptr stream) { const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; { @@ -2125,7 +2084,7 @@ static void ggml_cpy_f32_f32_sycl(const char *cx, char *cdst, const int ne, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - dpct::queue_ptr stream) { + queue_ptr stream) { const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; { @@ -2152,7 +2111,7 @@ static void ggml_cpy_f32_f16_sycl(const char *cx, char *cdst, const int ne, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - dpct::queue_ptr stream) { + queue_ptr stream) { const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; { @@ -2179,7 +2138,7 @@ static void ggml_cpy_f32_q8_0_sycl(const char *cx, char *cdst, const int ne, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - dpct::queue_ptr stream) { + queue_ptr stream) { GGML_ASSERT(ne % QK8_0 == 0); const int num_blocks = ne / QK8_0; @@ -2201,7 +2160,7 @@ static void ggml_cpy_f32_q4_0_sycl(const char *cx, char *cdst, const int ne, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - dpct::queue_ptr stream) { + queue_ptr stream) { GGML_ASSERT(ne % QK4_0 == 0); const int num_blocks = ne / QK4_0; @@ -2223,7 +2182,7 @@ static void ggml_cpy_f32_q4_1_sycl(const char *cx, char *cdst, const int ne, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - dpct::queue_ptr stream) { + queue_ptr stream) { GGML_ASSERT(ne % QK4_1 == 0); const int num_blocks = ne / QK4_1; @@ -2245,7 +2204,7 @@ static void ggml_cpy_f16_f16_sycl(const char *cx, char *cdst, const int ne, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - dpct::queue_ptr stream) { + queue_ptr stream) { const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; { @@ -2272,7 +2231,7 @@ static void ggml_cpy_i16_i16_sycl(const char *cx, char *cdst, const int ne, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - dpct::queue_ptr stream) { + queue_ptr stream) { const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; { @@ -2299,7 +2258,7 @@ static void ggml_cpy_i32_i32_sycl(const char *cx, char *cdst, const int ne, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - dpct::queue_ptr stream) { + queue_ptr stream) { const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE; { @@ -2319,7 +2278,7 @@ static void ggml_cpy_i32_i32_sycl(const char *cx, char *cdst, const int ne, } static void scale_f32_sycl(const float *x, float *dst, const float scale, - const int k, dpct::queue_ptr stream) { + const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE; stream->parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * @@ -2332,7 +2291,7 @@ static void scale_f32_sycl(const float *x, float *dst, const float scale, static void clamp_f32_sycl(const float *x, float *dst, const float min, const float max, const int k, - dpct::queue_ptr stream) { + queue_ptr stream) { const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE; stream->parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * @@ -2347,7 +2306,7 @@ template static void rope_sycl(const T *x, T *dst, int ncols, int nrows, const int32_t *pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, - rope_corr_dims corr_dims, dpct::queue_ptr stream) { + rope_corr_dims corr_dims, queue_ptr stream) { GGML_ASSERT(ncols % 2 == 0); const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE); @@ -2392,7 +2351,7 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows, const int32_t *pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, - dpct::queue_ptr stream) { + queue_ptr stream) { GGML_ASSERT(ncols % 2 == 0); const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE); @@ -2440,7 +2399,7 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows, static void rope_glm_f32_sycl(const float *x, float *dst, int ncols, int nrows, const int32_t *pos, float freq_scale, int p_delta_rows, float freq_base, int n_ctx, - dpct::queue_ptr stream) { + queue_ptr stream) { GGML_ASSERT(ncols % 4 == 0); const sycl::range<3> block_dims(1, 1, SYCL_ROPE_BLOCK_SIZE / 4); const int num_blocks_x = (ncols + SYCL_ROPE_BLOCK_SIZE - 1) / SYCL_ROPE_BLOCK_SIZE; @@ -2456,7 +2415,7 @@ static void rope_glm_f32_sycl(const float *x, float *dst, int ncols, int nrows, static void alibi_f32_sycl(const float *x, float *dst, const int ncols, const int nrows, const int k_rows, const int n_heads_log2_floor, const float m0, - const float m1, dpct::queue_ptr stream) { + const float m1, queue_ptr stream) { const sycl::range<3> block_dims(1, 1, SYCL_ALIBI_BLOCK_SIZE); const int num_blocks_x = (ncols + SYCL_ALIBI_BLOCK_SIZE - 1) / (SYCL_ALIBI_BLOCK_SIZE); const sycl::range<3> block_nums(1, nrows, num_blocks_x); @@ -2468,7 +2427,7 @@ static void alibi_f32_sycl(const float *x, float *dst, const int ncols, } static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols, - const int nrows, dpct::queue_ptr stream) { + const int nrows, queue_ptr stream) { const sycl::range<3> block_dims(1, 1, WARP_SIZE); const sycl::range<3> block_nums(1, nrows, 1); stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), @@ -2488,7 +2447,7 @@ static int next_power_of_2(int x) { static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols, const int nrows, ggml_sort_order order, - dpct::queue_ptr stream) { + queue_ptr stream) { // bitonic sort requires ncols to be power of 2 const int ncols_pad = next_power_of_2(ncols); @@ -2496,7 +2455,7 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols, const sycl::range<3> block_nums(1, nrows, 1); const size_t shared_mem = ncols_pad * sizeof(int); - // GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb); + // GGML_ASSERT(shared_mem <= ggml_sycl_info().devices[ggml_sycl_get_device()].smpb); if (order == GGML_SORT_ORDER_ASC) { stream->submit([&](sycl::handler &cgh) { @@ -2534,7 +2493,7 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols, static void diag_mask_inf_f32_sycl(const float *x, float *dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, - dpct::queue_ptr stream) { + queue_ptr stream) { const sycl::range<3> block_dims(1, SYCL_DIAG_MASK_INF_BLOCK_SIZE, 1); const int block_num_x = (ncols_x + SYCL_DIAG_MASK_INF_BLOCK_SIZE - 1) / SYCL_DIAG_MASK_INF_BLOCK_SIZE; const sycl::range<3> block_nums(1, block_num_x, nrows_x); @@ -2550,7 +2509,7 @@ template static void soft_max_f32_submitter(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims, - const size_t n_local_scratch, dpct::queue_ptr stream) { + const size_t n_local_scratch, queue_ptr stream) { stream->submit([&](sycl::handler &cgh) { sycl::local_accessor local_buf_acc(n_local_scratch, cgh); @@ -2568,7 +2527,7 @@ static void soft_max_f32_submitter(const float * x, const float * mask, const fl static void soft_max_f32_sycl(const float * x, const float * mask, const float * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, - dpct::queue_ptr stream) { + queue_ptr stream) { int nth = WARP_SIZE; int max_block_size = g_work_group_size; while (nth < ncols_x && nth < max_block_size) nth *= 2; @@ -2651,7 +2610,7 @@ static void im2col_sycl(const float *x, T *dst, int IW, int IH, int OW, int OH, int KW, int KH, int IC, int offset_delta, int s0, int s1, int p0, int p1, int d0, int d1, - dpct::queue_ptr stream) { + queue_ptr stream) { const int parallel_elements = OW * KW * KH; const int num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE; sycl::range<3> block_nums(IC, OH, num_blocks); @@ -2759,11 +2718,11 @@ int get_work_group_size(int user_device_id) { return prop.get_max_work_group_size(); } -static void ggml_init_sycl() try { +static void ggml_check_sycl() try { static bool initialized = false; if (!initialized) { - fprintf(stderr, "[SYCL] call ggml_init_sycl\n"); + fprintf(stderr, "[SYCL] call ggml_check_sycl\n"); g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0); fprintf(stderr, "%s: GGML_SYCL_DEBUG: %d\n", __func__, g_ggml_sycl_debug); @@ -2799,69 +2758,175 @@ catch (sycl::exception const &exc) { << ", line:" << __LINE__ << std::endl; std::exit(1); } +static ggml_sycl_device_info ggml_sycl_init() { + ggml_sycl_device_info info = {}; -void ggml_init_by_gpus(int device_count) try { - g_device_count = device_count; - g_work_group_size = g_sycl_gpu_mgr->work_group_size; - - int64_t total_vram = 0; + info.device_count = dpct::dev_mgr::instance().device_count(); + if (info.device_count == 0) { + fprintf(stderr, "%s: failed to initialize " GGML_SYCL_NAME ": %s\n", __func__); + return info; + } - print_gpu_device_list(); + GGML_ASSERT(info.device_count <= GGML_SYCL_MAX_DEVICES); - for (int id = 0; id < GGML_SYCL_MAX_DEVICES; ++id) { - g_device_caps[id].vmm = 0; - g_device_caps[id].device_id = -1; - g_device_caps[id].cc = 0; - g_tensor_split[id] = 0; - g_default_tensor_split[id] = 0; - } + int64_t total_vram = 0; +#if defined(GGML_SYCL_FORCE_MMQ) + fprintf(stderr, "%s: GGML_SYCL_FORCE_MMQ: yes\n", __func__); +#else + fprintf(stderr, "%s: GGML_SYCL_FORCE_MMQ: no\n", __func__); +#endif +#if defined(SYCL_USE_XMX) + fprintf(stderr, "%s: SYCL_USE_XMX: yes\n", __func__); +#else + fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__); +#endif + fprintf(stderr, "%s: found %d " GGML_SYCL_NAME " devices:\n", __func__, info.device_count); - for (int i = 0; i < g_device_count; ++i) { + for (int i = 0; i < info.device_count; ++i) { int device_id = g_sycl_gpu_mgr->gpus[i]; - g_device_caps[i].vmm = 0; + info.devices[i].vmm = 0; dpct::device_info prop; SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info( prop, dpct::dev_mgr::instance().get_device(device_id)))); - g_default_tensor_split[i] = total_vram; + info.default_tensor_split[i] = total_vram; total_vram += prop.get_global_mem_size(); - g_device_caps[i].cc = + info.devices[i].cc = 100 * prop.get_major_version() + 10 * prop.get_minor_version(); } - - for (int i = 0; i < g_device_count; ++i) { - g_default_tensor_split[i] /= total_vram; + for (int id = 0; id < info.device_count; ++id) { + info.default_tensor_split[id] /= total_vram; } - for (int i = 0; i < g_device_count; ++i) { - SYCL_CHECK(ggml_sycl_set_device(i)); + // configure logging to stdout + // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); - // create sycl streams - for (int is = 0; is < MAX_STREAMS; ++is) { - SYCL_CHECK(CHECK_TRY_ERROR( - g_syclStreams[i][is] = - dpct::get_current_device().create_queue( - g_sycl_gpu_mgr->get_co_ctx(), dpct::get_current_device()))); + return info; +} + +const ggml_sycl_device_info & ggml_sycl_info() { + static ggml_sycl_device_info info = ggml_sycl_init(); + return info; +} + +// buffer pool for sycl (legacy) +struct ggml_sycl_pool_leg : public ggml_sycl_pool { + static const int MAX_SYCL_BUFFERS = 256; + + int device; + queue_ptr qptr; + struct ggml_sycl_buffer { + void * ptr = nullptr; + size_t size = 0; + }; + + ggml_sycl_buffer buffer_pool[MAX_SYCL_BUFFERS] = {}; + size_t pool_size = 0; + + explicit ggml_sycl_pool_leg(queue_ptr qptr_, int device_) : + qptr(stream_) + device(device_) { + } + + ~ggml_sycl_pool_leg() { + for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { + ggml_sycl_buffer & b = buffer_pool[i]; + if (b.ptr != nullptr) { + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr))); + pool_size -= b.size; + } } + GGML_ASSERT(pool_size == 0); + } - const dpct::queue_ptr stream = g_syclStreams[i][0]; - // create sycl handle - SYCL_CHECK(CHECK_TRY_ERROR(g_sycl_handles[i] = stream)); + void * alloc(size_t size, size_t * actual_size) override { +#ifdef DEBUG_sycl_MALLOC + int nnz = 0; + size_t max_size = 0; +#endif + size_t best_diff = 1ull << 36; + int ibest = -1; + for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { + ggml_sycl_buffer& b = buffer_pool[i]; + if (b.ptr != nullptr) { +#ifdef DEBUG_sycl_MALLOC + ++nnz; + if (b.size > max_size) max_size = b.size; +#endif + if (b.size >= size) { + size_t diff = b.size - size; + if (diff < best_diff) { + best_diff = diff; + ibest = i; + if (!best_diff) { + void * ptr = b.ptr; + *actual_size = b.size; + b.ptr = nullptr; + b.size = 0; + return ptr; + } + } + } + } + } + if (ibest >= 0) { + ggml_sycl_buffer& b = buffer_pool[ibest]; + void * ptr = b.ptr; + *actual_size = b.size; + b.ptr = nullptr; + b.size = 0; + return ptr; + } + void * ptr; + size_t look_ahead_size = (size_t) (1.05 * size); + + SYCL_CHECK( + CHECK_TRY_ERROR(ptr = (void *)sycl::malloc_device( + look_ahead_size, *qptr))); + *actual_size = look_ahead_size; + pool_size += look_ahead_size; + + #ifdef DEBUG_SYCL_MALLOC + fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz, + (uint32_t)(max_size/1024/1024), (uint32_t)(g_sycl_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024)); + #endif + // GGML_SYCL_DEBUG("ggml_sycl_pool_malloc_leg look_ahead_size=%lu, return %p\n", look_ahead_size, ptr); + return ptr; + } + + void free(void * ptr, size_t size) override { + for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { + ggml_sycl_buffer& b = buffer_pool[i]; + if (b.ptr == nullptr) { + b.ptr = ptr; + b.size = size; + return; + } + } + fprintf(stderr, "WARNING: sycl buffer pool full, increase MAX_sycl_BUFFERS\n"); + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr))); + pool_size -= size; } +}; + +std::unique_ptr ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) { + // TBD: NO VMM support + // if (ggml_cuda_info().devices[device].vmm) { + // return std::unique_ptr(new ggml_cuda_pool_vmm(device)); + // } + return std::unique_ptr(new ggml_sycl_pool_leg(qptr, device)); } -catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} + +// TBD pool with virtual memory management +// struct ggml_sycl_pool_vmm : public ggml_sycl_pool static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst, const struct ggml_tensor *src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, - dpct::queue_ptr stream) try { + queue_ptr stream) try { dpct::memcpy_direction kind; char * src_ptr; @@ -2930,7 +2995,7 @@ catch (sycl::exception const &exc) { static void ggml_sycl_op_get_rows(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_d, const float *src1_d, - float *dst_d, const dpct::queue_ptr &stream) { + float *dst_d, const queue_ptr &stream) { GGML_ASSERT(src1->type == GGML_TYPE_I32); GGML_ASSERT(dst->type == GGML_TYPE_F32); @@ -2977,7 +3042,7 @@ inline void ggml_sycl_op_bin_bcast(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { op()(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); @@ -3004,7 +3069,7 @@ static void ggml_sycl_op_repeat(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_d, const float *src1_d, float *dst_d, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { ggml_sycl_op_bin_bcast>(dst, src0, dst, nullptr, src0_d, dst_d, main_stream); @@ -3015,7 +3080,7 @@ static void ggml_sycl_op_repeat(const ggml_tensor *src0, inline void ggml_sycl_op_add(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { ggml_sycl_op_bin_bcast>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); } @@ -3023,7 +3088,7 @@ inline void ggml_sycl_op_add(const ggml_tensor *src0, const ggml_tensor *src1, inline void ggml_sycl_op_acc(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); @@ -3043,7 +3108,7 @@ inline void ggml_sycl_op_acc(const ggml_tensor *src0, const ggml_tensor *src1, inline void ggml_sycl_op_mul(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { ggml_sycl_op_bin_bcast>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); } @@ -3051,7 +3116,7 @@ inline void ggml_sycl_op_mul(const ggml_tensor *src0, const ggml_tensor *src1, inline void ggml_sycl_op_div(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { ggml_sycl_op_bin_bcast>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); } @@ -3059,7 +3124,7 @@ inline void ggml_sycl_op_div(const ggml_tensor *src0, const ggml_tensor *src1, inline void ggml_sycl_op_gelu(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -3074,7 +3139,7 @@ inline void ggml_sycl_op_gelu(const ggml_tensor *src0, const ggml_tensor *src1, inline void ggml_sycl_op_silu(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -3090,7 +3155,7 @@ inline void ggml_sycl_op_gelu_quick(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -3105,7 +3170,7 @@ inline void ggml_sycl_op_gelu_quick(const ggml_tensor *src0, inline void ggml_sycl_op_tanh(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -3119,7 +3184,7 @@ inline void ggml_sycl_op_tanh(const ggml_tensor *src0, const ggml_tensor *src1, inline void ggml_sycl_op_relu(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -3135,7 +3200,7 @@ static void ggml_sycl_op_hardsigmoid(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -3150,7 +3215,7 @@ static void ggml_sycl_op_hardsigmoid(const ggml_tensor *src0, static void ggml_sycl_op_hardswish(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, - float *dst_dd, const dpct::queue_ptr &main_stream) { + float *dst_dd, const queue_ptr &main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -3166,7 +3231,7 @@ inline void ggml_sycl_op_leaky_relu(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -3184,7 +3249,7 @@ inline void ggml_sycl_op_leaky_relu(const ggml_tensor *src0, inline void ggml_sycl_op_sqr(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -3199,7 +3264,7 @@ inline void ggml_sycl_op_sqr(const ggml_tensor *src0, const ggml_tensor *src1, inline void ggml_sycl_op_norm(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -3221,7 +3286,7 @@ inline void ggml_sycl_op_group_norm(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -3239,7 +3304,7 @@ inline void ggml_sycl_op_concat(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); @@ -3257,7 +3322,7 @@ inline void ggml_sycl_op_upscale(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); @@ -3275,7 +3340,7 @@ inline void ggml_sycl_op_upscale(const ggml_tensor *src0, inline void ggml_sycl_op_pad(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); @@ -3294,7 +3359,7 @@ inline void ggml_sycl_op_rms_norm(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -3317,11 +3382,11 @@ static int64_t get_row_rounding(ggml_type type, const std::array g_device_caps[i].cc) { - min_compute_capability = g_device_caps[i].cc; + if (min_compute_capability > ggml_sycl_info().devices[i].cc) { + min_compute_capability = ggml_sycl_info().devices[i].cc; } - if (max_compute_capability < g_device_caps[i].cc) { - max_compute_capability = g_device_caps[i].cc; + if (max_compute_capability < ggml_sycl_info().devices[i].cc) { + max_compute_capability = ggml_sycl_info().devices[i].cc; } } } @@ -3365,7 +3430,7 @@ inline void ggml_sycl_op_mul_mat_sycl( const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, float *dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, - const dpct::queue_ptr &stream) try { + const queue_ptr &stream) try { GGML_ASSERT(src0_dd_i != nullptr); GGML_ASSERT(src1_ddf_i != nullptr); @@ -3396,7 +3461,7 @@ inline void ggml_sycl_op_mul_mat_sycl( dst->op_params[0] == GGML_PREC_DEFAULT) { // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp16 path\n"); - sycl_pool_alloc src0_as_f16; + ggml_sycl_pool_alloc src0_as_f16; if (src0->type != GGML_TYPE_F16) { const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type); GGML_ASSERT(to_fp16_sycl != nullptr); @@ -3408,7 +3473,7 @@ inline void ggml_sycl_op_mul_mat_sycl( ? (const sycl::half *)src0_dd_i : src0_as_f16.get(); - sycl_pool_alloc src1_as_f16; + ggml_sycl_pool_alloc src1_as_f16; if (src1->type != GGML_TYPE_F16) { const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type); GGML_ASSERT(to_fp16_sycl != nullptr); @@ -3419,7 +3484,7 @@ inline void ggml_sycl_op_mul_mat_sycl( const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16 ? (const sycl::half *)src1->data + src1_padded_row_size : src1_as_f16.get(); - sycl_pool_alloc dst_f16(row_diff * src1_ncols); + ggml_sycl_pool_alloc dst_f16(row_diff * src1_ncols); const sycl::half alpha_f16 = 1.0f; const sycl::half beta_f16 = 0.0f; @@ -3437,8 +3502,8 @@ inline void ggml_sycl_op_mul_mat_sycl( } else { // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n"); - sycl_pool_alloc src0_ddq_as_f32; - sycl_pool_alloc src1_ddq_as_f32; + ggml_sycl_pool_alloc src0_ddq_as_f32; + ggml_sycl_pool_alloc src1_ddq_as_f32; if (src0->type != GGML_TYPE_F32) { const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type); GGML_ASSERT(to_fp32_sycl != nullptr); @@ -3479,7 +3544,7 @@ catch (sycl::exception const &exc) { inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); @@ -3559,7 +3624,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, inline void ggml_sycl_op_alibi(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -3589,7 +3654,7 @@ inline void ggml_sycl_op_alibi(const ggml_tensor *src0, const ggml_tensor *src1, static void ggml_sycl_op_pool2d(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, - float *dst_dd, const dpct::queue_ptr &main_stream) { + float *dst_dd, const queue_ptr &main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -3632,7 +3697,7 @@ inline void ggml_sycl_op_im2col(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); @@ -3673,7 +3738,7 @@ inline void ggml_sycl_op_sum_rows(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -3692,7 +3757,7 @@ inline void ggml_sycl_op_argsort(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_I32); @@ -3713,7 +3778,7 @@ inline void ggml_sycl_op_diag_mask_inf(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -3735,7 +3800,7 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -3759,7 +3824,7 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0, // positions tensor float * src2_dd = nullptr; - sycl_pool_alloc src2_f; + ggml_sycl_pool_alloc src2_f; const bool use_src2 = src2 != nullptr; @@ -3782,7 +3847,7 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0, inline void ggml_sycl_op_scale(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -3805,7 +3870,7 @@ inline void ggml_sycl_op_scale(const ggml_tensor *src0, const ggml_tensor *src1, inline void ggml_sycl_op_clamp(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { + const queue_ptr &main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -3827,7 +3892,8 @@ inline void ggml_sycl_op_clamp(const ggml_tensor *src0, const ggml_tensor *src1, (void) src1_dd; } -static void ggml_sycl_op_flatten(const ggml_tensor *src0, +static void ggml_sycl_op_flatten(const ggml_sycl_backend_context & ctx, + const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const ggml_sycl_op_flatten_t op) try { const int64_t nrows0 = ggml_nrows(src0); @@ -3851,12 +3917,12 @@ static void ggml_sycl_op_flatten(const ggml_tensor *src0, float * src1_ddf = nullptr; float * dst_ddf = nullptr; - sycl_pool_alloc src0_f; - sycl_pool_alloc src1_f; - sycl_pool_alloc dst_f; + ggml_sycl_pool_alloc src0_f; + ggml_sycl_pool_alloc src1_f; + ggml_sycl_pool_alloc dst_f; ggml_sycl_set_device(g_main_device); - dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0]; + queue_ptr main_stream = ctx.stream(g_main_device, 0); // GGML_SYCL_DEBUG("g_main_device=%d, main_stream=%p src0_on_device=%d, src1_on_device=%d, dst_on_device=%d\n", // g_main_device, main_stream, src0_on_device, src1_on_device, dst_on_device); @@ -4012,10 +4078,10 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0, } struct dev_data { - sycl_pool_alloc src0_dd_alloc; - sycl_pool_alloc src1_ddf_alloc; - sycl_pool_alloc src1_ddq_alloc; - sycl_pool_alloc dst_dd_alloc; + ggml_sycl_pool_alloc src0_dd_alloc; + ggml_sycl_pool_alloc src1_ddf_alloc; + ggml_sycl_pool_alloc src1_ddq_alloc; + ggml_sycl_pool_alloc dst_dd_alloc; char *src0_dd = nullptr; float *src1_ddf = nullptr; // float @@ -4029,7 +4095,8 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0, dev_data dev[GGML_SYCL_MAX_DEVICES]; int used_devices = 0; - dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0]; + auto ctx = ggml_sycl_info(); + queue_ptr main_stream = ctx.stream(g_main_device, 0); for (int i = 0; i < g_device_count; ++i) { // by default, use all rows @@ -4068,7 +4135,7 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0, const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU && i == g_main_device; ggml_sycl_set_device(i); - dpct::queue_ptr stream = g_syclStreams[i][0]; + queue_ptr main_stream = ctx.stream(i, 0); if (src0_on_device && src0_is_contiguous) { dev[i].src0_dd = (char *) src0_extra->data_device[i]; @@ -4115,7 +4182,7 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0, */ SYCL_CHECK(CHECK_TRY_ERROR( *src0_extra->events[g_main_device][0] = - g_syclStreams[g_main_device][0]->ext_oneapi_submit_barrier())); + ctx.stream(g_main_device, 0)->ext_oneapi_submit_barrier())); } const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11; @@ -4133,7 +4200,7 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0, const int64_t row_diff = dev[i].row_high - dev[i].row_low; ggml_sycl_set_device(i); - dpct::queue_ptr stream = g_syclStreams[i][is]; + queue_ptr stream = xtx.stream(i, is); // wait for main GPU data if necessary if (split && (i != g_main_device || is != 0)) { @@ -4298,7 +4365,7 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0, } for (int64_t is = 0; is < is_max; ++is) { SYCL_CHECK(CHECK_TRY_ERROR( - g_syclStreams[g_main_device][0]->ext_oneapi_submit_barrier( + stx.stream(g_main_device, 0)->ext_oneapi_submit_barrier( {*src0_extra->events[i][is]}))); } } @@ -4459,7 +4526,8 @@ bool ggml_sycl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te (ne0 >= 32 && ne1 >= 32 && ne10 >= 32); } -static void ggml_sycl_mul_mat_vec_p021(const ggml_tensor *src0, +static void ggml_sycl_mul_mat_vec_p021(ggml_backend_sycl_context & ctx, + const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst) try { GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); @@ -4476,7 +4544,7 @@ static void ggml_sycl_mul_mat_vec_p021(const ggml_tensor *src0, const int64_t ne12 = src1->ne[2]; SYCL_CHECK(ggml_sycl_set_device(g_main_device)); - dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0]; + queue_ptr main_stream = ctx.stream(g_main_device, 0); ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; void * src0_ddq = src0_extra->data_device[g_main_device]; @@ -4515,7 +4583,8 @@ static void ggml_sycl_mul_mat_vec_nc(const ggml_tensor *src0, const int64_t ne12 = src1->ne[2]; SYCL_CHECK(ggml_sycl_set_device(g_main_device)); - dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0]; + auto ctx = ggml_sycl_info(); + queue_ptr main_stream = ctx.stream(g_main_device, 0); ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; void * src0_ddq = src0_extra->data_device[g_main_device]; @@ -4575,9 +4644,10 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0, const int64_t ne_dst = ggml_nelements(dst); SYCL_CHECK(ggml_sycl_set_device(g_main_device)); - dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0]; + auto ctx = ggml_sycl_info(); + queue_ptr main_stream = ctx.stream(g_main_device, 0); - bool no_mixed_dtypes = main_stream->get_backend() == sycl::backend::ext_oneapi_cuda || + bool no_mixed_dtypes = main_stream->get_backend() == sycl::backend::ext_oneapi_sycl || main_stream->get_backend() == sycl::backend::ext_oneapi_hip; SYCL_CHECK( @@ -4594,7 +4664,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0, float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; // convert src1 to fp16 - sycl_pool_alloc src1_f16_alloc; + ggml_sycl_pool_alloc src1_f16_alloc; if (src1->type != GGML_TYPE_F16) { const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type); const int64_t ne_src1 = ggml_nelements(src1); @@ -4605,7 +4675,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0, sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf : src1_f16_alloc.get(); - sycl_pool_alloc dst_f16; + ggml_sycl_pool_alloc dst_f16; char * dst_t; dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float; @@ -4685,8 +4755,8 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0, } else { const int ne23 = ne12*ne13; - sycl_pool_alloc ptrs_src(2*ne23); - sycl_pool_alloc< void *> ptrs_dst(1*ne23); + ggml_sycl_pool_alloc ptrs_src(2*ne23); + ggml_sycl_pool_alloc< void *> ptrs_dst(1*ne23); sycl::range<3> block_dims(1, ne12, ne13); /* @@ -4749,8 +4819,8 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 int64_t min_compute_capability = INT_MAX; for (int i = 0; i < g_device_count; ++i) { - if (min_compute_capability > g_device_caps[i].cc && g_tensor_split[i] < (i + 1 < g_device_count ? g_tensor_split[i + 1] : 1.0f)) { - min_compute_capability = g_device_caps[i].cc; + if (min_compute_capability > ggml_sycl_info().devices[i].cc && g_tensor_split[i] < (i + 1 < g_device_count ? g_tensor_split[i + 1] : 1.0f)) { + min_compute_capability = ggml_sycl_info().devices[i].cc; } } @@ -4826,186 +4896,14 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 } } -#if 0 -template -static __global__ void k_compute_batched_ptrs_id( - const void ** ptrs_src, void ** ptrs_dst, - int ne12, int ne13, - int ne23, - int nb02, int nb03, - int nb12, int nb13, - int nb2, int nb3, - int r2, int r3, - ggml_type src0_type, half * src0_as_f16, int64_t src0_ne, - const half * src1_f16, half * dst_f16, - const int32_t * ids, const int id, - Srcs... src0s) { - - int i = ids[id]; - - half * src0_f16; - const void * srcs_ar[] = { (const half *) src0s... }; - if (src0_type == GGML_TYPE_F16) { - src0_f16 = (half *) srcs_ar[i]; - } else { - src0_f16 = src0_as_f16; - if (item_ct1.get_local_id(2) == 0 && threadIdx.y == 0) { - const to_fp16_sycl_t to_fp16 = ggml_get_to_fp16_sycl(src0_type); - to_fp16(srcs_ar[i], src0_f16, src0_ne, syclStreamFireAndForget); - } - } - - int i13 = blockIdx.x * blockDim.x + item_ct1.get_local_id(2); - int i12 = blockIdx.y * blockDim.y + threadIdx.y; - - if (i13 >= ne13 || i12 >= ne12) { - return; - } - - int i03 = i13 / r3; - int i02 = i12 / r2; - - ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_f16 + i02*nb02 + i03*nb03; - ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_f16 + i12*nb12/2 + i13*nb13/2; - ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2; -} - -static void ggml_sycl_mul_mat_id_sycl(ggml_tensor * dst) { - const struct ggml_tensor * ids = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - const struct ggml_tensor * src00 = dst->src[2]; - - const int id = dst->op_params[0]; - - GGML_ASSERT(!ggml_is_transposed(src00)); - GGML_ASSERT(!ggml_is_transposed(src1)); - - GGML_ASSERT(src00->backend != GGML_BACKEND_TYPE_GPU_SPLIT); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - GGML_TENSOR_LOCALS(int64_t, ne0, src00, ne); - - //const int64_t nb01 = src00->nb[1]; - GGML_TENSOR_LOCALS(int64_t, nb0, src00, nb); - - GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); - - GGML_TENSOR_LOCALS(int64_t, nb1, src1, nb); - //const int64_t nb11 = src1->nb[1]; - - const int64_t ne1 = ggml_nelements(src1); - const int64_t ne = ggml_nelements(dst); - - SYCL_CHECK(ggml_sycl_set_device(g_main_device)); - syclStream_t main_stream = g_syclStreams[g_main_device][0]; - - SYCL_CHECK(syclSetStream(g_sycl_handles[g_main_device], main_stream)); - - //ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; - //void * src0_ddq = src0_extra->data_device[g_main_device]; - //half * src0_as_f16 = (half *) src0_ddq; - - ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; - float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; - - ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; - float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; - - // convert src1 to fp16 - const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type); - GGML_ASSERT(to_fp16_sycl != nullptr); - - size_t src1_as = 0; - half * src1_as_f16 = (half *) ggml_sycl_pool_malloc(g_main_device, ne1 * sizeof(half), &src1_as); - to_fp16_sycl(src1_ddf, src1_as_f16, ne1, main_stream); - - size_t dst_as = 0; - half * dst_f16 = (half *) ggml_sycl_pool_malloc(g_main_device, ne * sizeof(half), &dst_as); - - GGML_ASSERT(ne12 % ne02 == 0); - GGML_ASSERT(ne13 % ne03 == 0); - - // broadcast factors - const int64_t r2 = ne12/ne02; - const int64_t r3 = ne13/ne03; - - const half alpha_f16 = 1.0f; - const half beta_f16 = 0.0f; - - // use syclGemmBatchedEx - const int ne23 = ne12*ne13; - - const void ** ptrs_src = nullptr; - void ** ptrs_dst = nullptr; - - size_t ptrs_src_s = 0; - size_t ptrs_dst_s = 0; - - ptrs_src = (const void **) ggml_sycl_pool_malloc(g_main_device, 2*ne23*sizeof(void *), &ptrs_src_s); - ptrs_dst = ( void **) ggml_sycl_pool_malloc(g_main_device, 1*ne23*sizeof(void *), &ptrs_dst_s); - - int64_t src0_ne = ggml_nelements(src00); - half * src0_as_f16 = nullptr; - size_t src0_as = 0; - if (src00->type != GGML_TYPE_F16) { - src0_as_f16 = (half *) ggml_sycl_pool_malloc(g_main_device, src0_ne * sizeof(half), &src0_as); - } - - static_assert(GGML_MAX_SRC == 6, "GGML_MAX_SRC == 6"); - dim3 block_dims(ne13, ne12); - k_compute_batched_ptrs_id<<<1, block_dims, 0, main_stream>>>( - ptrs_src, ptrs_dst, - ne12, ne13, - ne23, - ne00*ne01*sizeof(half), ne00*ne01*ne02*sizeof(half), - nb12, nb13, - dst->nb[2], dst->nb[3], - r2, r3, - src00->type, src0_as_f16, src0_ne, - src1_as_f16, dst_f16, - (const int *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device], id, - dst->src[2] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[2]->extra)->data_device[g_main_device] : nullptr, - dst->src[3] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[3]->extra)->data_device[g_main_device] : nullptr, - dst->src[4] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[4]->extra)->data_device[g_main_device] : nullptr, - dst->src[5] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[5]->extra)->data_device[g_main_device] : nullptr - ); - SYCL_CHECK(syclGetLastError()); - - SYCL_CHECK( - syclGemmBatchedEx(g_sycl_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, - ne01, ne11, ne10, - &alpha_f16, (const void **) (ptrs_src + 0*ne23), SYCL_R_16F, ne00, - (const void **) (ptrs_src + 1*ne23), SYCL_R_16F, ne10, - &beta_f16, ( void **) (ptrs_dst + 0*ne23), SYCL_R_16F, ne01, - ne23, - CUBLAS_COMPUTE_16F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - if (src0_as != 0) { - ggml_sycl_pool_free(g_main_device, src0_as_f16, src0_as); - } - if (ptrs_src_s != 0) { - ggml_sycl_pool_free(g_main_device, ptrs_src, ptrs_src_s); - } - if (ptrs_dst_s != 0) { - ggml_sycl_pool_free(g_main_device, ptrs_dst, ptrs_dst_s); - } - - const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16); - to_fp32_sycl(dst_f16, dst_ddf, ne, main_stream); - - ggml_sycl_pool_free(g_main_device, src1_as_f16, src1_as); - ggml_sycl_pool_free(g_main_device, dst_f16, dst_as); -} -#endif - -static void ggml_sycl_mul_mat_id(const ggml_tensor *src0, +static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, + const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst) try { GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT && "mul_mat_id does not support split buffers"); const ggml_tensor *ids = dst->src[2]; - const dpct::queue_ptr stream = g_syclStreams[g_main_device][0]; + const queue_ptr stream = ctx.streams(g_main_device, 0); const size_t nb11 = src1->nb[1]; const size_t nb1 = dst->nb[1]; @@ -5074,8 +4972,8 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0, ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row); } } else { - sycl_pool_alloc src1_contiguous(sizeof(float)*ggml_nelements(src1)); - sycl_pool_alloc dst_contiguous(sizeof(float)*ggml_nelements(dst)); + ggml_sycl_pool_alloc src1_contiguous(sizeof(float)*ggml_nelements(src1)); + ggml_sycl_pool_alloc dst_contiguous(sizeof(float)*ggml_nelements(dst)); src1_row_extra.data_device[g_main_device] = src1_contiguous.get(); dst_row_extra.data_device[g_main_device] = dst_contiguous.get(); @@ -5153,7 +5051,8 @@ static void ggml_sycl_clamp(const ggml_tensor * src0, const ggml_tensor * src1, ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_clamp); } -static void ggml_sycl_cpy(const ggml_tensor *src0, const ggml_tensor *src1, +static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, + const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst) try { const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); @@ -5167,7 +5066,7 @@ static void ggml_sycl_cpy(const ggml_tensor *src0, const ggml_tensor *src1, GGML_TENSOR_BINARY_OP_LOCALS; SYCL_CHECK(ggml_sycl_set_device(g_main_device)); - dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0]; + queue_ptr main_stream = ctx.streams(g_main_device, 0); const ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; const ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; @@ -5260,172 +5159,9 @@ static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_spl return nrows_split*ggml_row_size(tensor->type, tensor->ne[0]); } -void ggml_sycl_free_data(struct ggml_tensor *tensor) try { - if (!tensor || !tensor->extra || (tensor->backend != GGML_BACKEND_TYPE_GPU && tensor->backend != GGML_BACKEND_TYPE_GPU_SPLIT) ) { - return; - } - - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; - - for (int i = 0; i < g_device_count; ++i) { - const dpct::queue_ptr stream = g_syclStreams[i][0]; - if (extra->data_device[i] != nullptr) { - SYCL_CHECK(ggml_sycl_set_device(i)); - SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(extra->data_device[i], *stream))); - } - - for (int64_t is = 0; is < MAX_STREAMS; ++is) { - if (extra->events[i][is] != nullptr) { - SYCL_CHECK(ggml_sycl_set_device(i)); - SYCL_CHECK(CHECK_TRY_ERROR( - dpct::destroy_event(extra->events[i][is]))); - } - } - } - - delete extra; -} -catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} - static ggml_tensor_extra_gpu * g_temp_tensor_extras = nullptr; static size_t g_temp_tensor_extra_index = 0; -static ggml_tensor_extra_gpu * ggml_sycl_alloc_temp_tensor_extra() { - if (g_temp_tensor_extras == nullptr) { - g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_SYCL_MAX_NODES]; - } - - size_t alloc_index = g_temp_tensor_extra_index; - g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_SYCL_MAX_NODES; - ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index]; - memset(extra, 0, sizeof(*extra)); - - return extra; -} - -static void ggml_sycl_assign_buffers_impl(struct ggml_tensor *tensor, - bool scratch, bool force_inplace, - bool no_alloc) try { - if (scratch && g_scratch_size == 0) { - return; - } - - tensor->backend = GGML_BACKEND_TYPE_GPU; - - if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_TYPE_CPU) { - const ggml_op src0_op = tensor->src[0]->op; - if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW || src0_op == GGML_OP_PERMUTE) { - ggml_sycl_assign_buffers_impl(tensor->src[0], scratch, force_inplace, no_alloc); - } - } - if (tensor->op == GGML_OP_CPY && tensor->src[1]->backend == GGML_BACKEND_TYPE_CPU) { - ggml_sycl_assign_buffers_impl(tensor->src[1], scratch, force_inplace, no_alloc); - } - - if (scratch && no_alloc) { - return; - } - - ggml_tensor_extra_gpu * extra; - - const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) || - tensor->op == GGML_OP_VIEW || - force_inplace; - const size_t size = ggml_nbytes(tensor); - - SYCL_CHECK(ggml_sycl_set_device(g_main_device)); - const dpct::queue_ptr stream = g_syclStreams[g_main_device][0]; - - if (inplace && (tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU || tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT)) { - ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra; - char * src0_ddc = (char *) src0_extra->data_device[g_main_device]; - size_t offset = 0; - if (tensor->op == GGML_OP_VIEW) { - memcpy(&offset, tensor->op_params, sizeof(size_t)); - } - extra = ggml_sycl_alloc_temp_tensor_extra(); - extra->data_device[g_main_device] = src0_ddc + offset; - } else if (tensor->op == GGML_OP_CPY) { - ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) tensor->src[1]->extra; - void * src1_ddv = src1_extra->data_device[g_main_device]; - extra = ggml_sycl_alloc_temp_tensor_extra(); - extra->data_device[g_main_device] = src1_ddv; - } else if (scratch) { - GGML_ASSERT(size <= g_scratch_size); - if (g_scratch_offset + size > g_scratch_size) { - g_scratch_offset = 0; - } - - char * data = (char *) g_scratch_buffer; - if (data == nullptr) { - SYCL_CHECK(CHECK_TRY_ERROR( - data = (char *)sycl::malloc_device( - g_scratch_size, *stream))); - g_scratch_buffer = data; - } - extra = ggml_sycl_alloc_temp_tensor_extra(); - extra->data_device[g_main_device] = data + g_scratch_offset; - - g_scratch_offset += size; - - GGML_ASSERT(g_scratch_offset <= g_scratch_size); - } else { // allocate new buffers outside of scratch - void * data; - SYCL_CHECK(CHECK_TRY_ERROR(data = (void *)sycl::malloc_device( - size, *stream))); - SYCL_CHECK(CHECK_TRY_ERROR( - (*stream).memset(data, 0, size).wait())); - extra = new ggml_tensor_extra_gpu; - memset(extra, 0, sizeof(*extra)); - extra->data_device[g_main_device] = data; - } - - tensor->extra = extra; -} -catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} - -void ggml_sycl_copy_to_device(struct ggml_tensor *tensor) try { - GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU); - GGML_ASSERT(ggml_is_contiguous(tensor)); - - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; - SYCL_CHECK(ggml_sycl_set_device(g_main_device)); - const dpct::queue_ptr stream = g_syclStreams[g_main_device][0]; - SYCL_CHECK(CHECK_TRY_ERROR((*stream) - .memcpy(extra->data_device[g_main_device], - tensor->data, ggml_nbytes(tensor)) - .wait())); -} -catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} - -void ggml_sycl_assign_buffers(struct ggml_tensor * tensor) { - ggml_sycl_assign_buffers_impl(tensor, true, false, false); -} - -void ggml_sycl_assign_buffers_no_alloc(struct ggml_tensor * tensor) { - ggml_sycl_assign_buffers_impl(tensor, true, false, true); -} - -void ggml_sycl_assign_buffers_no_scratch(struct ggml_tensor * tensor) { - ggml_sycl_assign_buffers_impl(tensor, false, false, false); -} - -void ggml_sycl_assign_buffers_force_inplace(struct ggml_tensor * tensor) { - ggml_sycl_assign_buffers_impl(tensor, false, true, false); -} - void ggml_sycl_set_main_device(const int main_device) try { if (g_main_device == main_device) return; check_allow_gpu_index(main_device); @@ -5446,33 +5182,7 @@ catch (sycl::exception const &exc) { std::exit(1); } -void ggml_sycl_set_scratch_size(const size_t scratch_size) { - // this is a hack to not completely break llama.cpp when using multiple models or contexts simultaneously - // it still won't always work as expected, but it's better than nothing - if (scratch_size > g_scratch_size) { - ggml_sycl_free_scratch(); - } - g_scratch_size = std::max(g_scratch_size, scratch_size); -} - -void ggml_sycl_free_scratch() try { - if (g_scratch_buffer == nullptr) { - return; - } - ggml_sycl_set_device(g_main_device); - const dpct::queue_ptr stream = g_syclStreams[g_main_device][0]; - - SYCL_CHECK(CHECK_TRY_ERROR( - sycl::free(g_scratch_buffer, *stream))); - g_scratch_buffer = nullptr; -} -catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} - -bool ggml_sycl_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { +bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * tensor) { if (!g_sycl_loaded) return false; ggml_sycl_func_t func; @@ -5629,12 +5339,6 @@ bool ggml_sycl_compute_forward(struct ggml_compute_params * params, struct ggml_ ggml_sycl_set_peer_access(tensor->src[1]->ne[1]); } - if (params->ith != 0) { - return true; - } - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return true; - } func(tensor->src[0], tensor->src[1], tensor); return true; } @@ -5699,7 +5403,7 @@ GGML_CALL void ggml_backend_sycl_get_device_memory(int device, size_t *free, inserted. You need to rewrite this code. */ /* - DPCT1106:217: 'cudaMemGetInfo' was migrated with the Intel extensions for + DPCT1106:217: 'syclMemGetInfo' was migrated with the Intel extensions for device information which may not be supported by all compilers or runtimes. You may need to adjust the code. */ @@ -5724,8 +5428,7 @@ catch (sycl::exception const &exc) { struct ggml_backend_sycl_buffer_context { int device; void * dev_ptr = nullptr; - ggml_tensor_extra_gpu * temp_tensor_extras = nullptr; - size_t temp_tensor_extra_index = 0; + queue_ptr stream; std::string name; ggml_backend_sycl_buffer_context(int device, void * dev_ptr) : @@ -5735,21 +5438,11 @@ struct ggml_backend_sycl_buffer_context { name = (GGML_SYCL_NAME + std::to_string(id)); } - ~ ggml_backend_sycl_buffer_context() { - delete[] temp_tensor_extras; - } - - ggml_tensor_extra_gpu * ggml_sycl_alloc_temp_tensor_extra() { - if (temp_tensor_extras == nullptr) { - temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_SYCL_MAX_NODES]; + ~ggml_backend_sycl_buffer_context() { + if (dev_ptr != nullptr) { + ggml_sycl_set_device(device); + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream))); } - - size_t alloc_index = temp_tensor_extra_index; - temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_SYCL_MAX_NODES; - ggml_tensor_extra_gpu * extra = &temp_tensor_extras[alloc_index]; - memset(extra, 0, sizeof(*extra)); - - return extra; } }; @@ -5766,10 +5459,6 @@ static void ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try { ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; ggml_sycl_set_device(ctx->device); - const dpct::queue_ptr stream = g_syclStreams[ctx->device][0]; - - SYCL_CHECK( - CHECK_TRY_ERROR(sycl::free(ctx->dev_ptr, *stream))); delete ctx; } catch (sycl::exception const &exc) { @@ -5795,19 +5484,13 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer, return; } - ggml_tensor_extra_gpu * extra = ctx->ggml_sycl_alloc_temp_tensor_extra(); - - extra->data_device[ctx->device] = tensor->data; - tensor->backend = GGML_BACKEND_TYPE_GPU; - tensor->extra = extra; - if (ggml_is_quantized(tensor->type)) { // initialize padding to 0 to avoid possible NaN values size_t original_size = ggml_nbytes(tensor); size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor); if (padded_size > original_size && tensor->view_src == nullptr) { - SYCL_CHECK(CHECK_TRY_ERROR(g_syclStreams[ctx->device][0]->memset( + SYCL_CHECK(CHECK_TRY_ERROR((*(dpct->stream))->memset( (char *)tensor->data + original_size, 0, padded_size - original_size).wait())); } @@ -5828,14 +5511,13 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; ggml_sycl_set_device(ctx->device); - const dpct::queue_ptr stream = g_syclStreams[ctx->device][0]; + auto stream = dpct::get_in_order_queue(); SYCL_CHECK( CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw())); char* host_buf = (char*)malloc(size); memcpy(host_buf, data, size); SYCL_CHECK( - CHECK_TRY_ERROR((*stream) - .memcpy((char *)tensor->data + offset, host_buf, size) + CHECK_TRY_ERROR(stream.memcpy((char *)tensor->data + offset, host_buf, size) .wait())); free(host_buf); } @@ -5854,14 +5536,13 @@ static void ggml_backend_sycl_buffer_get_tensor(ggml_backend_buffer_t buffer, ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; ggml_sycl_set_device(ctx->device); - const dpct::queue_ptr stream = g_syclStreams[ctx->device][0]; + auto stream = dpct::get_in_order_queue(); SYCL_CHECK( CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw())); SYCL_CHECK(CHECK_TRY_ERROR( - (*stream) - .memcpy(data, (const char *)tensor->data + offset, size) + stream.memcpy(data, (const char *)tensor->data + offset, size) .wait())); } catch (sycl::exception const &exc) { @@ -5899,9 +5580,9 @@ ggml_backend_sycl_buffer_cpy_tensor(ggml_backend_buffer_t buffer, error codes. The original code was commented out and a warning string was inserted. You need to rewrite this code. */ - - dpct::queue_ptr stream_dst = g_syclStreams[dst_ctx->device][0]; - dpct::queue_ptr stream_src = g_syclStreams[src_ctx->device][0]; + auto backend_ctx = ggml_sycl_info(); + queue_ptr stream_dst = backend_ctx.stream(dst_ctx->device, 0); + queue_ptr stream_src = backend_ctx.stream(src_ctx->device, 0); size_t size = ggml_nbytes(src); //todo. it's dirty solutino to walkaroud known issue:device2device cross GPUs. @@ -5936,7 +5617,8 @@ static void ggml_backend_sycl_buffer_clear(ggml_backend_buffer_t buffer, ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; ggml_sycl_set_device(ctx->device); - const dpct::queue_ptr stream = g_syclStreams[ctx->device][0]; + auto backend_ctx = ggml_sycl_info(); + queue_ptr stream = ctx->stream; SYCL_CHECK( CHECK_TRY_ERROR(dpct::get_current_device().queues_wait_and_throw())); @@ -5966,11 +5648,8 @@ static struct ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = { struct ggml_backend_sycl_buffer_type_context { int device; std::string name; -}; - -struct ggml_backend_sycl_context { - int device; - std::string name; + // each buffer type has its own stream + queue_ptr stream = nullptr; }; GGML_CALL static const char * ggml_backend_sycl_buffer_type_name(ggml_backend_buffer_type_t buft) { @@ -5983,13 +5662,14 @@ ggml_backend_sycl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) try { ggml_backend_sycl_buffer_type_context * buft_ctx = (ggml_backend_sycl_buffer_type_context *)buft->context; ggml_sycl_set_device(buft_ctx->device); - const dpct::queue_ptr stream = g_syclStreams[buft_ctx->device][0]; + auto backend_ctx = ggml_sycl_info(); + const queue_ptr stream = backend_ctx.stream(buft_ctx->device, 0); size = std::max(size, (size_t)1); // syclMalloc returns null for size 0 void * dev_ptr; SYCL_CHECK(CHECK_TRY_ERROR(dev_ptr = (void *)sycl::malloc_device( size, *stream))); - ggml_backend_sycl_buffer_context * ctx = new ggml_backend_sycl_buffer_context(buft_ctx->device, dev_ptr); + ggml_backend_sycl_buffer_context * ctx = new ggml_backend_sycl_buffer_context(buft_ctx->device, dev_ptr, buft_ctx->stream); return ggml_backend_buffer_init(buft, ggml_backend_sycl_buffer_interface, ctx, size); } catch (sycl::exception const &exc) { @@ -6043,13 +5723,13 @@ static ggml_backend_buffer_type_i ggml_backend_sycl_buffer_type_interface = { /* .is_host = */ nullptr, }; -ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device_index) { +ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context* ctx) { GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n"); - if (device_index>=g_device_count or device_index<0) { + if (ctx->device>=g_device_count or device_index<0) { printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n", - device_index, g_device_count-1); - GGML_ASSERT(device_indexdevice, g_device_count-1); + GGML_ASSERT(ctx->devicegpus[i])}, + /* .context = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(g_sycl_gpu_mgr->gpus[i]), ctx.stream(g_sycl_gpu_mgr->gpus[i], 0)}, }; } g_ggml_backend_sycl_buffer_type_initialized = true; @@ -6104,9 +5784,10 @@ struct ggml_backend_sycl_split_buffer_context { and a warning string was inserted. You need to rewrite this code. */ - ggml_sycl_set_device(i); + ggml_sycl_set_device(i); + const queue_ptr stream = backend_ctx.stream(i, 0); SYCL_CHECK(CHECK_TRY_ERROR(sycl::free( - extra->data_device[i], *g_syclStreams[i][0]))); + extra->data_device[i], *stream))); } } delete extra; @@ -6176,9 +5857,11 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer, size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); } - // FIXME: do not crash if cudaMalloc fails + // FIXME: do not crash if syclMalloc fails // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first ggml_sycl_set_device(i); + auto backend_ctx = ggml_sycl_info(); + const queue_ptr stream = backend_ctx.stream(i, 0); char * buf; /* DPCT1009:208: SYCL uses exceptions to report errors and does not use the @@ -6186,7 +5869,7 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer, was inserted. You need to rewrite this code. */ SYCL_CHECK(CHECK_TRY_ERROR(buf = (char *)sycl::malloc_device( - size, *g_syclStreams[i][0]))); + size, *stream))); // set padding to 0 to avoid possible NaN values if (size > original_size) { @@ -6196,7 +5879,7 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer, string was inserted. You need to rewrite this code. */ SYCL_CHECK(CHECK_TRY_ERROR( - (*g_syclStreams[i][0]) + (*stream) .memset(buf + original_size, 0, size - original_size) .wait())); } @@ -6262,8 +5945,10 @@ ggml_backend_sycl_split_buffer_set_tensor(ggml_backend_buffer_t buffer, was inserted. You need to rewrite this code. */ ggml_sycl_set_device(i); + auto backend_ctx = ggml_sycl_info(); + const queue_ptr stream = backend_ctx.stream(i, 0); SYCL_CHECK(CHECK_TRY_ERROR( - (*g_syclStreams[i][0]) + (*stream) .memcpy(extra->data_device[i], buf_host, original_size) .wait())); } @@ -6314,8 +5999,10 @@ ggml_backend_sycl_split_buffer_get_tensor(ggml_backend_buffer_t buffer, was inserted. You need to rewrite this code. */ ggml_sycl_set_device(i); + auto backend_ctx = ggml_sycl_info(); + const queue_ptr stream = backend_ctx.stream(i, 0); SYCL_CHECK(CHECK_TRY_ERROR( - (*g_syclStreams[i][0]) + (*stream) .memcpy(buf_host, extra->data_device[i], original_size) .wait())); } @@ -6416,7 +6103,7 @@ static ggml_backend_buffer_type_i ggml_backend_sycl_split_buffer_type_interface GGML_CALL ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split) { GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_split_buffer_type\n"); - ggml_init_sycl(); + ggml_check_sycl(); // FIXME: this is not thread safe static std::map, struct ggml_backend_buffer_type> buft_map; @@ -6524,7 +6211,8 @@ GGML_CALL static void ggml_backend_sycl_free(ggml_backend_t backend) { GGML_CALL static ggml_backend_buffer_type_t ggml_backend_sycl_get_default_buffer_type(ggml_backend_t backend) { ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; - return ggml_backend_sycl_buffer_type(sycl_ctx->device); + // TODO: pass context + return ggml_backend_sycl_buffer_type(sycl_ctx); } GGML_CALL static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend, @@ -6532,9 +6220,11 @@ GGML_CALL static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend, const void *data, size_t offset, size_t size) try { ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; - GGML_ASSERT(tensor->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type"); + GGML_ASSERT(tensor->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx) && "unsupported buffer type"); GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU); - SYCL_CHECK(CHECK_TRY_ERROR(g_syclStreams[sycl_ctx->device][0]->memcpy( + auto backend_ctx = backend->context; + const queue_ptr stream = backend_ctx.stream(sycl_ctx->device, 0); + SYCL_CHECK(CHECK_TRY_ERROR((*stream)->memcpy( (char *)tensor->data + offset, data, size).wait())); } catch (sycl::exception const &exc) { @@ -6548,9 +6238,11 @@ GGML_CALL static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend, void *data, size_t offset, size_t size) try { ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; - GGML_ASSERT(tensor->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type"); + GGML_ASSERT(tensor->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx) && "unsupported buffer type"); GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU); - SYCL_CHECK(CHECK_TRY_ERROR(g_syclStreams[sycl_ctx->device][0]->memcpy( + auto backend_ctx = backend->context; + const queue_ptr stream = backend_ctx.stream(sycl_ctx->device, 0); + SYCL_CHECK(CHECK_TRY_ERROR((*stream)->memcpy( data, (const char *)tensor->data + offset, size).wait())); } catch (sycl::exception const &exc) { @@ -6563,13 +6255,15 @@ GGML_CALL static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor *src, ggml_tensor *dst) try { ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; - if (dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && ggml_backend_buffer_is_sycl(src->buffer)) { + if (dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx) && ggml_backend_buffer_is_sycl(src->buffer)) { /* DPCT1009:215: SYCL uses exceptions to report errors and does not use the error codes. The original code was commented out and a warning string was inserted. You need to rewrite this code. */ - SYCL_CHECK(CHECK_TRY_ERROR(g_syclStreams[sycl_ctx->device][0]->memcpy( + auto backend_ctx = backend->context; + const queue_ptr stream = backend_ctx.stream(sycl_ctx->device, 0); + SYCL_CHECK(CHECK_TRY_ERROR((*stream)->memcpy( dst->data, src->data, ggml_nbytes(dst)).wait())); return true; } @@ -6584,7 +6278,8 @@ catch (sycl::exception const &exc) { static void ggml_backend_sycl_synchronize(ggml_backend_t backend) try { ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; - SYCL_CHECK(CHECK_TRY_ERROR(g_syclStreams[sycl_ctx->device][0]->wait())); + const queue_ptr stream = sycl_ctx.stream(sycl_ctx->device, 0); + SYCL_CHECK(CHECK_TRY_ERROR((*stream)->wait())); UNUSED(backend); } @@ -6608,13 +6303,13 @@ GGML_CALL static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t back } #ifndef NDEBUG assert(node->backend == GGML_BACKEND_TYPE_GPU || node->backend == GGML_BACKEND_TYPE_GPU_SPLIT); - assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device)); + assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx)); assert(node->extra != nullptr); for (int j = 0; j < GGML_MAX_SRC; j++) { if (node->src[j] != nullptr) { assert(node->src[j]->backend == GGML_BACKEND_TYPE_GPU || node->src[j]->backend == GGML_BACKEND_TYPE_GPU_SPLIT); - assert(node->src[j]->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device)); + assert(node->src[j]->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx)); assert(node->src[j]->extra != nullptr); } } @@ -6791,17 +6486,18 @@ static ggml_guid_t ggml_backend_sycl_guid() { GGML_CALL ggml_backend_t ggml_backend_sycl_init(int device) { GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_init\n"); - ggml_init_sycl(); + ggml_check_sycl(); check_allow_gpu_index(device); // not strictly necessary, but it may reduce the overhead of the first graph_compute ggml_sycl_set_main_device(device); int id = g_sycl_gpu_mgr->gpus[device]; - ggml_backend_sycl_context * ctx = new ggml_backend_sycl_context { - /* .device = */ device, - /* .name = */ GGML_SYCL_NAME + std::to_string(id), - }; + ggml_backend_sycl_context * ctx = new ggml_backend_sycl_context(device); + if (ctx == nullptr) { + fprintf(stderr, "%s: error: failed to allocate context\n", __func__); + return nullptr; + } ggml_backend_t sycl_backend = new ggml_backend { /* .guid = */ ggml_backend_sycl_guid(), @@ -6839,44 +6535,9 @@ GGML_API GGML_CALL int ggml_backend_sycl_get_device_id(int device_index) { return g_sycl_gpu_mgr->gpus[device_index]; } -GGML_API GGML_CALL void ggml_backend_sycl_set_single_device_mode(int main_gpu_id) { - ggml_init_sycl(); - GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_set_single_device_mode\n"); - fprintf(stderr, "ggml_backend_sycl_set_single_device: use single device: [%d]\n", main_gpu_id); - GGML_ASSERT(main_gpu_idget_gpu_count()); - g_ggml_backend_sycl_buffer_type_initialized = false; -} - -GGML_API GGML_CALL void ggml_backend_sycl_set_mul_device_mode() { - ggml_init_sycl(); - GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_set_mul_device_mode\n"); - - if (g_ggml_sycl_backend_gpu_mode == SYCL_MUL_GPU_MODE) { - return; - } - - fprintf(stderr, "ggml_backend_sycl_set_mul_device_mode: true\n"); - - if (g_sycl_gpu_mgr) { - delete g_sycl_gpu_mgr; - } - g_sycl_gpu_mgr = new sycl_gpu_mgr(); - g_ggml_sycl_backend_gpu_mode = SYCL_MUL_GPU_MODE; - ggml_init_by_gpus(g_sycl_gpu_mgr->get_gpu_count()); - g_ggml_backend_sycl_buffer_type_initialized = false; -} - extern "C" int ggml_backend_sycl_reg_devices(); int ggml_backend_sycl_reg_devices() { - ggml_backend_sycl_set_mul_device_mode(); assert(g_device_count>0); for (int i = 0; i < g_device_count; i++) { int id = g_sycl_gpu_mgr->gpus[i]; diff --git a/ggml-sycl.h b/ggml-sycl.h index a63f65d8c5a7a..a37f27eaf39ee 100644 --- a/ggml-sycl.h +++ b/ggml-sycl.h @@ -36,8 +36,6 @@ GGML_API GGML_CALL int ggml_backend_sycl_get_device_index(int device_id); // TODO: these are temporary // ref: https://github.com/ggerganov/llama.cpp/pull/6022#issuecomment-1992615670 GGML_API GGML_CALL int ggml_backend_sycl_get_device_id(int device_index); -GGML_API GGML_CALL void ggml_backend_sycl_set_single_device_mode(int main_gpu_id); -GGML_API GGML_CALL void ggml_backend_sycl_set_mul_device_mode(); // SYCL doesn't support registering host memory, keep here for reference // GGML_API GGML_CALL bool ggml_backend_sycl_register_host_buffer(void * buffer, size_t size); diff --git a/ggml-sycl/backend.hpp b/ggml-sycl/backend.hpp index cd3481e41601e..2d37e271f9050 100644 --- a/ggml-sycl/backend.hpp +++ b/ggml-sycl/backend.hpp @@ -19,6 +19,5 @@ #include "dmmv.hpp" #include "mmq.hpp" #include "mmvq.hpp" -#include "pool.hpp" #endif // GGML_SYCL_BACKEND_HPP diff --git a/ggml-sycl/common.cpp b/ggml-sycl/common.cpp index 29c1e8e833080..0320749e27863 100644 --- a/ggml-sycl/common.cpp +++ b/ggml-sycl/common.cpp @@ -20,122 +20,6 @@ int get_current_device_id() { return dpct::dev_mgr::instance().current_device_id(); } -void log_ggml_var_device( - const char* name, - float* src, - size_t total_elements, - bool src_on_device) { - if (!g_ggml_sycl_debug) - return; - if (!src) { - printf("GGML Tensor:%s skip to save for NULL pointer\n", name); - return; - } - char filename[1024]; - sprintf(filename, "%s.txt", name); - printf("GGML Tensor:%s save to %s\n", name, filename); - - size_t total_size = total_elements * sizeof(float); - float* local_buf = NULL; - if (src_on_device) { - local_buf = (float*)ggml_sycl_host_malloc(total_size); - ggml_sycl_set_device(g_main_device); - dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0]; - main_stream->memcpy(local_buf, src, total_size).wait(); - } else { - local_buf = (float*)src; - } - - std::ofstream logfile; - logfile.open(filename); - for (size_t i = 0; i < total_elements; i++) { - logfile << local_buf[i] << " "; - if ((i + 1) % 20 == 0) - logfile << std::endl; - } - logfile << std::endl; - logfile.close(); - - if (src_on_device) - ggml_sycl_host_free(local_buf); -} - -void log_ggml_var_device_fp16( - const char* name, - sycl::half* src, - size_t total_elements, - bool src_on_device) { - if (!g_ggml_sycl_debug) - return; - if (!src) { - printf("GGML Tensor:%s skip to save for NULL pointer\n", name); - return; - } - char filename[1024]; - sprintf(filename, "%s.txt", name); - printf("GGML Tensor:%s save to %s\n", name, filename); - - size_t total_size = total_elements * sizeof(sycl::half); - sycl::half* local_buf = NULL; - if (src_on_device) { - local_buf = (sycl::half*)ggml_sycl_host_malloc(total_size); - ggml_sycl_set_device(g_main_device); - dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0]; - main_stream->memcpy(local_buf, src, total_size).wait(); - } else { - local_buf = (sycl::half*)src; - } - - std::ofstream logfile; - logfile.open(filename); - for (size_t i = 0; i < total_elements; i++) { - logfile << local_buf[i] << " "; - if ((i + 1) % 20 == 0) - logfile << std::endl; - } - logfile << std::endl; - logfile.close(); - - if (src_on_device) - ggml_sycl_host_free(local_buf); -} - -void print_ggml_tensor(const char* name, struct ggml_tensor* src) { - if (!g_ggml_sycl_debug) - return; - if (!src) { - printf("GGML Tensor:%s skip to save for NULL pointer\n", name); - return; - } - - size_t total_elements = ggml_nelements(src); - - const bool src_on_device = src->backend == GGML_BACKEND_TYPE_GPU || - src->backend == GGML_BACKEND_TYPE_GPU_SPLIT; - float* src_data = NULL; - if (src_on_device) { - ggml_tensor_extra_gpu* src_extra = (ggml_tensor_extra_gpu*)src->extra; - src_data = (float*)src_extra->data_device[g_main_device]; - } else { - src_data = (float*)src->data; - } - - log_ggml_var_device(name, src_data, total_elements, src_on_device); -} - -void log_tensor_with_cnt( - const char* name, - struct ggml_tensor* src, - int stop_cnt) { - stop_cnt = 4; - if (log_file_name_idx >= stop_cnt) - return; - char filename[1280]; - sprintf(filename, "%s_%07d", name, log_file_name_idx); - log_file_name_idx++; - print_ggml_tensor(filename, src); -} - void* ggml_sycl_host_malloc(size_t size) try { if (getenv("GGML_SYCL_NO_PINNED") != nullptr) { return nullptr; diff --git a/ggml-sycl/common.hpp b/ggml-sycl/common.hpp index 2ce5f92f49dd8..febea98947d76 100644 --- a/ggml-sycl/common.hpp +++ b/ggml-sycl/common.hpp @@ -78,6 +78,9 @@ static int g_work_group_size = 0; #define GGML_SYCL_MMV_Y 1 #endif +typedef sycl::queue *queue_ptr; +typedef sycl::handler *handle_ptr; + enum ggml_sycl_backend_gpu_mode { SYCL_UNSET_GPU_MODE = -1, SYCL_SINGLE_GPU_MODE = 0, @@ -182,17 +185,6 @@ static_assert( #endif // GGML_SYCL_PEER_MAX_BATCH_SIZE #define MUL_MAT_SRC1_COL_STRIDE 128 -#define MAX_STREAMS 8 -#define SYCL_MAX_DEVICES 48 - -static dpct::queue_ptr g_syclStreams[SYCL_MAX_DEVICES][MAX_STREAMS] = {{0}}; - -struct ggml_tensor_extra_gpu { - void* data_device[SYCL_MAX_DEVICES]; // 1 pointer for each device for split - // tensors - dpct::event_ptr events[SYCL_MAX_DEVICES] - [MAX_STREAMS]; // events for synchronizing multiple GPUs -}; class sycl_gpu_mgr { public: @@ -320,7 +312,7 @@ class sycl_gpu_mgr { } }; -static sycl_gpu_mgr* g_sycl_gpu_mgr = NULL; +static sycl_gpu_mgr* g_sycl_gpu_mgr = new sycl_gpu_mgr(0); static int g_device_count = -1; static int g_all_sycl_device_count = -1; static int g_main_device = -1; @@ -329,31 +321,15 @@ static bool g_ggml_backend_sycl_buffer_type_initialized = false; static std::array g_default_tensor_split = {}; -static float g_tensor_split[SYCL_MAX_DEVICES] = {0}; +static float g_tensor_split[GGML_SYCL_MAX_DEVICES] = {0}; static ggml_sycl_backend_gpu_mode g_ggml_sycl_backend_gpu_mode = SYCL_UNSET_GPU_MODE; -struct sycl_device_capabilities { - int cc; // compute capability - bool vmm; // virtual memory support - size_t vmm_granularity; // granularity of virtual memory - int device_id; -}; - -static sycl_device_capabilities g_device_caps[SYCL_MAX_DEVICES] = { - {0, false, 0, -1}}; - -struct sycl_device_id2index { - int index; -}; - static void* g_scratch_buffer = nullptr; static size_t g_scratch_size = 0; // disabled by default static size_t g_scratch_offset = 0; -static dpct::queue_ptr g_sycl_handles[SYCL_MAX_DEVICES] = {nullptr}; - int get_main_device(); [[noreturn]] static inline void bad_arch(const sycl::stream& stream_ct1) { @@ -427,25 +403,151 @@ inline dpct::err0 ggml_sycl_set_device(const int device) try { std::exit(1); } -void log_ggml_var_device( - const char* name, - float* src, - size_t total_elements, - bool src_on_device); - -void log_ggml_var_device_fp16( - const char* name, - sycl::half* src, - size_t total_elements, - bool src_on_device); - -// todo: debug for crash in some case -void print_ggml_tensor(const char* name, struct ggml_tensor* src); - -static int log_file_name_idx = 0; -void log_tensor_with_cnt( - const char* name, - struct ggml_tensor* src, - int stop_cnt); +////////////////////// + +struct ggml_sycl_device_info { + int device_count; + + struct sycl_device_info { + int cc; // compute capability + // int nsm; // number of streaming multiprocessors + // size_t smpb; // max. shared memory per block + bool vmm; // virtual memory support + size_t total_vram; + }; + + sycl_device_info devices[GGML_SYCL_MAX_DEVICES] = {}; + + std::array default_tensor_split = {}; +}; + +const ggml_sycl_device_info & ggml_sycl_info(); + +struct ggml_sycl_pool { + virtual ~ggml_sycl_pool() = default; + + virtual void * alloc(size_t size, size_t * actual_size) = 0; + virtual void free(void * ptr, size_t size) = 0; +}; + +template +struct ggml_sycl_pool_alloc { + ggml_sycl_pool * pool = nullptr; + T * ptr = nullptr; + size_t actual_size = 0; + + explicit ggml_sycl_pool_alloc(ggml_sycl_pool & pool) : pool(&pool) { + } + + ggml_sycl_pool_alloc(ggml_sycl_pool & pool, size_t size) : pool(&pool) { + alloc(size); + } + + ~ggml_sycl_pool_alloc() { + if (ptr != nullptr) { + pool->free(ptr, actual_size); + } + } + + // size is in number of elements + T * alloc(size_t size) { + GGML_ASSERT(pool != nullptr); + GGML_ASSERT(ptr == nullptr); + ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size); + return ptr; + } + + T * alloc(ggml_sycl_pool & pool, size_t size) { + this->pool = &pool; + return alloc(size); + } + + T * get() { + return ptr; + } + + ggml_sycl_pool_alloc() = default; + ggml_sycl_pool_alloc(const ggml_sycl_pool_alloc &) = delete; + ggml_sycl_pool_alloc(ggml_sycl_pool_alloc &&) = delete; + ggml_sycl_pool_alloc& operator=(const ggml_sycl_pool_alloc &) = delete; + ggml_sycl_pool_alloc& operator=(ggml_sycl_pool_alloc &&) = delete; +}; + +// backend interface + +struct ggml_tensor_extra_gpu { + void* data_device[GGML_SYCL_MAX_DEVICES]; // 1 pointer for each device for split + // tensors + dpct::event_ptr events[GGML_SYCL_MAX_DEVICES] + [GGML_SYCL_MAX_STREAMS]; // events for synchronizing multiple GPUs +}; + +struct ggml_backend_sycl_context { + int device; + std::string name; + + queue_ptr qptrs[GGML_SYCL_MAX_DEVICES][GGML_SYCL_MAX_STREAMS] = { { nullptr } }; + static sycl::handler * sycl_handles[GGML_SYCL_MAX_DEVICES] = {nullptr}; + + explicit ggml_backend_sycl_context(int device) : + device(device), + name(GGML_SYCL_NAME + std::to_string(device)) { + } + + ~ggml_backend_sycl_context() { + for (int i = 0; i < GGML_SYCL_MAX_DEVICES; ++i) { + for (int j = 0; j < GGML_SYCL_MAX_STREAMS; ++j) { + if (qptrs[i][j] != nullptr) { + SYCL_CHECK(free(qptrs[i][j])); + } + } + if (cublas_handles[i] != nullptr) { + SYCL_CHECK(free(sycl_handles[i])); + } + } + } + + queue_ptr stream(int device, int stream) { + if (qptrs[device][stream] == nullptr) { + SYCL_CHECK(dpct::get_current_device().create_queue( + g_sycl_gpu_mgr->get_co_ctx(), dpct::get_current_device()))); + } + return qptrs[device][stream]; + } + + queue_ptr stream() { + return stream(device, 0); + } + + handle_ptr sycl_handle(int device) { + if (sycl_handles[device] == nullptr) { + const dpct::queue_ptr stream = qptrs[device][0]; + // create sycl handle + SYCL_CHECK(CHECK_TRY_ERROR(sycl_handles[device] = stream)); + } + return sycl_handles[device]; + } + + handle_ptr sycl_handle() { + return sycl_handle(device); + } + + // pool + std::unique_ptr pools[GGML_SYCL_MAX_DEVICES]; + + static std::unique_ptr new_pool_for_device(queue_ptr qptr, int device); + + ggml_sycl_pool & pool(int device) { + if (pools[device] == nullptr) { + pools[device] = new_pool_for_device(qptrs[device][0], device); + } + return *pools[device]; + } + + ggml_sycl_pool & pool() { + return pool(device); + } +}; + #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml-sycl/convert.cpp b/ggml-sycl/convert.cpp index 4e4d214ed765c..ce9de2b42b722 100644 --- a/ggml-sycl/convert.cpp +++ b/ggml-sycl/convert.cpp @@ -1,6 +1,6 @@ #include "convert.hpp" #include "dequantize.hpp" - +#include "presets.hpp" template static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, diff --git a/ggml-sycl/dmmv.cpp b/ggml-sycl/dmmv.cpp index 754f84ca886f2..7b2849ff4c3a6 100644 --- a/ggml-sycl/dmmv.cpp +++ b/ggml-sycl/dmmv.cpp @@ -1,7 +1,7 @@ #include "convert.hpp" #include "dmmv.hpp" #include "dequantize.hpp" -#include "pool.hpp" +#include "presets.hpp" static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){ const sycl::half *x = (const sycl::half *)vx; @@ -967,10 +967,35 @@ void ggml_sycl_op_dequantize_mul_mat_vec( break; } GGML_ASSERT(src1->type == GGML_TYPE_F32); - + switch (src0->type) + { + case GGML_TYPE_F32: + printf("f32\n"); + break; + case GGML_TYPE_Q4_0: + printf("q4_0\n"); + break; + case GGML_TYPE_F16: + printf("f16\n"); + break; + case GGML_TYPE_Q4_1: + printf("q4_1\n"); + break; + case GGML_TYPE_Q5_0: + printf("q5_0\n"); + break; + case GGML_TYPE_Q5_1: + printf("q5_1\n"); + break; + case GGML_TYPE_Q5_K: + printf("q5_K\n"); + break; + default: + break; + } // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics #ifdef GGML_SYCL_F16 - sycl_pool_alloc src1_dfloat_a; + ggml_sycl_pool_alloc src1_dfloat_a; sycl::half *src1_dfloat = nullptr; // dfloat == half bool src1_convert_f16 = @@ -987,7 +1012,32 @@ void ggml_sycl_op_dequantize_mul_mat_vec( #else const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion #endif // GGML_SYCL_F16 - + switch (src0->type) + { + case GGML_TYPE_F32: + printf("f32\n"); + break; + case GGML_TYPE_Q4_0: + printf("q4_0\n"); + break; + case GGML_TYPE_F16: + printf("f16\n"); + break; + case GGML_TYPE_Q4_1: + printf("q4_1\n"); + break; + case GGML_TYPE_Q5_0: + printf("q5_0\n"); + break; + case GGML_TYPE_Q5_1: + printf("q5_1\n"); + break; + case GGML_TYPE_Q5_K: + printf("q5_K\n"); + break; + default: + break; + } switch (src0->type) { case GGML_TYPE_Q4_0: dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); diff --git a/ggml-sycl/mmq.cpp b/ggml-sycl/mmq.cpp index 3927ca467abfa..d98f406d99733 100644 --- a/ggml-sycl/mmq.cpp +++ b/ggml-sycl/mmq.cpp @@ -1779,7 +1779,7 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy, int id; SYCL_CHECK( CHECK_TRY_ERROR(id = get_current_device_id())); - const int compute_capability = g_device_caps[id].cc; + const int compute_capability = ggml_sycl_info().devices[id].cc; int mmq_x, mmq_y, nwarps; if (compute_capability >= VER_GEN13) { @@ -1894,7 +1894,7 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy, int id; SYCL_CHECK( CHECK_TRY_ERROR(id = get_current_device_id())); - const int compute_capability = g_device_caps[id].cc; + const int compute_capability = ggml_sycl_info().devices[id].cc; int mmq_x, mmq_y, nwarps; if (compute_capability >= VER_GEN13) { @@ -2009,7 +2009,7 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy, int id; SYCL_CHECK( CHECK_TRY_ERROR(id = get_current_device_id())); - const int compute_capability = g_device_caps[id].cc; + const int compute_capability = ggml_sycl_info().devices[id].cc; int mmq_x, mmq_y, nwarps; if (compute_capability >= VER_GEN13) { @@ -2124,7 +2124,7 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy, int id; SYCL_CHECK( CHECK_TRY_ERROR(id = get_current_device_id())); - const int compute_capability = g_device_caps[id].cc; + const int compute_capability = ggml_sycl_info().devices[id].cc; int mmq_x, mmq_y, nwarps; if (compute_capability >= VER_GEN13) { @@ -2239,7 +2239,7 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy, int id; SYCL_CHECK( CHECK_TRY_ERROR(id = get_current_device_id())); - const int compute_capability = g_device_caps[id].cc; + const int compute_capability = ggml_sycl_info().devices[id].cc; int mmq_x, mmq_y, nwarps; if (compute_capability >= VER_GEN13) { @@ -2354,7 +2354,7 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy, int id; SYCL_CHECK( CHECK_TRY_ERROR(id = get_current_device_id())); - const int compute_capability = g_device_caps[id].cc; + const int compute_capability = ggml_sycl_info().devices[id].cc; int mmq_x, mmq_y, nwarps; if (compute_capability >= VER_GEN13) { @@ -2477,7 +2477,7 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy, int id; SYCL_CHECK( CHECK_TRY_ERROR(id = get_current_device_id())); - const int compute_capability = g_device_caps[id].cc; + const int compute_capability = ggml_sycl_info().devices[id].cc; int mmq_x, mmq_y, nwarps; if (compute_capability >= VER_GEN13) { @@ -2605,7 +2605,7 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy, int id; SYCL_CHECK( CHECK_TRY_ERROR(id = get_current_device_id())); - const int compute_capability = g_device_caps[id].cc; + const int compute_capability = ggml_sycl_info().devices[id].cc; int mmq_x, mmq_y, nwarps; if (compute_capability >= VER_GEN13) { @@ -2726,7 +2726,7 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy, int id; SYCL_CHECK( CHECK_TRY_ERROR(id = get_current_device_id())); - const int compute_capability = g_device_caps[id].cc; + const int compute_capability = ggml_sycl_info().devices[id].cc; int mmq_x, mmq_y, nwarps; if (compute_capability >= VER_GEN13) { @@ -2847,7 +2847,7 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy, int id; SYCL_CHECK( CHECK_TRY_ERROR(id = get_current_device_id())); - const int compute_capability = g_device_caps[id].cc; + const int compute_capability = ggml_sycl_info().devices[id].cc; int mmq_x, mmq_y, nwarps; if (compute_capability >= VER_GEN13) { diff --git a/ggml-sycl/pool.hpp b/ggml-sycl/pool.hpp deleted file mode 100644 index 3308d3cb91d69..0000000000000 --- a/ggml-sycl/pool.hpp +++ /dev/null @@ -1,233 +0,0 @@ -// -// MIT license -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: MIT -// - -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// - -#ifndef GGML_SYCL_POOL_HPP -#define GGML_SYCL_POOL_HPP - -// buffer pool for sycl -#define MAX_SYCL_BUFFERS 256 - -struct scoped_spin_lock { - std::atomic_flag& lock; - scoped_spin_lock(std::atomic_flag& lock) : lock(lock) { - while (lock.test_and_set(std::memory_order_acquire)) { - ; // spin - } - } - ~scoped_spin_lock() { - lock.clear(std::memory_order_release); - } - scoped_spin_lock(const scoped_spin_lock&) = delete; - scoped_spin_lock& operator=(const scoped_spin_lock&) = delete; -}; - -static std::atomic_flag g_sycl_pool_lock = ATOMIC_FLAG_INIT; - -// #define DEBUG_SYCL_MALLOC -struct sycl_buffer { - void * ptr = nullptr; - size_t size = 0; -}; - -static sycl_buffer g_sycl_buffer_pool[GGML_SYCL_MAX_DEVICES][MAX_SYCL_BUFFERS]; -static size_t g_sycl_pool_size[GGML_SYCL_MAX_DEVICES] = {0}; - -static void *ggml_sycl_pool_malloc_leg(int device_index, size_t size, size_t *actual_size) try { - scoped_spin_lock lock(g_sycl_pool_lock); - // GGML_SYCL_DEBUG("ggml_sycl_pool_malloc_leg device_index %d size=%lu\n", device_index, size); -#ifdef DEBUG_SYCL_MALLOC - int nnz = 0; - size_t max_size = 0; -#endif - size_t best_diff = 1ull << 36; - int ibest = -1; - for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { - sycl_buffer& b = g_sycl_buffer_pool[device_index][i]; - if (b.ptr != nullptr) { -#ifdef DEBUG_SYCL_MALLOC - ++nnz; - if (b.size > max_size) max_size = b.size; -#endif - if (b.size >= size) { - size_t diff = b.size - size; - if (diff < best_diff) { - best_diff = diff; - ibest = i; - if (!best_diff) { - void * ptr = b.ptr; - *actual_size = b.size; - b.ptr = nullptr; - b.size = 0; - // GGML_SYCL_DEBUG("ggml_sycl_pool_malloc_leg return 1 %p and rm in pool\n", ptr); - return ptr; - } - } - } - } - } - if (ibest >= 0) { - sycl_buffer& b = g_sycl_buffer_pool[device_index][ibest]; - void * ptr = b.ptr; - *actual_size = b.size; - b.ptr = nullptr; - b.size = 0; - // GGML_SYCL_DEBUG("ggml_sycl_pool_malloc_leg return 2 %p and rm in pool\n", ptr); - return ptr; - } - void * ptr; - size_t look_ahead_size = (size_t) (1.05 * size); - look_ahead_size = 256 * ((look_ahead_size + 255)/256); - - const dpct::queue_ptr stream = g_syclStreams[device_index][0]; - SYCL_CHECK( - CHECK_TRY_ERROR(ptr = (void *)sycl::malloc_device( - look_ahead_size, *stream))); - *actual_size = look_ahead_size; - g_sycl_pool_size[device_index] += look_ahead_size; - -#ifdef DEBUG_SYCL_MALLOC - fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz, - (uint32_t)(max_size/1024/1024), (uint32_t)(g_sycl_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024)); -#endif - // GGML_SYCL_DEBUG("ggml_sycl_pool_malloc_leg look_ahead_size=%lu, return %p\n", look_ahead_size, ptr); - return ptr; -} -catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} - -static void ggml_sycl_pool_free_leg(int device_index, void *ptr, size_t size) try { - scoped_spin_lock lock(g_sycl_pool_lock); - const dpct::queue_ptr stream = g_syclStreams[device_index][0]; - for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { - sycl_buffer& b = g_sycl_buffer_pool[device_index][i]; - if (b.ptr == nullptr) { - b.ptr = ptr; - b.size = size; - return; - } - } - fprintf(stderr, "WARNING: sycl buffer pool full, increase MAX_SYCL_BUFFERS\n"); - SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *stream))); - g_sycl_pool_size[device_index] -= size; -} -catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} - -// pool with virtual memory -/* -DPCT1082:64: Migration of CUmemGenericAllocationHandle type is not supported. -*/ -// static std::vector -// g_sycl_pool_handles[GGML_SYCL_MAX_DEVICES]; -static dpct::device_ptr g_sycl_pool_addr[GGML_SYCL_MAX_DEVICES] = {0}; -static size_t g_sycl_pool_used[GGML_SYCL_MAX_DEVICES] = {0}; - -static void *ggml_sycl_pool_malloc_vmm(int device_index, size_t size, size_t *actual_size) try { - GGML_UNUSED(device_index); - GGML_UNUSED(size); - GGML_UNUSED(actual_size); - return NULL; -} -catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} - -static void ggml_sycl_pool_free_vmm(int device_index, void *ptr, size_t size) try { - scoped_spin_lock lock(g_sycl_pool_lock); -#ifdef DEBUG_SYCL_MALLOC - printf("sycl pool[%d]: freed %llu bytes at %llx\n", device_index, (unsigned long long) size, ptr); -#endif - - g_sycl_pool_used[device_index] -= size; - - // all deallocations must be in reverse order of the allocations - GGML_ASSERT(ptr == (void *) (g_sycl_pool_addr[device_index] + g_sycl_pool_used[device_index])); -} -catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} - -static void *ggml_sycl_pool_malloc(int device_index, size_t size, size_t *actual_size) try { - if (g_device_caps[device_index].vmm) { - return ggml_sycl_pool_malloc_vmm(device_index, size, actual_size); - } else { - return ggml_sycl_pool_malloc_leg(device_index, size, actual_size); - } -} -catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} - -static void ggml_sycl_pool_free(int device_index, void *ptr, size_t size) try { - if (g_device_caps[device_index].vmm) { - ggml_sycl_pool_free_vmm(device_index, ptr, size); - } else { - ggml_sycl_pool_free_leg(device_index, ptr, size); - } -} -catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} - -template -struct sycl_pool_alloc { - int device_index = -1; - int device_id = -1; - T * ptr = nullptr; - size_t actual_size = 0; - - // size is in number of elements - T * alloc(size_t size) { - GGML_ASSERT(ptr == nullptr); - device_id = get_current_device_id(); - device_index = g_sycl_gpu_mgr->get_index(device_id); - ptr = (T *) ggml_sycl_pool_malloc(device_index, size * sizeof(T), &this->actual_size); - // GGML_SYCL_DEBUG("sycl_pool_alloc %lu return %p actual size=%lu\n", size * sizeof(T), ptr, this->actual_size); - return ptr; - } - - sycl_pool_alloc(size_t size) { - alloc(size); - } - - ~sycl_pool_alloc() { - if (ptr != nullptr) { - ggml_sycl_pool_free(device_index, ptr, actual_size); - } - } - - T * get() { - return ptr; - } - - sycl_pool_alloc() = default; - sycl_pool_alloc(const sycl_pool_alloc &) = delete; - sycl_pool_alloc(sycl_pool_alloc &&) = delete; - sycl_pool_alloc& operator=(const sycl_pool_alloc &) = delete; - sycl_pool_alloc& operator=(sycl_pool_alloc &&) = delete; -}; - -#endif // GGML_SYCL_POOL_HPP diff --git a/ggml-sycl/presets.hpp b/ggml-sycl/presets.hpp index ea337030dac10..357015db30fdd 100644 --- a/ggml-sycl/presets.hpp +++ b/ggml-sycl/presets.hpp @@ -14,6 +14,7 @@ #define GGML_SYCL_PRESETS_HPP #define GGML_SYCL_MAX_DEVICES 48 +#define GGML_SYCL_MAX_STREAMS 8 #define GGML_SYCL_NAME "SYCL" #define WARP_SIZE 32 diff --git a/llama.cpp b/llama.cpp index aeb5c08df64a5..12c5003e00a50 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6135,17 +6135,6 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam params.n_gpu_layers = 0; } #endif - -#ifdef GGML_USE_SYCL - if (params.split_mode == LLAMA_SPLIT_MODE_NONE) { - ggml_backend_sycl_set_single_device_mode(params.main_gpu); - //SYCL use device index (0, 1, 2) directly, uer input device id, then convert to device index. - params.main_gpu = ggml_backend_sycl_get_device_index(params.main_gpu); - } else { - ggml_backend_sycl_set_mul_device_mode(); - } -#endif - if (!llm_load_tensors( ml, model, params.n_gpu_layers, params.split_mode, params.main_gpu, params.tensor_split, params.use_mlock, params.progress_callback, params.progress_callback_user_data From 27c3f2961a5ba21047a7204127ce5f8f9e28cbb5 Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Wed, 22 May 2024 03:55:06 +0000 Subject: [PATCH 2/2] backup --- ggml-sycl.cpp | 528 ++++++++++++++++++++++--------------------- ggml-sycl/common.hpp | 56 +---- ggml-sycl/dmmv.cpp | 1 + ggml-sycl/dmmv.hpp | 1 + ggml-sycl/mmq.cpp | 1 + ggml-sycl/mmq.hpp | 1 + ggml-sycl/mmvq.cpp | 1 + ggml-sycl/mmvq.hpp | 1 + 8 files changed, 279 insertions(+), 311 deletions(-) diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index e63503a21733c..0c2e3979cbb24 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -60,14 +60,15 @@ void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst, } typedef void (*cpy_kernel_t)(const char * cx, char * cdst); -typedef void (*ggml_sycl_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); +typedef void (*ggml_sycl_func_t)(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); typedef void (*ggml_sycl_op_mul_mat_t)( + ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, float *dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, const queue_ptr &stream); -typedef void (*ggml_sycl_op_flatten_t)(const ggml_tensor *src0, +typedef void (*ggml_sycl_op_flatten_t)(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, @@ -1475,7 +1476,7 @@ static void pool2d_nchw_kernel( } template -static void get_rows_sycl(const ggml_tensor *src0, const ggml_tensor *src1, +static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const void *src0_dd, const int32_t *src1_dd, float *dst_dd, queue_ptr stream) { @@ -1510,7 +1511,7 @@ static void get_rows_sycl(const ggml_tensor *src0, const ggml_tensor *src1, } template -static void get_rows_sycl_float(const ggml_tensor *src0, +static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const src0_t *src0_dd, const int32_t *src1_dd, float *dst_dd, queue_ptr stream) { @@ -2811,6 +2812,25 @@ const ggml_sycl_device_info & ggml_sycl_info() { return info; } +/* +device_index: device index from 0 to n (continue numbers). + It is used for device select/set in SYCL backend internal data structure. +*/ +inline void check_allow_gpu_index(const int device_index) { + if (device_index >= ggml_sycl_info().device_count) { + char error_buf[256]; + snprintf( + error_buf, + sizeof(error_buf), + "%s error: device_index:%d is out of range: [0-%d]", + __func__, + device_index, + ggml_sycl_info().device_count - 1); + fprintf(stderr, "%s\n", error_buf); + assert(false); + } +} + // buffer pool for sycl (legacy) struct ggml_sycl_pool_leg : public ggml_sycl_pool { static const int MAX_SYCL_BUFFERS = 256; @@ -2826,7 +2846,7 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { size_t pool_size = 0; explicit ggml_sycl_pool_leg(queue_ptr qptr_, int device_) : - qptr(stream_) + qptr(qptr_), device(device_) { } @@ -2992,7 +3012,7 @@ catch (sycl::exception const &exc) { std::exit(1); } -static void ggml_sycl_op_get_rows(const ggml_tensor *src0, +static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_d, const float *src1_d, float *dst_d, const queue_ptr &stream) { @@ -3008,26 +3028,26 @@ static void ggml_sycl_op_get_rows(const ggml_tensor *src0, switch (src0->type) { case GGML_TYPE_F16: - get_rows_sycl_float(src0, src1, dst, (const sycl::half *)src0_d, + get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d, src1_i32, dst_d, stream); break; case GGML_TYPE_F32: - get_rows_sycl_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); break; case GGML_TYPE_Q4_0: - get_rows_sycl(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); break; case GGML_TYPE_Q4_1: - get_rows_sycl(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); break; case GGML_TYPE_Q5_0: - get_rows_sycl(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); break; case GGML_TYPE_Q5_1: - get_rows_sycl(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); break; case GGML_TYPE_Q8_0: - get_rows_sycl(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); break; default: // TODO: k-quants @@ -3038,7 +3058,7 @@ static void ggml_sycl_op_get_rows(const ggml_tensor *src0, } template -inline void ggml_sycl_op_bin_bcast(const ggml_tensor *src0, +inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, @@ -3065,27 +3085,27 @@ inline void ggml_sycl_op_bin_bcast(const ggml_tensor *src0, } } -static void ggml_sycl_op_repeat(const ggml_tensor *src0, +static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_d, const float *src1_d, float *dst_d, const queue_ptr &main_stream) { - ggml_sycl_op_bin_bcast>(dst, src0, dst, nullptr, src0_d, dst_d, main_stream); + ggml_sycl_op_bin_bcast>(ctx, dst, src0, dst, nullptr, src0_d, dst_d, main_stream); (void) src1; (void) src1_d; } -inline void ggml_sycl_op_add(const ggml_tensor *src0, const ggml_tensor *src1, +inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) { - ggml_sycl_op_bin_bcast>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); + ggml_sycl_op_bin_bcast>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); } -inline void ggml_sycl_op_acc(const ggml_tensor *src0, const ggml_tensor *src1, +inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) { @@ -3105,23 +3125,23 @@ inline void ggml_sycl_op_acc(const ggml_tensor *src0, const ggml_tensor *src1, (void) dst; } -inline void ggml_sycl_op_mul(const ggml_tensor *src0, const ggml_tensor *src1, +inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) { - ggml_sycl_op_bin_bcast>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); + ggml_sycl_op_bin_bcast>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); } -inline void ggml_sycl_op_div(const ggml_tensor *src0, const ggml_tensor *src1, +inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) { - ggml_sycl_op_bin_bcast>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); + ggml_sycl_op_bin_bcast>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); } -inline void ggml_sycl_op_gelu(const ggml_tensor *src0, const ggml_tensor *src1, +inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) { @@ -3136,7 +3156,7 @@ inline void ggml_sycl_op_gelu(const ggml_tensor *src0, const ggml_tensor *src1, (void) src1_dd; } -inline void ggml_sycl_op_silu(const ggml_tensor *src0, const ggml_tensor *src1, +inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) { @@ -3151,7 +3171,7 @@ inline void ggml_sycl_op_silu(const ggml_tensor *src0, const ggml_tensor *src1, (void) src1_dd; } -inline void ggml_sycl_op_gelu_quick(const ggml_tensor *src0, +inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, @@ -3167,7 +3187,7 @@ inline void ggml_sycl_op_gelu_quick(const ggml_tensor *src0, (void) src1_dd; } -inline void ggml_sycl_op_tanh(const ggml_tensor *src0, const ggml_tensor *src1, +inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) { @@ -3181,7 +3201,7 @@ inline void ggml_sycl_op_tanh(const ggml_tensor *src0, const ggml_tensor *src1, (void) src1_dd; } -inline void ggml_sycl_op_relu(const ggml_tensor *src0, const ggml_tensor *src1, +inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) { @@ -3196,7 +3216,7 @@ inline void ggml_sycl_op_relu(const ggml_tensor *src0, const ggml_tensor *src1, (void) src1_dd; } -static void ggml_sycl_op_hardsigmoid(const ggml_tensor *src0, +static void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, @@ -3212,7 +3232,7 @@ static void ggml_sycl_op_hardsigmoid(const ggml_tensor *src0, (void) src1_dd; } -static void ggml_sycl_op_hardswish(const ggml_tensor *src0, +static void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) { @@ -3227,7 +3247,7 @@ static void ggml_sycl_op_hardswish(const ggml_tensor *src0, (void) src1_dd; } -inline void ggml_sycl_op_leaky_relu(const ggml_tensor *src0, +inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, @@ -3246,7 +3266,7 @@ inline void ggml_sycl_op_leaky_relu(const ggml_tensor *src0, (void) src1_dd; } -inline void ggml_sycl_op_sqr(const ggml_tensor *src0, const ggml_tensor *src1, +inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) { @@ -3261,7 +3281,7 @@ inline void ggml_sycl_op_sqr(const ggml_tensor *src0, const ggml_tensor *src1, (void) src1_dd; } -inline void ggml_sycl_op_norm(const ggml_tensor *src0, const ggml_tensor *src1, +inline void ggml_sycl_op_norm(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) { @@ -3282,7 +3302,7 @@ inline void ggml_sycl_op_norm(const ggml_tensor *src0, const ggml_tensor *src1, (void) src1_dd; } -inline void ggml_sycl_op_group_norm(const ggml_tensor *src0, +inline void ggml_sycl_op_group_norm(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, @@ -3300,7 +3320,7 @@ inline void ggml_sycl_op_group_norm(const ggml_tensor *src0, (void) src1_dd; } -inline void ggml_sycl_op_concat(const ggml_tensor *src0, +inline void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, @@ -3318,7 +3338,7 @@ inline void ggml_sycl_op_concat(const ggml_tensor *src0, (void) dst; } -inline void ggml_sycl_op_upscale(const ggml_tensor *src0, +inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, @@ -3337,7 +3357,7 @@ inline void ggml_sycl_op_upscale(const ggml_tensor *src0, (void) src1_dd; } -inline void ggml_sycl_op_pad(const ggml_tensor *src0, const ggml_tensor *src1, +inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) { @@ -3355,7 +3375,7 @@ inline void ggml_sycl_op_pad(const ggml_tensor *src0, const ggml_tensor *src1, (void) src1_dd; } -inline void ggml_sycl_op_rms_norm(const ggml_tensor *src0, +inline void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, @@ -3380,8 +3400,8 @@ inline void ggml_sycl_op_rms_norm(const ggml_tensor *src0, static int64_t get_row_rounding(ggml_type type, const std::array & tensor_split) { int64_t min_compute_capability = INT_MAX; int64_t max_compute_capability = INT_MIN; - for (int i = 0; i < g_device_count; ++i) { - if (tensor_split[i] < (i + 1 < g_device_count ? tensor_split[i + 1] : 1.0f)) { + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { + if (tensor_split[i] < (i + 1 < ggml_sycl_info().device_count ? tensor_split[i + 1] : 1.0f)) { if (min_compute_capability > ggml_sycl_info().devices[i].cc) { min_compute_capability = ggml_sycl_info().devices[i].cc; } @@ -3426,6 +3446,7 @@ static int64_t get_row_rounding(ggml_type type, const std::arrayop_params[0] == GGML_PREC_DEFAULT) { // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp16 path\n"); - ggml_sycl_pool_alloc src0_as_f16; + ggml_sycl_pool_alloc src0_as_f16(ctx.pool()); if (src0->type != GGML_TYPE_F16) { const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type); GGML_ASSERT(to_fp16_sycl != nullptr); @@ -3473,7 +3494,7 @@ inline void ggml_sycl_op_mul_mat_sycl( ? (const sycl::half *)src0_dd_i : src0_as_f16.get(); - ggml_sycl_pool_alloc src1_as_f16; + ggml_sycl_pool_alloc src1_as_f16(ctx.pool()); if (src1->type != GGML_TYPE_F16) { const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type); GGML_ASSERT(to_fp16_sycl != nullptr); @@ -3484,26 +3505,24 @@ inline void ggml_sycl_op_mul_mat_sycl( const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16 ? (const sycl::half *)src1->data + src1_padded_row_size : src1_as_f16.get(); - ggml_sycl_pool_alloc dst_f16(row_diff * src1_ncols); + ggml_sycl_pool_alloc dst_f16(ctx.pool(), row_diff * src1_ncols); const sycl::half alpha_f16 = 1.0f; const sycl::half beta_f16 = 0.0f; - SYCL_CHECK(CHECK_TRY_ERROR(g_sycl_handles[id] = stream)); SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm( - *g_sycl_handles[id], oneapi::mkl::transpose::trans, + *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00, src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16, dst_f16.get(), dpct::library_data_t::real_half, ldc, dpct::library_data_t::real_half))); - g_sycl_handles[id]->wait(); const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16); to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); } else { // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n"); - ggml_sycl_pool_alloc src0_ddq_as_f32; - ggml_sycl_pool_alloc src1_ddq_as_f32; + ggml_sycl_pool_alloc src0_ddq_as_f32(ctx.pool()); + ggml_sycl_pool_alloc src1_ddq_as_f32(ctx.pool()); if (src0->type != GGML_TYPE_F32) { const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type); GGML_ASSERT(to_fp32_sycl != nullptr); @@ -3522,14 +3541,12 @@ inline void ggml_sycl_op_mul_mat_sycl( const float alpha = 1.0f; const float beta = 0.0f; - SYCL_CHECK(CHECK_TRY_ERROR(g_sycl_handles[id] = stream)); SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( - *g_sycl_handles[id], oneapi::mkl::transpose::trans, + *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, - dpct::get_value(&alpha, *g_sycl_handles[id]), src0_ddf_i, ne00, - src1_ddf1_i, ne10, dpct::get_value(&beta, *g_sycl_handles[id]), + dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, + src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc))); - g_sycl_handles[id]->wait(); } (void) dst; (void) src1_ddq_i; @@ -3541,7 +3558,7 @@ catch (sycl::exception const &exc) { std::exit(1); } -inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, +inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) { @@ -3621,7 +3638,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, (void) src1_dd; } -inline void ggml_sycl_op_alibi(const ggml_tensor *src0, const ggml_tensor *src1, +inline void ggml_sycl_op_alibi(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) { @@ -3651,7 +3668,7 @@ inline void ggml_sycl_op_alibi(const ggml_tensor *src0, const ggml_tensor *src1, (void) src1_dd; } -static void ggml_sycl_op_pool2d(const ggml_tensor *src0, +static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) { @@ -3693,7 +3710,7 @@ static void ggml_sycl_op_pool2d(const ggml_tensor *src0, (void) src1_dd; } -inline void ggml_sycl_op_im2col(const ggml_tensor *src0, +inline void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, @@ -3734,7 +3751,7 @@ inline void ggml_sycl_op_im2col(const ggml_tensor *src0, (void) src0_dd; } -inline void ggml_sycl_op_sum_rows(const ggml_tensor *src0, +inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, @@ -3753,7 +3770,7 @@ inline void ggml_sycl_op_sum_rows(const ggml_tensor *src0, (void) src1_dd; } -inline void ggml_sycl_op_argsort(const ggml_tensor *src0, +inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, @@ -3774,7 +3791,7 @@ inline void ggml_sycl_op_argsort(const ggml_tensor *src0, (void) src1_dd; } -inline void ggml_sycl_op_diag_mask_inf(const ggml_tensor *src0, +inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, @@ -3796,7 +3813,8 @@ inline void ggml_sycl_op_diag_mask_inf(const ggml_tensor *src0, (void) src1_dd; } -inline void ggml_sycl_op_soft_max(const ggml_tensor *src0, +inline void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, + const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, @@ -3824,7 +3842,7 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0, // positions tensor float * src2_dd = nullptr; - ggml_sycl_pool_alloc src2_f; + ggml_sycl_pool_alloc src2_f(ctx.pool()); const bool use_src2 = src2 != nullptr; @@ -3844,7 +3862,7 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0, nrows_x, nrows_y, scale, max_bias, main_stream); } -inline void ggml_sycl_op_scale(const ggml_tensor *src0, const ggml_tensor *src1, +inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) { @@ -3867,7 +3885,7 @@ inline void ggml_sycl_op_scale(const ggml_tensor *src0, const ggml_tensor *src1, (void) src1_dd; } -inline void ggml_sycl_op_clamp(const ggml_tensor *src0, const ggml_tensor *src1, +inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) { @@ -3892,7 +3910,7 @@ inline void ggml_sycl_op_clamp(const ggml_tensor *src0, const ggml_tensor *src1, (void) src1_dd; } -static void ggml_sycl_op_flatten(const ggml_sycl_backend_context & ctx, +static void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const ggml_sycl_op_flatten_t op) try { @@ -3917,9 +3935,9 @@ static void ggml_sycl_op_flatten(const ggml_sycl_backend_context & ctx, float * src1_ddf = nullptr; float * dst_ddf = nullptr; - ggml_sycl_pool_alloc src0_f; - ggml_sycl_pool_alloc src1_f; - ggml_sycl_pool_alloc dst_f; + ggml_sycl_pool_alloc src0_f(ctx.pool()); + ggml_sycl_pool_alloc src1_f(ctx.pool()); + ggml_sycl_pool_alloc dst_f(ctx.pool()); ggml_sycl_set_device(g_main_device); queue_ptr main_stream = ctx.stream(g_main_device, 0); @@ -3951,7 +3969,7 @@ static void ggml_sycl_op_flatten(const ggml_sycl_backend_context & ctx, // GGML_SYCL_DEBUG("op src0=%p, src1=%p, dst=%p, src0_ddf=%p, src1_ddf=%p, dst_ddf=%p, main_stream=%p\n", // src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream); // do the computation - op(src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream); + op(ctx, src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream); /* DPCT1010:89: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to rewrite this code. @@ -3987,15 +4005,15 @@ static void ggml_sycl_set_peer_access(const int n_tokens) { } #ifdef NDEBUG - for (int i = 0; i < g_device_count; ++i) { + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { SYCL_CHECK(ggml_sycl_set_device(i)); // SYCL_CHECK(syclDeviceSynchronize()); } - for (int i = 0; i < g_device_count; ++i) { + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { SYCL_CHECK(ggml_sycl_set_device(i)); - for (int id_other = 0; id_other < g_device_count; ++id_other) { + for (int id_other = 0; id_other < ggml_sycl_info().device_count; ++id_other) { if (i == id_other) { continue; } @@ -4023,7 +4041,8 @@ struct ggml_backend_sycl_split_buffer_type_context { std::array tensor_split; }; -static void ggml_sycl_op_mul_mat(const ggml_tensor *src0, +static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, + const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, ggml_sycl_op_mul_mat_t op, const bool convert_src1_to_q8_1) try { @@ -4095,10 +4114,9 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0, dev_data dev[GGML_SYCL_MAX_DEVICES]; int used_devices = 0; - auto ctx = ggml_sycl_info(); queue_ptr main_stream = ctx.stream(g_main_device, 0); - for (int i = 0; i < g_device_count; ++i) { + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { // by default, use all rows dev[i].row_low = 0; dev[i].row_high = ne01; @@ -4115,7 +4133,7 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0, } } - if (i != g_device_count - 1) { + if (i != ggml_sycl_info().device_count - 1) { dev[i].row_high = ne01*tensor_split[i + 1]; if (dev[i].row_high < ne01) { dev[i].row_high -= dev[i].row_high % rounding; @@ -4124,7 +4142,7 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0, } } - for (int i = 0; i < g_device_count; ++i) { + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { if ((!split && i != g_main_device) || dev[i].row_low == dev[i].row_high) { continue; } @@ -4153,7 +4171,7 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0, dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs); if (src1_on_device && src1_is_contiguous) { - quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream); + quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, main_stream); /* DPCT1010:90: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to @@ -4187,10 +4205,10 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0, const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11; for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) { - const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0; + const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_SYCL_MAX_STREAMS : 0; const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride; - for (int i = 0; i < g_device_count; ++i) { + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { if ((!split && i != g_main_device) || dev[i].row_low == dev[i].row_high) { continue; } @@ -4200,7 +4218,7 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0, const int64_t row_diff = dev[i].row_high - dev[i].row_low; ggml_sycl_set_device(i); - queue_ptr stream = xtx.stream(i, is); + queue_ptr stream = ctx.stream(i, is); // wait for main GPU data if necessary if (split && (i != g_main_device || is != 0)) { @@ -4274,7 +4292,7 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0, src1_padded_col_size = (i0 * ne11 + src1_col_0) * ne10; } // do the computation - SYCL_CHECK(CHECK_TRY_ERROR(op(src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i, + SYCL_CHECK(CHECK_TRY_ERROR(op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i, dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream))); /* DPCT1010:93: SYCL uses exceptions to report errors and does not @@ -4354,18 +4372,18 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0, } // main device waits for all other devices to be finished - if (split && g_device_count > 1) { + if (split && ggml_sycl_info().device_count > 1) { int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE; - is_max = is_max <= MAX_STREAMS ? is_max : MAX_STREAMS; + is_max = is_max <= GGML_SYCL_MAX_STREAMS ? is_max : GGML_SYCL_MAX_STREAMS; ggml_sycl_set_device(g_main_device); - for (int i = 0; i < g_device_count; ++i) { + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { if (dev[i].row_low == dev[i].row_high) { continue; } for (int64_t is = 0; is < is_max; ++is) { SYCL_CHECK(CHECK_TRY_ERROR( - stx.stream(g_main_device, 0)->ext_oneapi_submit_barrier( + ctx.stream(g_main_device, 0)->ext_oneapi_submit_barrier( {*src0_extra->events[i][is]}))); } } @@ -4384,130 +4402,130 @@ catch (sycl::exception const &exc) { } -static void ggml_sycl_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_repeat); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_repeat); GGML_SYCL_DEBUG("call %s done\n", __func__); } -static void ggml_sycl_get_rows(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_get_rows); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_get_rows); GGML_SYCL_DEBUG("call %s done\n", __func__); } -static void ggml_sycl_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_add(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_add); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_add); GGML_SYCL_DEBUG("call %s done\n", __func__); } -static void ggml_sycl_acc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_acc(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_acc); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_acc); GGML_SYCL_DEBUG("call %s done\n", __func__); } -static void ggml_sycl_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_mul(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_mul); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_mul); GGML_SYCL_DEBUG("call %s done\n", __func__); } -static void ggml_sycl_div(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_div(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_div); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_div); GGML_SYCL_DEBUG("call %s done\n", __func__); } -static void ggml_sycl_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_gelu); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu); GGML_SYCL_DEBUG("call %s done\n", __func__); } -static void ggml_sycl_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_silu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_silu); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_silu); GGML_SYCL_DEBUG("call %s done\n", __func__); } -static void ggml_sycl_gelu_quick(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_gelu_quick); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu_quick); GGML_SYCL_DEBUG("call %s done\n", __func__); } -static void ggml_sycl_tanh(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_tanh); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_tanh); GGML_SYCL_DEBUG("call %s done\n", __func__); } -static void ggml_sycl_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_relu); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_relu); GGML_SYCL_DEBUG("call %s done\n", __func__); } -static void ggml_sycl_hardsigmoid(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_hardsigmoid); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardsigmoid); GGML_SYCL_DEBUG("call %s done\n", __func__); } -static void ggml_sycl_hardswish(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_hardswish); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardswish); GGML_SYCL_DEBUG("call %s done\n", __func__); } -static void ggml_sycl_leaky_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_leaky_relu); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_leaky_relu); GGML_SYCL_DEBUG("call %s done\n", __func__); } -static void ggml_sycl_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_sqr); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sqr); GGML_SYCL_DEBUG("call %s done\n", __func__); } -static void ggml_sycl_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_norm); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_norm); GGML_SYCL_DEBUG("call %s done\n", __func__); } -static void ggml_sycl_group_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_group_norm); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_group_norm); GGML_SYCL_DEBUG("call %s done\n", __func__); } -static void ggml_sycl_concat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_concat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_concat); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_concat); GGML_SYCL_DEBUG("call %s done\n", __func__); } -static void ggml_sycl_upscale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_upscale); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_upscale); GGML_SYCL_DEBUG("call %s done\n", __func__); } -static void ggml_sycl_pad(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_pad(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_pad); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_pad); GGML_SYCL_DEBUG("call %s done\n", __func__); } -static void ggml_sycl_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_rms_norm); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rms_norm); GGML_SYCL_DEBUG("call %s done\n", __func__); } @@ -4563,7 +4581,7 @@ catch (sycl::exception const &exc) { std::exit(1); } -static void ggml_sycl_mul_mat_vec_nc(const ggml_tensor *src0, +static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst) try { GGML_ASSERT(!ggml_is_transposed(src0)); @@ -4583,7 +4601,6 @@ static void ggml_sycl_mul_mat_vec_nc(const ggml_tensor *src0, const int64_t ne12 = src1->ne[2]; SYCL_CHECK(ggml_sycl_set_device(g_main_device)); - auto ctx = ggml_sycl_info(); queue_ptr main_stream = ctx.stream(g_main_device, 0); ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; @@ -4631,7 +4648,8 @@ static void k_compute_batched_ptrs(const sycl::half *src0_as_f16, ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3; } -static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0, +static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, + const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst) try { GGML_ASSERT(!ggml_is_transposed(src0)); @@ -4644,15 +4662,11 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0, const int64_t ne_dst = ggml_nelements(dst); SYCL_CHECK(ggml_sycl_set_device(g_main_device)); - auto ctx = ggml_sycl_info(); queue_ptr main_stream = ctx.stream(g_main_device, 0); - bool no_mixed_dtypes = main_stream->get_backend() == sycl::backend::ext_oneapi_sycl || + bool no_mixed_dtypes = main_stream->get_backend() == sycl::backend::ext_oneapi_level_zero || main_stream->get_backend() == sycl::backend::ext_oneapi_hip; - SYCL_CHECK( - CHECK_TRY_ERROR(g_sycl_handles[g_main_device] = main_stream)); - ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; void * src0_ddq = src0_extra->data_device[g_main_device]; sycl::half *src0_as_f16 = (sycl::half *)src0_ddq; @@ -4664,7 +4678,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0, float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; // convert src1 to fp16 - ggml_sycl_pool_alloc src1_f16_alloc; + ggml_sycl_pool_alloc src1_f16_alloc(ctx.pool()); if (src1->type != GGML_TYPE_F16) { const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type); const int64_t ne_src1 = ggml_nelements(src1); @@ -4675,7 +4689,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0, sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf : src1_f16_alloc.get(); - ggml_sycl_pool_alloc dst_f16; + ggml_sycl_pool_alloc dst_f16(ctx.pool()); char * dst_t; dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float; @@ -4720,30 +4734,10 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0, const int64_t r2 = ne12/ne02; const int64_t r3 = ne13/ne03; -#if 0 - // use syclGemmEx - { - for (int i13 = 0; i13 < ne13; ++i13) { - for (int i12 = 0; i12 < ne12; ++i12) { - int i03 = i13 / r3; - int i02 = i12 / r2; - - SYCL_CHECK( - syclGemmEx(g_sycl_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, - ne01, ne11, ne10, - alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , SYCL_R_16F, nb01/sizeof(half), - (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, SYCL_R_16F, nb11/sizeof(float), - beta, ( char *) dst_t + i12*nbd2 + i13*nbd3, cu_data_type, ne01, - cu_compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - } - } - } -#else if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) { // there is no broadcast and src0, src1 are contiguous across dims 2, 3 SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( - *g_sycl_handles[g_main_device], oneapi::mkl::transpose::trans, + *main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, (const char *)src0_as_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00, @@ -4751,12 +4745,11 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0, nb11 / nb10, nb12 / nb10, beta, (char *)dst_t, cu_data_type, ne01, nb2 / nb0, ne12 * ne13, cu_compute_type))); - g_sycl_handles[g_main_device]->wait(); } else { const int ne23 = ne12*ne13; - ggml_sycl_pool_alloc ptrs_src(2*ne23); - ggml_sycl_pool_alloc< void *> ptrs_dst(1*ne23); + ggml_sycl_pool_alloc ptrs_src(ctx.pool(), 2*ne23); + ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23); sycl::range<3> block_dims(1, ne12, ne13); /* @@ -4785,7 +4778,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0, }).wait(); } SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( - *g_sycl_handles[g_main_device], oneapi::mkl::transpose::trans, + *main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, (const void **)(ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00, @@ -4793,9 +4786,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0, dpct::library_data_t::real_half, nb11 / nb10, beta, (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type))); - g_sycl_handles[g_main_device]->wait(); } -#endif if (no_mixed_dtypes) { const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16); @@ -4809,7 +4800,7 @@ catch (sycl::exception const &exc) { } -static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool all_on_device = (src0->backend == GGML_BACKEND_TYPE_GPU || src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT) && (src1->backend == GGML_BACKEND_TYPE_GPU) && @@ -4818,8 +4809,8 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 const bool split = src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT; int64_t min_compute_capability = INT_MAX; - for (int i = 0; i < g_device_count; ++i) { - if (min_compute_capability > ggml_sycl_info().devices[i].cc && g_tensor_split[i] < (i + 1 < g_device_count ? g_tensor_split[i + 1] : 1.0f)) { + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { + if (min_compute_capability > ggml_sycl_info().devices[i].cc && g_tensor_split[i] < (i + 1 < ggml_sycl_info().device_count ? g_tensor_split[i + 1] : 1.0f)) { min_compute_capability = ggml_sycl_info().devices[i].cc; } } @@ -4841,18 +4832,18 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { // KQ single-batch // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_p021\n"); - ggml_sycl_mul_mat_vec_p021(src0, src1, dst); + ggml_sycl_mul_mat_vec_p021(ctx, src0, src1, dst); } else if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { // KQV single-batch // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_nc\n"); - ggml_sycl_mul_mat_vec_nc(src0, src1, dst); + ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst); } else if (!split && all_on_device && use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) { // KQ + KQV multi-batch // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_batched_sycl\n"); - ggml_sycl_mul_mat_batched_sycl(src0, src1, dst); + ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst); } else if (src0->type == GGML_TYPE_F32) { // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat\n"); - ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false); + ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false); } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { // GGML_SYCL_DEBUG("ggml_is_quantized or GGML_TYPE_F16\n"); if (src1->ne[1] == 1 && src0->ne[0] % GGML_SYCL_DMMV_X == 0) { @@ -4871,10 +4862,10 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 if (use_mul_mat_vec_q) { // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_vec_q path\n"); - ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true); + ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true); } else { // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_dequantize_mul_mat_vec path\n"); - ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false); + ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false); } } else { bool use_mul_mat_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type); @@ -4885,10 +4876,10 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 if (use_mul_mat_q) { // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_q path\n"); - ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true); + ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true); } else { // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_sycl path\n"); - ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false); + ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false); } } } else { @@ -4903,7 +4894,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT && "mul_mat_id does not support split buffers"); const ggml_tensor *ids = dst->src[2]; - const queue_ptr stream = ctx.streams(g_main_device, 0); + const queue_ptr stream = ctx.stream(g_main_device, 0); const size_t nb11 = src1->nb[1]; const size_t nb1 = dst->nb[1]; @@ -4969,11 +4960,11 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, dst_row_extra.data_device[g_main_device] = dst_original + i01 * dst->nb[1]; - ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row); + ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row); } } else { - ggml_sycl_pool_alloc src1_contiguous(sizeof(float)*ggml_nelements(src1)); - ggml_sycl_pool_alloc dst_contiguous(sizeof(float)*ggml_nelements(dst)); + ggml_sycl_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); + ggml_sycl_pool_alloc dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); src1_row_extra.data_device[g_main_device] = src1_contiguous.get(); dst_row_extra.data_device[g_main_device] = dst_contiguous.get(); @@ -5013,7 +5004,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, dst_row.nb[2] = num_src1_rows*nb1; dst_row.nb[3] = num_src1_rows*nb1; - ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row); + ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row); num_src1_rows = 0; for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { @@ -5043,12 +5034,12 @@ catch (sycl::exception const &exc) { std::exit(1); } -static void ggml_sycl_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_scale); +static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_scale); } -static void ggml_sycl_clamp(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_clamp); +static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_clamp); } static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, @@ -5066,7 +5057,7 @@ static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, GGML_TENSOR_BINARY_OP_LOCALS; SYCL_CHECK(ggml_sycl_set_device(g_main_device)); - queue_ptr main_stream = ctx.streams(g_main_device, 0); + queue_ptr main_stream = ctx.stream(g_main_device, 0); const ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; const ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; @@ -5106,48 +5097,48 @@ catch (sycl::exception const &exc) { std::exit(1); } -static void ggml_sycl_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_dup(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { // TODO: why do we pass dst as src1 here? - ggml_sycl_cpy(src0, dst, nullptr); + ggml_sycl_cpy(ctx, src0, dst, nullptr); (void) src1; } -static void ggml_sycl_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_diag_mask_inf); +static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_diag_mask_inf); } -static void ggml_sycl_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_soft_max); +static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_soft_max); } -static void ggml_sycl_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_rope); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rope); } -static void ggml_sycl_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_alibi); +static void ggml_sycl_alibi(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_alibi); } -static void ggml_sycl_pool2d(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_pool2d); +static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_pool2d); } -static void ggml_sycl_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_im2col); +static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_im2col); } -static void ggml_sycl_sum_rows(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(ggml_is_contiguous(src0)); - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_sum_rows); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum_rows); } -static void ggml_sycl_argsort(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(ggml_is_contiguous(src0)); - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_argsort); + ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argsort); } -static void ggml_sycl_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_nop(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { (void) src0; (void) src1; (void) dst; @@ -5339,7 +5330,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens ggml_sycl_set_peer_access(tensor->src[1]->ne[1]); } - func(tensor->src[0], tensor->src[1], tensor); + func(ctx, tensor->src[0], tensor->src[1], tensor); return true; } @@ -5431,8 +5422,8 @@ struct ggml_backend_sycl_buffer_context { queue_ptr stream; std::string name; - ggml_backend_sycl_buffer_context(int device, void * dev_ptr) : - device(device), dev_ptr(dev_ptr) { + ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) : + device(device), dev_ptr(dev_ptr), stream(stream) { check_allow_gpu_index(device); int id = g_sycl_gpu_mgr->gpus[device]; name = (GGML_SYCL_NAME + std::to_string(id)); @@ -5490,7 +5481,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer, size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor); if (padded_size > original_size && tensor->view_src == nullptr) { - SYCL_CHECK(CHECK_TRY_ERROR((*(dpct->stream))->memset( + SYCL_CHECK(CHECK_TRY_ERROR((ctx->stream)->memset( (char *)tensor->data + original_size, 0, padded_size - original_size).wait())); } @@ -5557,7 +5548,7 @@ ggml_backend_sycl_buffer_cpy_tensor(ggml_backend_buffer_t buffer, ggml_tensor *dst) try { if (ggml_backend_buffer_is_sycl(src->buffer)) { ggml_backend_sycl_buffer_context * src_ctx = (ggml_backend_sycl_buffer_context *)src->buffer->context; - ggml_backend_sycl_buffer_context * dst_ctx = (ggml_backend_sycl_buffer_context *)buffer->context; + ggml_backend_sycl_buffer_context * dst_ctx = (ggml_backend_sycl_buffer_context *)dst->buffer->context; ggml_sycl_set_device(src_ctx->device); /* @@ -5580,9 +5571,8 @@ ggml_backend_sycl_buffer_cpy_tensor(ggml_backend_buffer_t buffer, error codes. The original code was commented out and a warning string was inserted. You need to rewrite this code. */ - auto backend_ctx = ggml_sycl_info(); - queue_ptr stream_dst = backend_ctx.stream(dst_ctx->device, 0); - queue_ptr stream_src = backend_ctx.stream(src_ctx->device, 0); + queue_ptr stream_dst = dst_ctx->stream; + queue_ptr stream_src = src_ctx->stream; size_t size = ggml_nbytes(src); //todo. it's dirty solutino to walkaroud known issue:device2device cross GPUs. @@ -5617,7 +5607,6 @@ static void ggml_backend_sycl_buffer_clear(ggml_backend_buffer_t buffer, ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; ggml_sycl_set_device(ctx->device); - auto backend_ctx = ggml_sycl_info(); queue_ptr stream = ctx->stream; SYCL_CHECK( CHECK_TRY_ERROR(dpct::get_current_device().queues_wait_and_throw())); @@ -5662,8 +5651,7 @@ ggml_backend_sycl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) try { ggml_backend_sycl_buffer_type_context * buft_ctx = (ggml_backend_sycl_buffer_type_context *)buft->context; ggml_sycl_set_device(buft_ctx->device); - auto backend_ctx = ggml_sycl_info(); - const queue_ptr stream = backend_ctx.stream(buft_ctx->device, 0); + const queue_ptr stream = buft_ctx->stream; size = std::max(size, (size_t)1); // syclMalloc returns null for size 0 void * dev_ptr; @@ -5723,26 +5711,52 @@ static ggml_backend_buffer_type_i ggml_backend_sycl_buffer_type_interface = { /* .is_host = */ nullptr, }; -ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context* ctx) { +ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) { GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n"); - if (ctx->device>=g_device_count or device_index<0) { + if (device>=ggml_sycl_info().device_count or device<0) { printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n", - ctx->device, g_device_count-1); - GGML_ASSERT(ctx->deviceget_co_ctx(), device_i); ggml_backend_sycl_buffer_types[i] = { /* .iface = */ ggml_backend_sycl_buffer_type_interface, - /* .context = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(g_sycl_gpu_mgr->gpus[i]), ctx.stream(g_sycl_gpu_mgr->gpus[i], 0)}, + /* .context = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(g_sycl_gpu_mgr->gpus[i]), stream}, }; } g_ggml_backend_sycl_buffer_type_initialized = true; } - return &ggml_backend_sycl_buffer_types[device_index]; + return &ggml_backend_sycl_buffer_types[device]; +} + +ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) { + GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n"); + + int device = ctx->device; + if (device>=ggml_sycl_info().device_count or device<0) { + printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n", + device, ggml_sycl_info().device_count-1); + GGML_ASSERT(devicegpus[i]), ctx->stream(i, 0)}, + }; + } + g_ggml_backend_sycl_buffer_type_initialized = true; + } + return &ggml_backend_sycl_buffer_types[device]; } // sycl split buffer type @@ -5752,7 +5766,7 @@ static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tens *row_low = id == 0 ? 0 : nrows*tensor_split[id]; *row_low -= *row_low % rounding; - if (id == g_device_count - 1) { + if (id == ggml_sycl_info().device_count - 1) { *row_high = nrows; } else { *row_high = nrows*tensor_split[id + 1]; @@ -5763,9 +5777,9 @@ static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tens struct ggml_backend_sycl_split_buffer_context { ~ggml_backend_sycl_split_buffer_context() try { for (ggml_tensor_extra_gpu * extra : tensor_extras) { - for (int i = 0; i < g_device_count; ++i) { + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { // int id = g_sycl_gpu_mgr->gpus[i]; - for (int64_t is = 0; is < MAX_STREAMS; ++is) { + for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) { if (extra->events[i][is] != nullptr) { /* DPCT1009:206: SYCL uses exceptions to report errors and @@ -5785,9 +5799,9 @@ struct ggml_backend_sycl_split_buffer_context { code. */ ggml_sycl_set_device(i); - const queue_ptr stream = backend_ctx.stream(i, 0); + const queue_ptr stream = streams[i]; SYCL_CHECK(CHECK_TRY_ERROR(sycl::free( - extra->data_device[i], *stream))); + extra->data_device[i], *(streams[i])))); } } delete extra; @@ -5800,6 +5814,7 @@ struct ggml_backend_sycl_split_buffer_context { } std::vector tensor_extras; + std::vector streams; }; GGML_CALL static const char * ggml_backend_sycl_split_buffer_get_name(ggml_backend_buffer_t buffer) { @@ -5838,8 +5853,10 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; ctx->tensor_extras.push_back(extra); + ctx->streams.push_back(dpct::get_current_device().create_queue( + g_sycl_gpu_mgr->get_co_ctx(), dpct::get_current_device())); - for (int i = 0; i < g_device_count; ++i) { + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { // int id = g_sycl_gpu_mgr->gpus[i]; int64_t row_low, row_high; get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i); @@ -5860,8 +5877,7 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer, // FIXME: do not crash if syclMalloc fails // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first ggml_sycl_set_device(i); - auto backend_ctx = ggml_sycl_info(); - const queue_ptr stream = backend_ctx.stream(i, 0); + const queue_ptr stream = ctx->streams[i]; char * buf; /* DPCT1009:208: SYCL uses exceptions to report errors and does not use the @@ -5886,7 +5902,7 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer, extra->data_device[i] = buf; - for (int64_t is = 0; is < MAX_STREAMS; ++is) { + for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) { /* DPCT1009:210: SYCL uses exceptions to report errors and does not use the error codes. The original code was commented out and a warning @@ -5913,13 +5929,14 @@ ggml_backend_sycl_split_buffer_set_tensor(ggml_backend_buffer_t buffer, GGML_ASSERT(offset == 0); GGML_ASSERT(size == ggml_nbytes(tensor)); + ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context; ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context; const int64_t ne0 = tensor->ne[0]; const size_t nb1 = tensor->nb[1]; ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra; - for (int i = 0; i < g_device_count; ++i) { + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { // int id = g_sycl_gpu_mgr->gpus[i]; int64_t row_low, row_high; get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i); @@ -5945,8 +5962,7 @@ ggml_backend_sycl_split_buffer_set_tensor(ggml_backend_buffer_t buffer, was inserted. You need to rewrite this code. */ ggml_sycl_set_device(i); - auto backend_ctx = ggml_sycl_info(); - const queue_ptr stream = backend_ctx.stream(i, 0); + const queue_ptr stream = ctx->streams[i]; SYCL_CHECK(CHECK_TRY_ERROR( (*stream) .memcpy(extra->data_device[i], buf_host, original_size) @@ -5967,13 +5983,14 @@ ggml_backend_sycl_split_buffer_get_tensor(ggml_backend_buffer_t buffer, GGML_ASSERT(offset == 0); GGML_ASSERT(size == ggml_nbytes(tensor)); + ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context; ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context; const int64_t ne0 = tensor->ne[0]; const size_t nb1 = tensor->nb[1]; ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra; - for (int i = 0; i < g_device_count; ++i) { + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { // int id = g_sycl_gpu_mgr->gpus[i]; int64_t row_low, row_high; get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i); @@ -5999,8 +6016,7 @@ ggml_backend_sycl_split_buffer_get_tensor(ggml_backend_buffer_t buffer, was inserted. You need to rewrite this code. */ ggml_sycl_set_device(i); - auto backend_ctx = ggml_sycl_info(); - const queue_ptr stream = backend_ctx.stream(i, 0); + const queue_ptr stream = ctx->streams[i]; SYCL_CHECK(CHECK_TRY_ERROR( (*stream) .memcpy(buf_host, extra->data_device[i], original_size) @@ -6058,7 +6074,7 @@ GGML_CALL static size_t ggml_backend_sycl_split_buffer_type_get_alloc_size(ggml_ const int64_t ne0 = tensor->ne[0]; - for (int i = 0; i < g_device_count; ++i) { + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { // int id = g_sycl_gpu_mgr->gpus[i]; int64_t row_low, row_high; get_row_split(&row_low, &row_high, tensor, ctx->tensor_split, i); @@ -6114,12 +6130,12 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const f tensor_split_arr = g_default_tensor_split; } else { float split_sum = 0.0f; - for (int i = 0; i < g_device_count; ++i) { + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { // int id = g_sycl_gpu_mgr->gpus[i]; tensor_split_arr[i] = split_sum; split_sum += tensor_split[i]; } - for (int i = 0; i < g_device_count; ++i) { + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { // int id = g_sycl_gpu_mgr->gpus[i]; tensor_split_arr[i] /= split_sum; } @@ -6222,9 +6238,8 @@ GGML_CALL static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend, ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; GGML_ASSERT(tensor->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx) && "unsupported buffer type"); GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU); - auto backend_ctx = backend->context; - const queue_ptr stream = backend_ctx.stream(sycl_ctx->device, 0); - SYCL_CHECK(CHECK_TRY_ERROR((*stream)->memcpy( + const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0); + SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy( (char *)tensor->data + offset, data, size).wait())); } catch (sycl::exception const &exc) { @@ -6240,9 +6255,8 @@ GGML_CALL static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend, ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; GGML_ASSERT(tensor->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx) && "unsupported buffer type"); GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU); - auto backend_ctx = backend->context; - const queue_ptr stream = backend_ctx.stream(sycl_ctx->device, 0); - SYCL_CHECK(CHECK_TRY_ERROR((*stream)->memcpy( + const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0); + SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy( data, (const char *)tensor->data + offset, size).wait())); } catch (sycl::exception const &exc) { @@ -6261,9 +6275,8 @@ GGML_CALL static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend, error codes. The original code was commented out and a warning string was inserted. You need to rewrite this code. */ - auto backend_ctx = backend->context; - const queue_ptr stream = backend_ctx.stream(sycl_ctx->device, 0); - SYCL_CHECK(CHECK_TRY_ERROR((*stream)->memcpy( + const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0); + SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy( dst->data, src->data, ggml_nbytes(dst)).wait())); return true; } @@ -6278,8 +6291,8 @@ catch (sycl::exception const &exc) { static void ggml_backend_sycl_synchronize(ggml_backend_t backend) try { ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; - const queue_ptr stream = sycl_ctx.stream(sycl_ctx->device, 0); - SYCL_CHECK(CHECK_TRY_ERROR((*stream)->wait())); + const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0); + SYCL_CHECK(CHECK_TRY_ERROR((stream)->wait())); UNUSED(backend); } @@ -6293,9 +6306,6 @@ GGML_CALL static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t back ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; ggml_sycl_set_main_device(sycl_ctx->device); - ggml_compute_params params = {}; - params.type = GGML_TASK_TYPE_COMPUTE; - params.ith = 0; for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { @@ -6314,7 +6324,7 @@ GGML_CALL static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t back } } #endif - bool ok = ggml_sycl_compute_forward(¶ms, node); + bool ok = ggml_sycl_compute_forward(*sycl_ctx, node); if (!ok) { fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); } @@ -6538,12 +6548,12 @@ GGML_API GGML_CALL int ggml_backend_sycl_get_device_id(int device_index) { extern "C" int ggml_backend_sycl_reg_devices(); int ggml_backend_sycl_reg_devices() { - assert(g_device_count>0); - for (int i = 0; i < g_device_count; i++) { + assert(ggml_sycl_info().device_count>0); + for (int i = 0; i < ggml_sycl_info().device_count; i++) { int id = g_sycl_gpu_mgr->gpus[i]; char name[128]; snprintf(name, sizeof(name), "%s%d", GGML_SYCL_NAME, id); ggml_backend_register(name, ggml_backend_reg_sycl_init, ggml_backend_sycl_buffer_type(i), (void *) (intptr_t) i); } - return g_device_count; + return ggml_sycl_info().device_count; } diff --git a/ggml-sycl/common.hpp b/ggml-sycl/common.hpp index febea98947d76..e354e720715d6 100644 --- a/ggml-sycl/common.hpp +++ b/ggml-sycl/common.hpp @@ -79,7 +79,6 @@ static int g_work_group_size = 0; #endif typedef sycl::queue *queue_ptr; -typedef sycl::handler *handle_ptr; enum ggml_sycl_backend_gpu_mode { SYCL_UNSET_GPU_MODE = -1, @@ -313,13 +312,12 @@ class sycl_gpu_mgr { }; static sycl_gpu_mgr* g_sycl_gpu_mgr = new sycl_gpu_mgr(0); -static int g_device_count = -1; static int g_all_sycl_device_count = -1; static int g_main_device = -1; static int g_main_device_id = -1; static bool g_ggml_backend_sycl_buffer_type_initialized = false; -static std::array g_default_tensor_split = {}; +static std::array g_default_tensor_split = {}; static float g_tensor_split[GGML_SYCL_MAX_DEVICES] = {0}; @@ -341,25 +339,6 @@ int get_main_device(); (void)bad_arch; // suppress unused function warning } -/* -device_index: device index from 0 to n (continue numbers). - It is used for device select/set in SYCL backend internal data structure. -*/ -inline void check_allow_gpu_index(const int device_index) { - if (device_index >= g_device_count) { - char error_buf[256]; - snprintf( - error_buf, - sizeof(error_buf), - "%s error: device_index:%d is out of range: [0-%d]", - __func__, - device_index, - g_device_count - 1); - fprintf(stderr, "%s\n", error_buf); - assert(false); - } -} - /* device_id: device ID is shown by ggml_backend_sycl_print_sycl_devices(). It is only used to set current working device. @@ -487,30 +466,16 @@ struct ggml_backend_sycl_context { std::string name; queue_ptr qptrs[GGML_SYCL_MAX_DEVICES][GGML_SYCL_MAX_STREAMS] = { { nullptr } }; - static sycl::handler * sycl_handles[GGML_SYCL_MAX_DEVICES] = {nullptr}; explicit ggml_backend_sycl_context(int device) : device(device), name(GGML_SYCL_NAME + std::to_string(device)) { } - ~ggml_backend_sycl_context() { - for (int i = 0; i < GGML_SYCL_MAX_DEVICES; ++i) { - for (int j = 0; j < GGML_SYCL_MAX_STREAMS; ++j) { - if (qptrs[i][j] != nullptr) { - SYCL_CHECK(free(qptrs[i][j])); - } - } - if (cublas_handles[i] != nullptr) { - SYCL_CHECK(free(sycl_handles[i])); - } - } - } - queue_ptr stream(int device, int stream) { if (qptrs[device][stream] == nullptr) { - SYCL_CHECK(dpct::get_current_device().create_queue( - g_sycl_gpu_mgr->get_co_ctx(), dpct::get_current_device()))); + qptrs[device][stream] = (dpct::get_current_device().create_queue( + g_sycl_gpu_mgr->get_co_ctx(), dpct::get_current_device())); } return qptrs[device][stream]; } @@ -519,19 +484,6 @@ struct ggml_backend_sycl_context { return stream(device, 0); } - handle_ptr sycl_handle(int device) { - if (sycl_handles[device] == nullptr) { - const dpct::queue_ptr stream = qptrs[device][0]; - // create sycl handle - SYCL_CHECK(CHECK_TRY_ERROR(sycl_handles[device] = stream)); - } - return sycl_handles[device]; - } - - handle_ptr sycl_handle() { - return sycl_handle(device); - } - // pool std::unique_ptr pools[GGML_SYCL_MAX_DEVICES]; @@ -539,7 +491,7 @@ struct ggml_backend_sycl_context { ggml_sycl_pool & pool(int device) { if (pools[device] == nullptr) { - pools[device] = new_pool_for_device(qptrs[device][0], device); + pools[device] = new_pool_for_device(stream(device,0), device); } return *pools[device]; } diff --git a/ggml-sycl/dmmv.cpp b/ggml-sycl/dmmv.cpp index 7b2849ff4c3a6..244ffe4390a99 100644 --- a/ggml-sycl/dmmv.cpp +++ b/ggml-sycl/dmmv.cpp @@ -943,6 +943,7 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y, } void ggml_sycl_op_dequantize_mul_mat_vec( + ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, float *dst_dd_i, const int64_t row_low, const int64_t row_high, diff --git a/ggml-sycl/dmmv.hpp b/ggml-sycl/dmmv.hpp index df555cc1cf264..bd83735641533 100644 --- a/ggml-sycl/dmmv.hpp +++ b/ggml-sycl/dmmv.hpp @@ -17,6 +17,7 @@ void ggml_sycl_op_dequantize_mul_mat_vec( + ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, float *dst_dd_i, const int64_t row_low, const int64_t row_high, diff --git a/ggml-sycl/mmq.cpp b/ggml-sycl/mmq.cpp index d98f406d99733..cb864d0062fbb 100644 --- a/ggml-sycl/mmq.cpp +++ b/ggml-sycl/mmq.cpp @@ -2960,6 +2960,7 @@ catch (sycl::exception const &exc) { } void ggml_sycl_op_mul_mat_q( + ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, float *dst_dd_i, const int64_t row_low, const int64_t row_high, diff --git a/ggml-sycl/mmq.hpp b/ggml-sycl/mmq.hpp index 74edc16d2be95..3f5297aaa5373 100644 --- a/ggml-sycl/mmq.hpp +++ b/ggml-sycl/mmq.hpp @@ -16,6 +16,7 @@ #include "common.hpp" void ggml_sycl_op_mul_mat_q( + ggml_backend_sycl_context & ctx, const ggml_tensor* src0, const ggml_tensor* src1, ggml_tensor* dst, diff --git a/ggml-sycl/mmvq.cpp b/ggml-sycl/mmvq.cpp index deee7b4bd7f5d..bc0cfaa8fdf53 100644 --- a/ggml-sycl/mmvq.cpp +++ b/ggml-sycl/mmvq.cpp @@ -932,6 +932,7 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy, } void ggml_sycl_op_mul_mat_vec_q( + ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, float *dst_dd_i, const int64_t row_low, const int64_t row_high, diff --git a/ggml-sycl/mmvq.hpp b/ggml-sycl/mmvq.hpp index 5497828bb3e18..049b43d453532 100644 --- a/ggml-sycl/mmvq.hpp +++ b/ggml-sycl/mmvq.hpp @@ -17,6 +17,7 @@ void ggml_sycl_op_mul_mat_vec_q( + ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, float *dst_dd_i, const int64_t row_low, const int64_t row_high,