From 8b185b703020e09ab9cac8c56832d93aa240e4d9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 9 Dec 2023 13:01:42 +0200 Subject: [PATCH] llama : fix expert weighting in the FFN --- llama.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/llama.cpp b/llama.cpp index 6333af4aa2b37..3c4da6a1c3f60 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4250,11 +4250,13 @@ struct llm_build_context { ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts] // select experts - ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok] - //ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [n_tokens, num_experts_per_tok, 1] - ggml_tensor * weights = ggml_get_rows(ctx0, - ggml_reshape_3d(ctx0, probs, 1, n_experts, n_tokens), selected_experts); - weights = ggml_div(ctx0, weights, ggml_sum_rows(ctx0, weights)); // [n_tokens, num_experts_per_tok, 1] + ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok] + ggml_tensor * weights = + ggml_reshape_2d(ctx0, + ggml_get_rows(ctx0, + ggml_reshape_3d(ctx0, probs, 1, n_experts, n_tokens), selected_experts), + n_experts_per_tok, n_tokens); // [n_tokens, num_experts_per_tok] + weights = ggml_div(ctx0, weights, ggml_sum_rows(ctx0, weights)); // [n_tokens, num_experts_per_tok] // compute expert outputs ggml_tensor * moe_out;