Skip to content

Commit

Permalink
Allow number of nodes in CUDA graph to change (llama/7738)
Browse files Browse the repository at this point in the history
Previously the code would have failed to cope in the case that the
number of nodes changes in an existing CUDA graph. This fixes the
issue by removing an unnecessary conditional.
  • Loading branch information
agray3 authored and ggerganov committed Jun 15, 2024
1 parent 9b05346 commit a524d3f
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2702,10 +2702,8 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t

if (cuda_graph_update_required) {
// Extract nodes from graph
if (cuda_ctx->cuda_graph->num_nodes == 0) {
// First call with null argument gets number of nodes in graph
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
}
// First call with null argument gets number of nodes in graph
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
// Subsequent call with non-null argument gets nodes
cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
Expand Down

0 comments on commit a524d3f

Please sign in to comment.