Skip to content

Commit

Permalink
ggml : add callback to abort ggml_graph_compute() (#328)
Browse files Browse the repository at this point in the history
* mechanism to abort ggml_graph_compute

* use pthread_cancel

* forgot to commit ggml.h

* static always_false()

Co-authored-by: Georgi Gerganov <[email protected]>

* accept callback data

* proper function prototype

* return exit status

* remove pthread_cancel and join every thread

* put abort_callback onto cplan

* cplan abort_callback in ggml.c

* make sure all threads abort

---------

Co-authored-by: Georgi Gerganov <[email protected]>
  • Loading branch information
CCLDArjun and ggerganov committed Jul 11, 2023
1 parent f5165d0 commit aded898
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
11 changes: 10 additions & 1 deletion include/ggml/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,13 @@
#define GGML_MAX_NAME 48
#define GGML_DEFAULT_N_THREADS 4


#define GGML_EXIT_SUCCESS 0
#define GGML_EXIT_ABORTED 1

#define GGML_UNUSED(x) (void)(x)


#define GGML_ASSERT(x) \
do { \
if (!(x)) { \
Expand Down Expand Up @@ -442,6 +447,10 @@ extern "C" {

// the `n_tasks` of nodes, 1:1 mapping to cgraph nodes
int n_tasks[GGML_MAX_NODES];

// abort ggml_graph_compute when true
bool (*abort_callback)(void * data);
void * abort_callback_data;
};

// computation graph
Expand Down Expand Up @@ -1303,7 +1312,7 @@ extern "C" {
// ggml_graph_plan() has to be called before ggml_graph_compute()
// when plan.work_size > 0, caller must allocate memory for plan.work_data
GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
GGML_API void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
GGML_API int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);

// same as ggml_graph_compute() but the work data is allocated as a part of the context
Expand Down
23 changes: 19 additions & 4 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <float.h>
#include <limits.h>
#include <stdarg.h>
#include <signal.h>

#ifdef GGML_USE_METAL
#include <unistd.h>
Expand Down Expand Up @@ -15955,6 +15956,9 @@ struct ggml_compute_state_shared {
// synchronization primitives
atomic_int n_active; // num active threads
atomic_int node_n; // active graph node

bool (*abort_callback)(void * data); // abort ggml_graph_compute when true
void * abort_callback_data;
};

struct ggml_compute_state {
Expand Down Expand Up @@ -15986,6 +15990,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
int node_n = -1;

while (true) {
if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
state->shared->node_n += 1;
return GGML_EXIT_ABORTED;
}
if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
// all other threads are finished and spinning
// do finalize and init here so we don't have synchronize again
Expand Down Expand Up @@ -16039,6 +16047,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
} else {
break;
}

if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
break;
}
}

atomic_store(&state->shared->n_active, n_threads);
Expand Down Expand Up @@ -16072,9 +16084,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
}
}

return 0;
return GGML_EXIT_SUCCESS;
}

static bool always_false(void * data) { return false; }
struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
if (n_threads <= 0) {
n_threads = GGML_DEFAULT_N_THREADS;
Expand Down Expand Up @@ -16412,7 +16425,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
return cplan;
}

void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
{
GGML_ASSERT(cplan);
GGML_ASSERT(cplan->n_threads > 0);
Expand Down Expand Up @@ -16461,12 +16474,12 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan)
const int64_t perf_start_time_us = ggml_perf_time_us();

// this is a work thread too
ggml_graph_compute_thread(&workers[0]);
int compute_status = ggml_graph_compute_thread(&workers[0]);

// don't leave affinity set on the main thread
clear_numa_thread_affinity();

// join thread pool
// join or kill thread pool
if (n_threads > 1) {
for (int j = 1; j < n_threads; j++) {
const int rc = ggml_thread_join(workers[j].thrd, NULL);
Expand All @@ -16490,6 +16503,8 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan)
(double) perf_time_us_cur / 1000.0,
(double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs);
}

return compute_status;
}

void ggml_graph_reset(struct ggml_cgraph * cgraph) {
Expand Down

0 comments on commit aded898

Please sign in to comment.