Skip to content

Commit

Permalink
llama : fix expert weighting in the FFN
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Dec 9, 2023
1 parent 7ea3695 commit 8b185b7
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4250,11 +4250,13 @@ struct llm_build_context {
ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]

// select experts
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok]
//ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [n_tokens, num_experts_per_tok, 1]
ggml_tensor * weights = ggml_get_rows(ctx0,
ggml_reshape_3d(ctx0, probs, 1, n_experts, n_tokens), selected_experts);
weights = ggml_div(ctx0, weights, ggml_sum_rows(ctx0, weights)); // [n_tokens, num_experts_per_tok, 1]
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok]
ggml_tensor * weights =
ggml_reshape_2d(ctx0,
ggml_get_rows(ctx0,
ggml_reshape_3d(ctx0, probs, 1, n_experts, n_tokens), selected_experts),
n_experts_per_tok, n_tokens); // [n_tokens, num_experts_per_tok]
weights = ggml_div(ctx0, weights, ggml_sum_rows(ctx0, weights)); // [n_tokens, num_experts_per_tok]

// compute expert outputs
ggml_tensor * moe_out;
Expand Down

0 comments on commit 8b185b7

Please sign in to comment.