Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

callback to abort ggml_graph_compute() #328

Merged
merged 13 commits into from
Jul 11, 2023
Prev Previous commit
Next Next commit
Merge branch 'master' into abort
  • Loading branch information
CCLDArjun committed Jul 11, 2023
commit e6a9fe6b3196f68ac28d2ecc39d7fa56592ab0bd
12 changes: 9 additions & 3 deletions include/ggml/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1305,9 +1305,15 @@ extern "C" {
GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);

GGML_API void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph);
GGML_API int ggml_graph_compute_with_abort(struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool (*abort_callback)(void * data), void * abort_callback_data);
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);
// ggml_graph_plan() has to be called before ggml_graph_compute()
// when plan.work_size > 0, caller must allocate memory for plan.work_data
GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
GGML_API void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);

// same as ggml_graph_compute() but the work data is allocated as a part of the context
// note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
GGML_API void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);

GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);

Expand Down
358 changes: 351 additions & 7 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -16078,14 +16078,360 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
}

static bool always_false(void * data) { return false; }
struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
if (n_threads <= 0) {
n_threads = GGML_DEFAULT_N_THREADS;
}

size_t work_size = 0;

struct ggml_cplan cplan;
memset(&cplan, 0, sizeof(struct ggml_cplan));

// thread scheduling for the different operations + work buffer size estimation
for (int i = 0; i < cgraph->n_nodes; i++) {
int n_tasks = 1;

struct ggml_tensor * node = cgraph->nodes[i];

switch (node->op) {
case GGML_OP_CPY:
case GGML_OP_DUP:
{
n_tasks = n_threads;

size_t cur = 0;
if (ggml_is_quantized(node->type)) {
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_tasks;
}

work_size = MAX(work_size, cur);
} break;
case GGML_OP_ADD:
case GGML_OP_ADD1:
{
n_tasks = n_threads;

size_t cur = 0;

if (ggml_is_quantized(node->src[0]->type)) {
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src[0]->ne[0] * n_tasks;
}

work_size = MAX(work_size, cur);
} break;
case GGML_OP_ACC:
{
n_tasks = n_threads;

size_t cur = 0;

if (ggml_is_quantized(node->src[0]->type)) {
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src[1]->ne[0] * n_tasks;
}

work_size = MAX(work_size, cur);
} break;
case GGML_OP_SUB:
case GGML_OP_DIV:
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_LOG:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_ARGMAX:
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
case GGML_OP_ABS:
case GGML_OP_SGN:
case GGML_OP_NEG:
case GGML_OP_STEP:
case GGML_OP_TANH:
case GGML_OP_ELU:
case GGML_OP_RELU:
{
n_tasks = 1;
} break;
case GGML_OP_MUL:
case GGML_OP_GELU:
case GGML_OP_GELU_QUICK:
case GGML_OP_SILU:
case GGML_OP_SILU_BACK:
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
{
n_tasks = n_threads;
} break;
case GGML_OP_MUL_MAT:
case GGML_OP_OUT_PROD:
{
n_tasks = n_threads;

// TODO: use different scheduling for different matrix sizes
//const int nr0 = ggml_nrows(node->src[0]);
//const int nr1 = ggml_nrows(node->src[1]);

//n_tasks = MIN(n_threads, MAX(1, nr0/128));
//printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks);

size_t cur = 0;
const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type;

#if defined(GGML_USE_CUBLAS)
if (ggml_cuda_can_mul_mat(node->src[0], node->src[1], node)) {
n_tasks = 1; // TODO: this actually is doing nothing
// the threads are still spinning
} else
#elif defined(GGML_USE_CLBLAST)
if (ggml_cl_can_mul_mat(node->src[0], node->src[1], node)) {
n_tasks = 1; // TODO: this actually is doing nothing
// the threads are still spinning
cur = ggml_cl_mul_mat_get_wsize(node->src[0], node->src[1], node);
} else
#endif
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (ggml_compute_forward_mul_mat_use_blas(node->src[0], node->src[1], node)) {
n_tasks = 1; // TODO: this actually is doing nothing
// the threads are still spinning
if (node->src[0]->type != GGML_TYPE_F32) {
// here we need memory just for single 2D matrix from src0
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src[0]->ne[0]*node->src[0]->ne[1]);
}
} else
#endif
if (node->src[1]->type != vec_dot_type) {
cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src[1])/GGML_BLCK_SIZE[vec_dot_type];
} else {
cur = 0;
}

work_size = MAX(work_size, cur);
} break;
case GGML_OP_SCALE:
{
n_tasks = 1;
} break;
case GGML_OP_SET:
case GGML_OP_CONT:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_GET_ROWS:
case GGML_OP_GET_ROWS_BACK:
case GGML_OP_DIAG:
case GGML_OP_DIAG_MASK_ZERO:
{
n_tasks = 1;
} break;
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
{
n_tasks = n_threads;
} break;
case GGML_OP_ALIBI:
{
n_tasks = 1; //TODO
} break;
case GGML_OP_CLAMP:
{
n_tasks = 1; //TODO
} break;
case GGML_OP_CONV_1D:
{
n_tasks = n_threads;

GGML_ASSERT(node->src[0]->ne[3] == 1);
GGML_ASSERT(node->src[1]->ne[2] == 1);
GGML_ASSERT(node->src[1]->ne[3] == 1);

size_t cur = 0;
const int nk = node->src[0]->ne[0];

if (node->src[0]->type == GGML_TYPE_F16 &&
node->src[1]->type == GGML_TYPE_F32) {
cur = sizeof(ggml_fp16_t)*(
nk*ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] +
( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1]
);
} else if (node->src[0]->type == GGML_TYPE_F32 &&
node->src[1]->type == GGML_TYPE_F32) {
cur = sizeof(float)*(
nk*ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] +
( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1]
);
} else {
GGML_ASSERT(false);
}

work_size = MAX(work_size, cur);
} break;
case GGML_OP_CONV_2D:
{
n_tasks = n_threads;

GGML_ASSERT(node->src[1]->ne[3] == 1);

const int64_t ne00 = node->src[0]->ne[0]; // W
const int64_t ne01 = node->src[0]->ne[1]; // H
const int64_t ne02 = node->src[0]->ne[2]; // C
const int64_t ne03 = node->src[0]->ne[3]; // N

const int64_t ne10 = node->src[1]->ne[0]; // W
const int64_t ne11 = node->src[1]->ne[1]; // H
const int64_t ne12 = node->src[1]->ne[2]; // C

const int64_t nk = ne00*ne01;

UNUSED(ne02);
UNUSED(ne03);
UNUSED(nk);

void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
ggml_graph_compute_with_abort(ctx, cgraph, always_false, NULL);
size_t cur = 0;

if (node->src[0]->type == GGML_TYPE_F16 &&
node->src[1]->type == GGML_TYPE_F32) {
cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12);
} else if (node->src[0]->type == GGML_TYPE_F32 &&
node->src[1]->type == GGML_TYPE_F32) {
cur = sizeof(float)* (ne10*ne11*ne12);
} else {
GGML_ASSERT(false);
}

work_size = MAX(work_size, cur);
} break;
case GGML_OP_FLASH_ATTN:
{
n_tasks = n_threads;

size_t cur = 0;

const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);

if (node->src[1]->type == GGML_TYPE_F32) {
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
}

if (node->src[1]->type == GGML_TYPE_F16) {
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
}

work_size = MAX(work_size, cur);
} break;
case GGML_OP_FLASH_FF:
{
n_tasks = n_threads;

size_t cur = 0;

if (node->src[1]->type == GGML_TYPE_F32) {
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
}

if (node->src[1]->type == GGML_TYPE_F16) {
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
}

work_size = MAX(work_size, cur);
} break;
case GGML_OP_FLASH_ATTN_BACK:
{
n_tasks = n_threads;

size_t cur = 0;

const int64_t D = node->src[0]->ne[0];
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
if (node->src[1]->type == GGML_TYPE_F32) {
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
}

if (node->src[1]->type == GGML_TYPE_F16) {
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
}

work_size = MAX(work_size, cur);
} break;
case GGML_OP_WIN_PART:
case GGML_OP_WIN_UNPART:
case GGML_OP_MAP_UNARY:
case GGML_OP_MAP_BINARY:
case GGML_OP_MAP_CUSTOM1:
case GGML_OP_MAP_CUSTOM2:
case GGML_OP_MAP_CUSTOM3:
{
n_tasks = 1;
} break;
case GGML_OP_CROSS_ENTROPY_LOSS:
{
n_tasks = n_threads;

size_t cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);

work_size = MAX(work_size, cur);
} break;
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
{
n_tasks = n_threads;

size_t cur = ggml_type_size(node->type)*node->src[0]->ne[0]*n_tasks;

work_size = MAX(work_size, cur);
} break;
case GGML_OP_NONE:
{
n_tasks = 1;
} break;
case GGML_OP_COUNT:
{
GGML_ASSERT(false);
} break;
}

cplan.n_tasks[i] = n_tasks;
}

if (work_size > 0) {
work_size += CACHE_LINE_SIZE*(n_threads - 1);
}

cplan.n_threads = n_threads;
cplan.work_size = work_size;
cplan.work_data = NULL;

return cplan;
}

int ggml_graph_compute_with_abort(struct ggml_context * ctx, struct ggml_cgraph * cgraph,
bool (*abort_callback)(void * data), void *abort_callback_data) {
const int n_threads = cgraph->n_threads;
void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
{
GGML_ASSERT(cplan);
GGML_ASSERT(cplan->n_threads > 0);

if (cplan->work_size > 0) {
GGML_ASSERT(cplan->work_data);
}

for (int i = 0; i < cgraph->n_nodes; ++i) {
if (cgraph->nodes[i]->op != GGML_OP_NONE) {
GGML_ASSERT(cplan->n_tasks[i] > 0);
}
}
}

const int n_threads = cplan->n_threads;

struct ggml_compute_state_shared state_shared = {
/*.cgraph =*/ cgraph,
Expand All @@ -16095,8 +16441,6 @@ int ggml_graph_compute_with_abort(struct ggml_context * ctx, struct ggml_cgraph
/*.n_threads =*/ n_threads,
/*.n_active =*/ n_threads,
/*.node_n =*/ -1,
/*.abort_callback =*/ abort_callback,
/*.abort_callback_data =*/ abort_callback_data,
};
struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);

Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.