Skip to content

Commit

Permalink
ggml : add error handling to graph_compute (whisper/1714)
Browse files Browse the repository at this point in the history
  • Loading branch information
finnvoor authored and ggerganov committed Jan 4, 2024
1 parent ae29248 commit 9a867f1
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 19 deletions.
24 changes: 16 additions & 8 deletions examples/whisper/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
// ggml helpers
//

static void ggml_graph_compute_helper(
static bool ggml_graph_compute_helper(
struct ggml_cgraph * graph,
std::vector<uint8_t> & buf,
int n_threads,
Expand All @@ -168,10 +168,10 @@ static void ggml_graph_compute_helper(
plan.work_data = buf.data();
}

ggml_graph_compute(graph, &plan);
return ggml_graph_compute(graph, &plan);
}

static void ggml_graph_compute_helper(
static bool ggml_graph_compute_helper(
struct ggml_backend * backend,
struct ggml_cgraph * graph,
int n_threads) {
Expand All @@ -183,7 +183,7 @@ static void ggml_graph_compute_helper(
ggml_backend_metal_set_n_cb(backend, n_threads);
}
#endif
ggml_backend_graph_compute(backend, graph);
return ggml_backend_graph_compute(backend, graph);
}

// faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
Expand Down Expand Up @@ -2103,7 +2103,9 @@ static bool whisper_encode_internal(
ggml_allocr_alloc_graph(alloc, gf);

if (!whisper_encode_external(wstate)) {
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
return false;
}
}
}

Expand All @@ -2117,7 +2119,9 @@ static bool whisper_encode_internal(

ggml_allocr_alloc_graph(alloc, gf);

ggml_graph_compute_helper(wstate.backend, gf, n_threads);
if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
return false;
}
}

// cross
Expand All @@ -2130,7 +2134,9 @@ static bool whisper_encode_internal(

ggml_allocr_alloc_graph(alloc, gf);

ggml_graph_compute_helper(wstate.backend, gf, n_threads);
if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
return false;
}
}

wstate.t_encode_us += ggml_time_us() - t_start_us;
Expand Down Expand Up @@ -2552,7 +2558,9 @@ static bool whisper_decode_internal(

logits = gf->nodes[gf->n_nodes - 1];

ggml_graph_compute_helper(wstate.backend, gf, n_threads);
if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
return false;
}
}

logits_out.resize(n_tokens*n_vocab);
Expand Down
2 changes: 1 addition & 1 deletion include/ggml/ggml-backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ extern "C" {

GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
GGML_API void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
GGML_API void ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
GGML_API bool ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
GGML_API bool ggml_backend_supports_op (ggml_backend_t backend, const struct ggml_tensor * op);

// tensor copy between different backends
Expand Down
2 changes: 1 addition & 1 deletion src/ggml-backend-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ extern "C" {
void (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);

// compute graph without a plan
void (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
bool (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);

// check if the backend supports an operation
bool (*supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
Expand Down
10 changes: 7 additions & 3 deletions src/ggml-backend.c
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,14 @@ void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_
ggml_backend_synchronize(backend);
}

void ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
backend->iface.graph_compute(backend, cgraph);
bool ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
if (!backend->iface.graph_compute(backend, cgraph)) {
return false;
}

// TODO: optional sync
ggml_backend_synchronize(backend);
return true;
}

bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
Expand Down Expand Up @@ -597,7 +600,7 @@ static void ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_bac
GGML_UNUSED(backend);
}

static void ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
static bool ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;

struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads);
Expand All @@ -611,6 +614,7 @@ static void ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_c
cplan.work_data = cpu_ctx->work_data;

ggml_graph_compute(cgraph, &cplan);
return true;
}

static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
Expand Down
4 changes: 3 additions & 1 deletion src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9910,7 +9910,7 @@ static void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_ba
UNUSED(plan);
}

static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
static bool ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;

ggml_cuda_set_main_device(cuda_ctx->device);
Expand Down Expand Up @@ -9967,6 +9967,8 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph
}

UNUSED(backend);

return true;
}

static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
Expand Down
2 changes: 1 addition & 1 deletion src/ggml-metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx);

// same as ggml_graph_compute but uses Metal
// creates gf->n_threads command buffers in parallel
void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
bool ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);

//
// backend API
Expand Down
9 changes: 5 additions & 4 deletions src/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,7 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
return false;
}
}
void ggml_metal_graph_compute(
bool ggml_metal_graph_compute(
struct ggml_metal_context * ctx,
struct ggml_cgraph * gf) {
@autoreleasepool {
Expand Down Expand Up @@ -2405,10 +2405,11 @@ void ggml_metal_graph_compute(
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
if (status != MTLCommandBufferStatusCompleted) {
GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
GGML_ASSERT(false);
return false;
}
}

return true;
}
}

Expand Down Expand Up @@ -2688,10 +2689,10 @@ static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggm
UNUSED(backend);
}

static void ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
static bool ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;

ggml_metal_graph_compute(metal_ctx, cgraph);
return ggml_metal_graph_compute(metal_ctx, cgraph);
}

static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
Expand Down

0 comments on commit 9a867f1

Please sign in to comment.