Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL] remove global variables #7710

Merged
merged 10 commits into from
Jun 15, 2024
Prev Previous commit
Next Next commit
update mul_mat condition
  • Loading branch information
airMeng committed Jun 13, 2024
commit abe11feab650416ad2f2bde0cac3cf989cd1fb73
20 changes: 2 additions & 18 deletions ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ Following definition copied from DPCT head files, which are used by ggml-sycl.cp
#endif

bool ggml_sycl_loaded(void);
bool ggml_sycl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
void ggml_sycl_free_data(struct ggml_tensor * tensor);
void ggml_sycl_assign_buffers(struct ggml_tensor * tensor);
void ggml_sycl_assign_buffers_no_scratch(struct ggml_tensor * tensor);
Expand Down Expand Up @@ -11375,21 +11374,6 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, const ggml_tenso
GGML_SYCL_DEBUG("call %s done\n", __func__);
}

bool ggml_sycl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
if (!g_sycl_loaded) return false;

const int64_t ne10 = src1->ne[0];

const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];

// TODO: find the optimal values for these
return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
src1->type == GGML_TYPE_F32 &&
dst->type == GGML_TYPE_F32 &&
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32);
}

static void ggml_sycl_mul_mat_vec_p021(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1,
ggml_tensor *dst) try {
Expand Down Expand Up @@ -12195,13 +12179,13 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
func = ggml_sycl_rms_norm;
break;
case GGML_OP_MUL_MAT:
if (ggml_sycl_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) {
if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
return false;
}
func = ggml_sycl_mul_mat;
break;
case GGML_OP_MUL_MAT_ID:
if (ggml_sycl_can_mul_mat(tensor->src[2], tensor->src[1], tensor)) {
if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
return false;
}
func = ggml_sycl_mul_mat_id;
Expand Down