diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index d7c9e0f0e..8fe05d3a5 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -201,8 +201,13 @@ #define GGML_MAX_NAME 48 #define GGML_DEFAULT_N_THREADS 4 + +#define GGML_EXIT_SUCCESS 0 +#define GGML_EXIT_ABORTED 1 + #define GGML_UNUSED(x) (void)(x) + #define GGML_ASSERT(x) \ do { \ if (!(x)) { \ @@ -442,6 +447,10 @@ extern "C" { // the `n_tasks` of nodes, 1:1 mapping to cgraph nodes int n_tasks[GGML_MAX_NODES]; + + // abort ggml_graph_compute when true + bool (*abort_callback)(void * data); + void * abort_callback_data; }; // computation graph @@ -1303,7 +1312,7 @@ extern "C" { // ggml_graph_plan() has to be called before ggml_graph_compute() // when plan.work_size > 0, caller must allocate memory for plan.work_data GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/); - GGML_API void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan); + GGML_API int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan); GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // same as ggml_graph_compute() but the work data is allocated as a part of the context diff --git a/src/ggml.c b/src/ggml.c index 487ad9483..149e069f0 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -25,6 +25,7 @@ #include #include #include +#include #ifdef GGML_USE_METAL #include @@ -15946,6 +15947,9 @@ struct ggml_compute_state_shared { // synchronization primitives atomic_int n_active; // num active threads atomic_int node_n; // active graph node + + bool (*abort_callback)(void * data); // abort ggml_graph_compute when true + void * abort_callback_data; }; struct ggml_compute_state { @@ -15977,6 +15981,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { int node_n = -1; while (true) { + if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { + state->shared->node_n += 1; + return GGML_EXIT_ABORTED; + } if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) { // all other threads are finished and spinning // do finalize and init here so we don't have synchronize again @@ -16030,6 +16038,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { } else { break; } + + if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { + break; + } } atomic_store(&state->shared->n_active, n_threads); @@ -16063,9 +16075,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { } } - return 0; + return GGML_EXIT_SUCCESS; } +static bool always_false(void * data) { return false; } struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { if (n_threads <= 0) { n_threads = GGML_DEFAULT_N_THREADS; @@ -16403,7 +16416,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { return cplan; } -void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) { +int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) { { GGML_ASSERT(cplan); GGML_ASSERT(cplan->n_threads > 0); @@ -16452,12 +16465,12 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) const int64_t perf_start_time_us = ggml_perf_time_us(); // this is a work thread too - ggml_graph_compute_thread(&workers[0]); + int compute_status = ggml_graph_compute_thread(&workers[0]); // don't leave affinity set on the main thread clear_numa_thread_affinity(); - // join thread pool + // join or kill thread pool if (n_threads > 1) { for (int j = 1; j < n_threads; j++) { const int rc = ggml_thread_join(workers[j].thrd, NULL); @@ -16481,6 +16494,8 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) (double) perf_time_us_cur / 1000.0, (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs); } + + return compute_status; } void ggml_graph_reset(struct ggml_cgraph * cgraph) {