Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

callback to abort ggml_graph_compute() #328

Merged
merged 13 commits into from
Jul 11, 2023
4 changes: 4 additions & 0 deletions include/ggml/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@
#define GGML_MAX_NAME 48
#define GGML_DEFAULT_N_THREADS 4

#define GGML_EXIT_SUCCESS 0
#define GGML_EXIT_ABORTED 1

ggerganov marked this conversation as resolved.
Show resolved Hide resolved
#define GGML_ASSERT(x) \
do { \
if (!(x)) { \
Expand Down Expand Up @@ -1271,6 +1274,7 @@ extern "C" {
GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);

GGML_API void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph);
GGML_API void ggml_graph_compute_with_abort(struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool (*abort_callback)());
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);

GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);
Expand Down
33 changes: 28 additions & 5 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <float.h>
#include <limits.h>
#include <stdarg.h>
#include <signal.h>

#ifdef GGML_USE_METAL
#include <unistd.h>
Expand Down Expand Up @@ -16654,6 +16655,7 @@ typedef pthread_t ggml_thread_t;

#define ggml_thread_create pthread_create
#define ggml_thread_join pthread_join
#define ggml_thread_cancel pthread_cancel

#else

Expand Down Expand Up @@ -16749,6 +16751,8 @@ struct ggml_compute_state_shared {
// synchronization primitives
atomic_int n_active; // num active threads
atomic_int node_n; // active graph node

bool (*abort_callback)(); // abort ggml_graph_compute when true
};

struct ggml_compute_state {
Expand Down Expand Up @@ -16776,6 +16780,9 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
int node_n = -1;

while (true) {
if (state->ith == 0 && state->shared->abort_callback()) {
return GGML_EXIT_ABORTED;
}
if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
// all other threads are finished and spinning
// do finalize and init here so we don't have synchronize again
Expand All @@ -16793,6 +16800,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
params.nth = node->n_tasks;
ggml_compute_forward(&params, node);
ggml_graph_compute_perf_stats_node(node, state->shared);

}

// distribute new work or execute it direct if 1T
Expand Down Expand Up @@ -16821,6 +16829,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
} else {
break;
}

if (state->shared->abort_callback()) {
break;
}
}

atomic_store(&state->shared->n_active, n_threads);
Expand Down Expand Up @@ -16853,10 +16865,15 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
}
}

return 0;
return GGML_EXIT_SUCCESS;
}

bool always_false() { return false; }
CCLDArjun marked this conversation as resolved.
Show resolved Hide resolved
void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
ggml_graph_compute_with_abort(ctx, cgraph, always_false);
}

void ggml_graph_compute_with_abort(struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool (*abort_callback)(void)) {
const int n_threads = cgraph->n_threads;

struct ggml_compute_state_shared state_shared = {
Expand All @@ -16866,6 +16883,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
/*.n_threads =*/ n_threads,
/*.n_active =*/ n_threads,
/*.node_n =*/ -1,
/*.abort_callback =*/ abort_callback,
};
struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);

Expand Down Expand Up @@ -17235,16 +17253,21 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
const int64_t perf_start_time_us = ggml_perf_time_us();

// this is a work thread too
ggml_graph_compute_thread(&workers[0]);
int compute_status = ggml_graph_compute_thread(&workers[0]);

// don't leave affinity set on the main thread
clear_numa_thread_affinity();

// join thread pool
// join or kill thread pool
if (n_threads > 1) {
for (int j = 1; j < n_threads; j++) {
const int rc = ggml_thread_join(workers[j].thrd, NULL);
GGML_ASSERT(rc == 0);
if (compute_status == GGML_EXIT_ABORTED) {
const int rc = ggml_thread_cancel(workers[j].thrd);
GGML_ASSERT(rc == 0);
} else if (compute_status == GGML_EXIT_SUCCESS) {
const int rc = ggml_thread_join(workers[j].thrd, NULL);
GGML_ASSERT(rc == 0);
}
CCLDArjun marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down