Skip to content

Commit

Permalink
ggml : add abort_callback for cpu backend (#725)
Browse files Browse the repository at this point in the history
* a way to use abort_callback with the cpu backend

* whisper update
  • Loading branch information
Xarbirus committed Feb 9, 2024
1 parent 6b14d73 commit 2c7cf49
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 19 deletions.
8 changes: 4 additions & 4 deletions examples/whisper/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,11 @@ static bool ggml_graph_compute_helper(
struct ggml_cgraph * graph,
std::vector<uint8_t> & buf,
int n_threads,
whisper_abort_callback abort_callback,
ggml_abort_callback abort_callback,
void * abort_callback_data) {
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);

plan.abort_callback = abort_callback;
plan.abort_callback = abort_callback;
plan.abort_callback_data = abort_callback_data;

if (plan.work_size > 0) {
Expand Down Expand Up @@ -2130,7 +2130,7 @@ static bool whisper_encode_internal(
whisper_state & wstate,
const int mel_offset,
const int n_threads,
whisper_abort_callback abort_callback,
ggml_abort_callback abort_callback,
void * abort_callback_data) {
const int64_t t_start_us = ggml_time_us();

Expand Down Expand Up @@ -2561,7 +2561,7 @@ static bool whisper_decode_internal(
whisper_state & wstate,
const whisper_batch & batch,
const int n_threads,
whisper_abort_callback abort_callback,
ggml_abort_callback abort_callback,
void * abort_callback_data) {
const int64_t t_start_us = ggml_time_us();

Expand Down
7 changes: 1 addition & 6 deletions examples/whisper/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,11 +411,6 @@ extern "C" {
// If it returns false, the computation is aborted
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);

// Abort callback
// If not NULL, called before ggml computation
// If it returns true, the computation is aborted
typedef bool (*whisper_abort_callback)(void * user_data);

// Logits filter callback
// Can be used to modify the logits before sampling
// If not NULL, called after applying temperature to logits
Expand Down Expand Up @@ -512,7 +507,7 @@ extern "C" {
void * encoder_begin_callback_user_data;

// called each time before ggml computation starts
whisper_abort_callback abort_callback;
ggml_abort_callback abort_callback;
void * abort_callback_user_data;

// called by each decoder to filter obtained logits
Expand Down
5 changes: 3 additions & 2 deletions include/ggml/ggml-backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ extern "C" {

GGML_API ggml_backend_t ggml_backend_cpu_init(void);

GGML_API GGML_CALL bool ggml_backend_is_cpu (ggml_backend_t backend);
GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads);
GGML_API GGML_CALL bool ggml_backend_is_cpu (ggml_backend_t backend);
GGML_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads);
GGML_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);

// Create a backend buffer from an existing pointer
GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
Expand Down
9 changes: 7 additions & 2 deletions include/ggml/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,11 @@ extern "C" {

static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);

// Abort callback
// If not NULL, called before ggml computation
// If it returns true, the computation is aborted
typedef bool (*ggml_abort_callback)(void * data);

// the compute plan that needs to be prepared for ggml_graph_compute()
// since https://github.com/ggerganov/ggml/issues/287
struct ggml_cplan {
Expand All @@ -576,8 +581,8 @@ extern "C" {
int n_threads;

// abort ggml_graph_compute when true
bool (*abort_callback)(void * data);
void * abort_callback_data;
ggml_abort_callback abort_callback;
void * abort_callback_data;
};

enum ggml_cgraph_eval_order {
Expand Down
26 changes: 22 additions & 4 deletions src/ggml-backend.c
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,9 @@ struct ggml_backend_cpu_context {
int n_threads;
void * work_data;
size_t work_size;

ggml_abort_callback abort_callback;
void * abort_callback_data;
};

GGML_CALL static const char * ggml_backend_cpu_name(ggml_backend_t backend) {
Expand Down Expand Up @@ -691,6 +694,9 @@ GGML_CALL static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(gg
cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size);
}

cpu_plan->cplan.abort_callback = cpu_ctx->abort_callback;
cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data;

return cpu_plan;
}

Expand Down Expand Up @@ -721,9 +727,11 @@ GGML_CALL static bool ggml_backend_cpu_graph_compute(ggml_backend_t backend, str
cpu_ctx->work_data = realloc(cpu_ctx->work_data, cplan.work_size);
cpu_ctx->work_size = cplan.work_size;
}

cplan.work_data = cpu_ctx->work_data;

cplan.abort_callback = cpu_ctx->abort_callback;
cplan.abort_callback_data = cpu_ctx->abort_callback_data;

ggml_graph_compute(cgraph, &cplan);
return true;
}
Expand Down Expand Up @@ -759,9 +767,11 @@ static struct ggml_backend_i cpu_backend_i = {
ggml_backend_t ggml_backend_cpu_init(void) {
struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context));

ctx->n_threads = GGML_DEFAULT_N_THREADS;
ctx->work_data = NULL;
ctx->work_size = 0;
ctx->n_threads = GGML_DEFAULT_N_THREADS;
ctx->work_data = NULL;
ctx->work_size = 0;
ctx->abort_callback = NULL;
ctx->abort_callback_data = NULL;

ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend));

Expand All @@ -783,6 +793,14 @@ void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
ctx->n_threads = n_threads;
}

void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data) {
GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));

struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
ctx->abort_callback = abort_callback;
ctx->abort_callback_data = abort_callback_data;
}

GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size);
}
Expand Down
2 changes: 1 addition & 1 deletion src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -16560,7 +16560,7 @@ struct ggml_compute_state_shared {
atomic_int node_n; // active graph node
atomic_int node_task; // active graph node task phase

bool (*abort_callback)(void * data); // abort ggml_graph_compute when true
ggml_abort_callback abort_callback; // abort ggml_graph_compute when true
void * abort_callback_data;
};

Expand Down

0 comments on commit 2c7cf49

Please sign in to comment.