Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL] remove global variables #7710

Merged
merged 10 commits into from
Jun 15, 2024
Prev Previous commit
Next Next commit
remove duplicate buft initialization
  • Loading branch information
airMeng committed Jun 13, 2024
commit 9c5476ead4247d526510d8b88fe5bbec87290ba0
24 changes: 16 additions & 8 deletions ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12606,6 +12606,9 @@ static ggml_backend_buffer_type_i ggml_backend_sycl_buffer_type_interface = {
};

ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);

GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");

if (device>=ggml_sycl_info().device_count or device<0) {
Expand Down Expand Up @@ -13000,6 +13003,9 @@ static ggml_backend_buffer_type_i ggml_backend_sycl_split_buffer_type_interface
};

GGML_CALL ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split) {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);

GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_split_buffer_type\n");
ggml_check_sycl();
// FIXME: this is not thread safe
Expand Down Expand Up @@ -13106,16 +13112,17 @@ GGML_CALL static void ggml_backend_sycl_free(ggml_backend_t backend) {

GGML_CALL static ggml_backend_buffer_type_t ggml_backend_sycl_get_default_buffer_type(ggml_backend_t backend) {
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
return ggml_backend_sycl_buffer_type(sycl_ctx);
return ggml_backend_sycl_buffer_type(sycl_ctx->device);
}

GGML_CALL static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend,
ggml_tensor *tensor,
const void *data, size_t offset,
size_t size) try {
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
GGML_ASSERT(tensor->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx) && "unsupported buffer type");
GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;

GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
(char *)tensor->data + offset, data, size).wait()));
Expand All @@ -13131,8 +13138,9 @@ GGML_CALL static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend,
void *data, size_t offset,
size_t size) try {
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
GGML_ASSERT(tensor->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx) && "unsupported buffer type");
GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;

GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
data, (const char *)tensor->data + offset, size).wait()));
Expand All @@ -13147,7 +13155,7 @@ GGML_CALL static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend,
const ggml_tensor *src,
ggml_tensor *dst) try {
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
if (dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx) && ggml_backend_buffer_is_sycl(src->buffer)) {
if (dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && ggml_backend_buffer_is_sycl(src->buffer)) {
/*
DPCT1009:215: SYCL uses exceptions to report errors and does not use the
error codes. The original code was commented out and a warning string
Expand Down Expand Up @@ -13191,10 +13199,10 @@ GGML_CALL static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t back
continue;
}
#ifndef NDEBUG
assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx));
assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));
for (int j = 0; j < GGML_MAX_SRC; j++) {
if (node->src[j] != nullptr) {
assert(node->src[j]->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx));
assert(node->src[j]->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));
}
}
#endif
Expand Down