Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numa #1556

Merged
merged 20 commits into from
Jun 26, 2023
Merged

Numa #1556

Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
ggml : fix handling of ops with n_threads > n_tasks > 1
  • Loading branch information
ggerganov committed Jun 26, 2023
commit 81a40e9d6176a1c40202e3705e3e1f14248ca4b2
16 changes: 7 additions & 9 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -16765,7 +16765,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
}

atomic_store(&state->shared->n_active, n_threads);
atomic_store(&state->shared->node_n, node_n);
atomic_store(&state->shared->node_n, node_n);
} else {
// wait for other threads to finish
const int last = node_n;
Expand All @@ -16774,11 +16774,13 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
node_n = atomic_load(&state->shared->node_n);
} while (node_n == last);
}

// check if we should stop
if (node_n >= cgraph->n_nodes) break;

/* COMPUTE */
struct ggml_tensor * node = cgraph->nodes[node_n];

struct ggml_compute_params params = {
/*.type =*/ GGML_TASK_COMPUTE,
/*.ith =*/ state->ith,
Expand All @@ -16787,10 +16789,8 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
/*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
};

if(state->ith < node->n_tasks) {
if (state->ith < node->n_tasks) {
ggml_compute_forward(&params, node);
} else {
break;
}
}

Expand Down Expand Up @@ -16952,7 +16952,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
} break;
case GGML_OP_SCALE:
{
node->n_tasks = n_threads;
node->n_tasks = 1;
} break;
case GGML_OP_SET:
case GGML_OP_CONT:
Expand Down Expand Up @@ -17165,9 +17165,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
.shared = &state_shared,
};

int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
GGML_ASSERT(rc == 0);
UNUSED(rc);
}
}
workers[0].ith = 0;
Expand All @@ -17185,9 +17184,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
// join thread pool
if (n_threads > 1) {
for (int j = 1; j < n_threads; j++) {
int rc = ggml_thread_join(workers[j].thrd, NULL);
const int rc = ggml_thread_join(workers[j].thrd, NULL);
GGML_ASSERT(rc == 0);
UNUSED(rc);
}
}

Expand Down
Loading