Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for DeepseekV2ForCausalLM #7519

Merged
merged 30 commits into from
May 28, 2024

Conversation

fairydreaming
Copy link
Collaborator

This pull request adds support for DeepseekV2ForCausalLM-based models. Both lite and non-lite models are supported. Fixes #7118

Changes included in this pull request:

  • increase max number of experts to 160
  • add new LLM_ARCH_DEEPSEEK2 architecture
  • add new model header parameters:
    • leading_dense_block_count - number of leading dense layers (following layers are MoE)
    • expert_feed_forward_length - ff length for experts (feed_forward_length is used for dense ffn dimension in leading dense layers)
    • expert_shared_count - number of shared experts
    • expert_weights_scale - in DeepSeekV2 MoE gate_inp output is not normalized, but multiplied by this constant value instead
    • attention.q_lora_rank - latent compressed Q dimension
    • attention.kv_lora_rank - latent compressed KV dimension
    • rope.scaling.yarn_log_multiplier - log multiplier used to calculate YaRN mscale (0.0707 in DeepSeek-V2 instead of 0.1)
  • add new tensors (first hidden representations are mapped to latent compressed representation with _a tensors and normalized with _a_norm, then mapped to q/kv space with _b tensors):
    • attn_q_a
    • attn_q_b
    • attn_kv_a_mqa
    • attn_kv_b
    • attn_q_a_norm
    • attn_kv_a_norm
  • add MODEL_16B and MODEL_236B model types for lite and non-lite models (would be great to have dense + expert param counts, but I have no idea how to count it properly considering the existence of leading dense layer, shared experts etc.)
  • add support for DeepseekV2ForCausalLM model conversion in convert-hf
  • add support for inference for models based on LLM_ARCH_DEEPSEEK2

This pull request also fixes #7331 by removing the failing assertion.

Note that I had to change the llm_build_moe_ffn() API to add scaling of MoE gate output - added a bool indicating whether to scale or not and a float scale value. Let me know if there is a better way.
Implementation also required somewhat ugly workaround for problems with YaRN implementation that I described in detail in #7416

One thing I'm not sure about - shall I print all new parameter values in llm_load_print_meta() or print them only for LLM_ARCH_DEEPSEEK2 models (since they are specific to this architecture)? Or perhaps simply not print them at all?

I'll correct the whitespace formatting after we agree on the parameter names.

sszymczy and others added 16 commits May 16, 2024 20:16
- leading_dense_block_count => hparams.n_leading_dense_layer,
- expert_feed_forward_length => hparams.n_expert_ff,
- expert_shared_count => hparams.n_expert_shared,
- attention.q_lora_rank => hparams.n_lora_q,
- attention.kv_lora_rank => hparams.n_lora_kv
Added missing scaling of kq_scale parameter.
… multiplier of the ln(s) from the sqrt(1/t) = 0.1 ln(s) + 1 equation.
@github-actions github-actions bot added the python python script changes label May 24, 2024
Copy link
Contributor

github-actions bot commented May 24, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 519 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8999.54ms p(95)=21986.23ms fails=, finish reason: stop=462 truncated=57
  • Prompt processing (pp): avg=103.24tk/s p(95)=442.93tk/s
  • Token generation (tg): avg=33.3tk/s p(95)=43.7tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=deepseek-v2 commit=3efb6595ae6c4f5e10492c70172085e6827d30a7

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 519 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1716905704 --> 1716906328
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 263.56, 263.56, 263.56, 263.56, 263.56, 861.49, 861.49, 861.49, 861.49, 861.49, 868.9, 868.9, 868.9, 868.9, 868.9, 894.54, 894.54, 894.54, 894.54, 894.54, 956.1, 956.1, 956.1, 956.1, 956.1, 947.2, 947.2, 947.2, 947.2, 947.2, 919.65, 919.65, 919.65, 919.65, 919.65, 915.83, 915.83, 915.83, 915.83, 915.83, 907.13, 907.13, 907.13, 907.13, 907.13, 895.27, 895.27, 895.27, 895.27, 895.27, 890.9, 890.9, 890.9, 890.9, 890.9, 897.97, 897.97, 897.97, 897.97, 897.97, 925.73, 925.73, 925.73, 925.73, 925.73, 930.89, 930.89, 930.89, 930.89, 930.89, 942.78, 942.78, 942.78, 942.78, 942.78, 874.34, 874.34, 874.34, 874.34, 874.34, 877.5, 877.5, 877.5, 877.5, 877.5, 879.6, 879.6, 879.6, 879.6, 879.6, 874.56, 874.56, 874.56, 874.56, 874.56, 894.38, 894.38, 894.38, 894.38, 894.38, 895.42, 895.42, 895.42, 895.42, 895.42, 899.78, 899.78, 899.78, 899.78, 899.78, 902.47, 902.47, 902.47, 902.47, 902.47, 911.6, 911.6, 911.6, 911.6, 911.6, 892.17, 892.17, 892.17, 892.17, 892.17, 891.74, 891.74, 891.74, 891.74, 891.74, 890.17, 890.17, 890.17, 890.17, 890.17, 903.69, 903.69, 903.69, 903.69, 903.69, 901.22, 901.22, 901.22, 901.22, 901.22, 899.58, 899.58, 899.58, 899.58, 899.58, 896.7, 896.7, 896.7, 896.7, 896.7, 897.01, 897.01, 897.01, 897.01, 897.01, 901.04, 901.04, 901.04, 901.04, 901.04, 898.14, 898.14, 898.14, 898.14, 898.14, 901.49, 901.49, 901.49, 901.49, 901.49, 903.77, 903.77, 903.77, 903.77, 903.77, 908.55, 908.55, 908.55, 908.55, 908.55, 909.7, 909.7, 909.7, 909.7, 909.7, 911.96, 911.96, 911.96, 911.96, 911.96, 907.71, 907.71, 907.71, 907.71, 907.71, 905.83, 905.83, 905.83, 905.83, 905.83, 905.06, 905.06, 905.06, 905.06, 905.06, 905.49, 905.49, 905.49, 905.49, 905.49, 905.51, 905.51, 905.51, 905.51, 905.51, 909.81, 909.81, 909.81, 909.81, 909.81, 908.23, 908.23, 908.23, 908.23, 908.23, 831.13, 831.13, 831.13, 831.13, 831.13, 829.89, 829.89, 829.89, 829.89, 829.89, 828.82, 828.82, 828.82, 828.82, 828.82, 833.05, 833.05, 833.05, 833.05, 833.05, 834.2, 834.2, 834.2, 834.2, 834.2, 835.24, 835.24, 835.24, 835.24, 835.24, 834.17, 834.17, 834.17, 834.17, 834.17, 837.41, 837.41, 837.41, 837.41, 837.41, 836.33, 836.33, 836.33, 836.33, 836.33, 835.8, 835.8, 835.8, 835.8, 835.8, 832.84, 832.84, 832.84, 832.84, 832.84, 832.12, 832.12, 832.12, 832.12, 832.12, 832.46, 832.46, 832.46, 832.46, 832.46, 832.03, 832.03, 832.03, 832.03, 832.03, 832.88, 832.88, 832.88]
                    
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 519 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1716905704 --> 1716906328
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 28.56, 28.56, 28.56, 28.56, 28.56, 26.79, 26.79, 26.79, 26.79, 26.79, 31.96, 31.96, 31.96, 31.96, 31.96, 32.42, 32.42, 32.42, 32.42, 32.42, 32.17, 32.17, 32.17, 32.17, 32.17, 32.16, 32.16, 32.16, 32.16, 32.16, 33.36, 33.36, 33.36, 33.36, 33.36, 33.95, 33.95, 33.95, 33.95, 33.95, 34.3, 34.3, 34.3, 34.3, 34.3, 34.33, 34.33, 34.33, 34.33, 34.33, 33.9, 33.9, 33.9, 33.9, 33.9, 33.33, 33.33, 33.33, 33.33, 33.33, 32.44, 32.44, 32.44, 32.44, 32.44, 32.38, 32.38, 32.38, 32.38, 32.38, 30.84, 30.84, 30.84, 30.84, 30.84, 29.99, 29.99, 29.99, 29.99, 29.99, 29.54, 29.54, 29.54, 29.54, 29.54, 29.71, 29.71, 29.71, 29.71, 29.71, 29.55, 29.55, 29.55, 29.55, 29.55, 29.67, 29.67, 29.67, 29.67, 29.67, 29.94, 29.94, 29.94, 29.94, 29.94, 30.09, 30.09, 30.09, 30.09, 30.09, 30.4, 30.4, 30.4, 30.4, 30.4, 30.18, 30.18, 30.18, 30.18, 30.18, 30.19, 30.19, 30.19, 30.19, 30.19, 30.25, 30.25, 30.25, 30.25, 30.25, 30.38, 30.38, 30.38, 30.38, 30.38, 30.24, 30.24, 30.24, 30.24, 30.24, 29.96, 29.96, 29.96, 29.96, 29.96, 30.02, 30.02, 30.02, 30.02, 30.02, 30.1, 30.1, 30.1, 30.1, 30.1, 30.34, 30.34, 30.34, 30.34, 30.34, 30.39, 30.39, 30.39, 30.39, 30.39, 30.54, 30.54, 30.54, 30.54, 30.54, 30.59, 30.59, 30.59, 30.59, 30.59, 30.49, 30.49, 30.49, 30.49, 30.49, 30.46, 30.46, 30.46, 30.46, 30.46, 29.83, 29.83, 29.83, 29.83, 29.83, 29.63, 29.63, 29.63, 29.63, 29.63, 29.77, 29.77, 29.77, 29.77, 29.77, 29.93, 29.93, 29.93, 29.93, 29.93, 30.03, 30.03, 30.03, 30.03, 30.03, 30.1, 30.1, 30.1, 30.1, 30.1, 30.18, 30.18, 30.18, 30.18, 30.18, 29.88, 29.88, 29.88, 29.88, 29.88, 29.88, 29.88, 29.88, 29.88, 29.88, 29.4, 29.4, 29.4, 29.4, 29.4, 29.07, 29.07, 29.07, 29.07, 29.07, 29.04, 29.04, 29.04, 29.04, 29.04, 29.01, 29.01, 29.01, 29.01, 29.01, 29.06, 29.06, 29.06, 29.06, 29.06, 29.12, 29.12, 29.12, 29.12, 29.12, 29.21, 29.21, 29.21, 29.21, 29.21, 29.25, 29.25, 29.25, 29.25, 29.25, 29.23, 29.23, 29.23, 29.23, 29.23, 29.27, 29.27, 29.27, 29.27, 29.27, 29.14, 29.14, 29.14, 29.14, 29.14, 29.11, 29.11, 29.11, 29.11, 29.11, 29.15, 29.15, 29.15, 29.15, 29.15, 29.32, 29.32, 29.32]
                    

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 519 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1716905704 --> 1716906328
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.06, 0.06, 0.06, 0.06, 0.06, 0.37, 0.37, 0.37, 0.37, 0.37, 0.12, 0.12, 0.12, 0.12, 0.12, 0.1, 0.1, 0.1, 0.1, 0.1, 0.25, 0.25, 0.25, 0.25, 0.25, 0.23, 0.23, 0.23, 0.23, 0.23, 0.12, 0.12, 0.12, 0.12, 0.12, 0.11, 0.11, 0.11, 0.11, 0.11, 0.13, 0.13, 0.13, 0.13, 0.13, 0.15, 0.15, 0.15, 0.15, 0.15, 0.2, 0.2, 0.2, 0.2, 0.2, 0.23, 0.23, 0.23, 0.23, 0.23, 0.3, 0.3, 0.3, 0.3, 0.3, 0.28, 0.28, 0.28, 0.28, 0.28, 0.42, 0.42, 0.42, 0.42, 0.42, 0.31, 0.31, 0.31, 0.31, 0.31, 0.32, 0.32, 0.32, 0.32, 0.32, 0.19, 0.19, 0.19, 0.19, 0.19, 0.12, 0.12, 0.12, 0.12, 0.12, 0.2, 0.2, 0.2, 0.2, 0.2, 0.12, 0.12, 0.12, 0.12, 0.12, 0.2, 0.2, 0.2, 0.2, 0.2, 0.16, 0.16, 0.16, 0.16, 0.16, 0.25, 0.25, 0.25, 0.25, 0.25, 0.18, 0.18, 0.18, 0.18, 0.18, 0.19, 0.19, 0.19, 0.19, 0.19, 0.15, 0.15, 0.15, 0.15, 0.15, 0.3, 0.3, 0.3, 0.3, 0.3, 0.26, 0.26, 0.26, 0.26, 0.26, 0.21, 0.21, 0.21, 0.21, 0.21, 0.18, 0.18, 0.18, 0.18, 0.18, 0.12, 0.12, 0.12, 0.12, 0.12, 0.17, 0.17, 0.17, 0.17, 0.17, 0.15, 0.15, 0.15, 0.15, 0.15, 0.21, 0.21, 0.21, 0.21, 0.21, 0.18, 0.18, 0.18, 0.18, 0.18, 0.27, 0.27, 0.27, 0.27, 0.27, 0.42, 0.42, 0.42, 0.42, 0.42, 0.31, 0.31, 0.31, 0.31, 0.31, 0.16, 0.16, 0.16, 0.16, 0.16, 0.18, 0.18, 0.18, 0.18, 0.18, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.15, 0.15, 0.15, 0.15, 0.15, 0.26, 0.26, 0.26, 0.26, 0.26, 0.51, 0.51, 0.51, 0.51, 0.51, 0.58, 0.58, 0.58, 0.58, 0.58, 0.28, 0.28, 0.28, 0.28, 0.28, 0.12, 0.12, 0.12, 0.12, 0.12, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.14, 0.14, 0.14, 0.14, 0.14, 0.12, 0.12, 0.12, 0.12, 0.12, 0.2, 0.2, 0.2, 0.2, 0.2, 0.25, 0.25, 0.25, 0.25, 0.25, 0.14, 0.14, 0.14, 0.14, 0.14, 0.27, 0.27, 0.27, 0.27, 0.27, 0.22, 0.22, 0.22, 0.22, 0.22, 0.15, 0.15, 0.15, 0.15, 0.15, 0.1, 0.1, 0.1, 0.1, 0.1, 0.15, 0.15, 0.15]
                    
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 519 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1716905704 --> 1716906328
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 6.0, 6.0, 6.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 1.0, 1.0, 1.0, 1.0, 1.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0]
                    

@mofosyne mofosyne added model Model specific Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level labels May 25, 2024
Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work!

I've added a few name suggestions at some places, but feel free to choose something else if you prefer

One thing I'm not sure about - shall I print all new parameter values in llm_load_print_meta() or print them only for LLM_ARCH_DEEPSEEK2 models (since they are specific to this architecture)? Or perhaps simply not print them at all?

Let's print them for the specific arch

llama.cpp Outdated
Comment on lines 11300 to 11306
k_pe = ggml_repeat(ctx0, k_pe, q_pe);
cb(k_pe, "k_pe", il);

struct ggml_tensor * key_states = ggml_new_tensor_3d(ctx0, q_nope->type, hparams.n_embd_head_k, n_head, n_tokens);
cb(key_states, "key_states", il);
key_states = ggml_set_inplace(ctx0, key_states, k_nope, key_states->nb[1], key_states->nb[2], key_states->nb[3], 0);
key_states = ggml_set_inplace(ctx0, key_states, k_pe, key_states->nb[1], key_states->nb[2], key_states->nb[3], ggml_element_size(key_states) * qk_nope_head_dim);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to implement GGML_OP_REPEAT for Metal - I can try to do that tomorrow

Copy link
Collaborator

@slaren slaren May 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And GGML_OP_SET as well. I am not quite sure what this op does. Could it be replaced with a ggml_cpy into views?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be possible, though I think it would require ggml_build_forward_expands which is not very nice

I think this OP does something like this:

k_pe:   kkk
q_pe:   qqq qqq qqq
k_nope: nnn

key_states: nnn kkk kkk kkk

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having ggml_new_tensor_3d is not great either, they will be treated as inputs in ggml-alloc and allocated at the beginning of the compute buffer, which will increase memory usage since there is one for each layer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to spend the time to make a more general ggml_concat operations.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, how would broadcasting work with ggml_concat? This is the current API:

    // concat a and b along dim
    // used in stable-diffusion
    GGML_API struct ggml_tensor * ggml_concat(
            struct ggml_context * ctx,
            struct ggml_tensor  * a,
            struct ggml_tensor  * b,
            int                   dim);

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you rebase on master and apply the following patch, it should work with CUDA and Metal:

diff --git a/llama.cpp b/llama.cpp
index cef5bfdd..9c80a621 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -11305,18 +11305,11 @@ struct llm_build_context {
                 );
                 cb(k_pe, "k_pe", il);
 
-                struct ggml_tensor * q_states = ggml_new_tensor_3d(ctx0, q_nope->type, hparams.n_embd_head_k, n_head, n_tokens);
+                struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
                 cb(q_states, "q_states", il);
-                q_states = ggml_set_inplace(ctx0, q_states, q_nope, q_states->nb[1], q_states->nb[2], q_states->nb[3], 0);
-                q_states = ggml_set_inplace(ctx0, q_states, q_pe, q_states->nb[1], q_states->nb[2], q_states->nb[3], ggml_element_size(q_states) * n_embd_head_qk_nope);
 
-                k_pe = ggml_repeat(ctx0, k_pe, q_pe);
-                cb(k_pe, "k_pe", il);
-
-                struct ggml_tensor * k_states = ggml_new_tensor_3d(ctx0, q_nope->type, hparams.n_embd_head_k, n_head, n_tokens);
+                struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
                 cb(k_states, "k_states", il);
-                k_states = ggml_set_inplace(ctx0, k_states, k_nope, k_states->nb[1], k_states->nb[2], k_states->nb[3], 0);
-                k_states = ggml_set_inplace(ctx0, k_states, k_pe, k_states->nb[1], k_states->nb[2], k_states->nb[3], ggml_element_size(k_states) * n_embd_head_qk_nope);
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, NULL,

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, how would broadcasting work with ggml_concat? This is the current API:

@ggerganov forget it, it seems that even in numpy and pytorch concat doesn't support broadcasting

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you rebase on master and apply the following patch, it should work with CUDA and Metal:

diff --git a/llama.cpp b/llama.cpp
index cef5bfdd..9c80a621 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -11305,18 +11305,11 @@ struct llm_build_context {
                 );
                 cb(k_pe, "k_pe", il);
 
-                struct ggml_tensor * q_states = ggml_new_tensor_3d(ctx0, q_nope->type, hparams.n_embd_head_k, n_head, n_tokens);
+                struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
                 cb(q_states, "q_states", il);
-                q_states = ggml_set_inplace(ctx0, q_states, q_nope, q_states->nb[1], q_states->nb[2], q_states->nb[3], 0);
-                q_states = ggml_set_inplace(ctx0, q_states, q_pe, q_states->nb[1], q_states->nb[2], q_states->nb[3], ggml_element_size(q_states) * n_embd_head_qk_nope);
 
-                k_pe = ggml_repeat(ctx0, k_pe, q_pe);
-                cb(k_pe, "k_pe", il);
-
-                struct ggml_tensor * k_states = ggml_new_tensor_3d(ctx0, q_nope->type, hparams.n_embd_head_k, n_head, n_tokens);
+                struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
                 cb(k_states, "k_states", il);
-                k_states = ggml_set_inplace(ctx0, k_states, k_nope, k_states->nb[1], k_states->nb[2], k_states->nb[3], 0);
-                k_states = ggml_set_inplace(ctx0, k_states, k_pe, k_states->nb[1], k_states->nb[2], k_states->nb[3], ggml_element_size(k_states) * n_embd_head_qk_nope);
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, NULL,

@ggerganov Done, looks much cleaner now, thank you! Works without issues (tested on CPU).

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, I also tested the Lite model with Metal. Let's merge after @slaren approval

llama.cpp Outdated Show resolved Hide resolved
llama.cpp Outdated Show resolved Hide resolved
llama.cpp Outdated Show resolved Hide resolved
llama.cpp Outdated Show resolved Hide resolved
llama.cpp Outdated Show resolved Hide resolved
llama.cpp Outdated Show resolved Hide resolved
@ggerganov ggerganov requested a review from slaren May 26, 2024 19:09
@foldl
Copy link
Contributor

foldl commented May 27, 2024

I think these issues about YaRN are not addressed yet. Dismissed.

@fairydreaming
Copy link
Collaborator Author

I think these issues about YaRN are not addressed yet.

@foldl While I agree that there may be some issues with YaRN implementation for models using NeoX-style RoPE (at least by looking at the code), I think DeepSeek-V2 is not affected by them, as it falls into preceding conditional block (if(!is_neox)):

https://github.com/ggerganov/ggml/blob/0cbb7c0e053f5419cfbebb46fbf4d4ed60182cf5/src/ggml.c#L14059C19-L14059C39

@foldl
Copy link
Contributor

foldl commented May 27, 2024

@fairydreaming Got it. Your PR uses RoPE type 0.

llama.cpp Outdated Show resolved Hide resolved
@fairydreaming fairydreaming merged commit ee3dff6 into ggerganov:master May 28, 2024
72 checks passed
@bartowski1182
Copy link
Contributor

Do we expect this to not work on CUDA?

Trying to make imatrix for https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat

Results in:

GGML_ASSERT: ggml-cuda/concat.cu:107: ggml_is_contiguous(src0)

@trislee02
Copy link

trislee02 commented May 29, 2024

I got the same issue GGML_ASSERT: ggml-cuda/concat.cu:107: ggml_is_contiguous(src0) when trying this model mzwing/DeepSeek-V2-Lite-Chat-GGUF:

./main -m ~/DeepSeek-V2-Lite-Chat.IQ2_M.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -ngl 33

I built with CMake:

cmake -B build -DLLAMA_CUDA=ON
cmake --build build --config Release

Following is the log when running ./main:

Log start
main: build = 3030 (504f0c34)
main: built with cc (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0 for x86_64-linux-gnu
main: seed  = 1716952582
llama_model_loader: loaded meta data with 42 key-value pairs and 377 tensors from /home/jupyter-trile-new/DeepSeek-V2-Lite-Chat.IQ2_M.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = deepseek2
llama_model_loader: - kv   1:                               general.name str              = DeepSeek-V2-Lite-Chat
llama_model_loader: - kv   2:                      deepseek2.block_count u32              = 27
llama_model_loader: - kv   3:                   deepseek2.context_length u32              = 163840
llama_model_loader: - kv   4:                 deepseek2.embedding_length u32              = 2048
llama_model_loader: - kv   5:              deepseek2.feed_forward_length u32              = 10944
llama_model_loader: - kv   6:             deepseek2.attention.head_count u32              = 16
llama_model_loader: - kv   7:          deepseek2.attention.head_count_kv u32              = 16
llama_model_loader: - kv   8:                   deepseek2.rope.freq_base f32              = 10000.000000
llama_model_loader: - kv   9: deepseek2.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  10:                deepseek2.expert_used_count u32              = 6
llama_model_loader: - kv  11:                          general.file_type u32              = 29
llama_model_loader: - kv  12:        deepseek2.leading_dense_block_count u32              = 1
llama_model_loader: - kv  13:                       deepseek2.vocab_size u32              = 102400
llama_model_loader: - kv  14:           deepseek2.attention.kv_lora_rank u32              = 512
llama_model_loader: - kv  15:             deepseek2.attention.key_length u32              = 192
llama_model_loader: - kv  16:           deepseek2.attention.value_length u32              = 128
llama_model_loader: - kv  17:       deepseek2.expert_feed_forward_length u32              = 1408
llama_model_loader: - kv  18:                     deepseek2.expert_count u32              = 64
llama_model_loader: - kv  19:              deepseek2.expert_shared_count u32              = 2
llama_model_loader: - kv  20:             deepseek2.expert_weights_scale f32              = 1.000000
llama_model_loader: - kv  21:             deepseek2.rope.dimension_count u32              = 64
llama_model_loader: - kv  22:                deepseek2.rope.scaling.type str              = yarn
llama_model_loader: - kv  23:              deepseek2.rope.scaling.factor f32              = 40.000000
llama_model_loader: - kv  24: deepseek2.rope.scaling.original_context_length u32              = 4096
llama_model_loader: - kv  25: deepseek2.rope.scaling.yarn_log_multiplier f32              = 0.070700
llama_model_loader: - kv  26:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  27:                         tokenizer.ggml.pre str              = deepseek-llm
llama_model_loader: - kv  28:                      tokenizer.ggml.tokens arr[str,102400]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  29:                  tokenizer.ggml.token_type arr[i32,102400]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  30:                      tokenizer.ggml.merges arr[str,99757]   = ["Ġ Ġ", "Ġ t", "Ġ a", "i n", "h e...
llama_model_loader: - kv  31:                tokenizer.ggml.bos_token_id u32              = 100000
llama_model_loader: - kv  32:                tokenizer.ggml.eos_token_id u32              = 100001
llama_model_loader: - kv  33:            tokenizer.ggml.padding_token_id u32              = 100001
llama_model_loader: - kv  34:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  35:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  36:                    tokenizer.chat_template str              = {% if not add_generation_prompt is de...
llama_model_loader: - kv  37:               general.quantization_version u32              = 2
llama_model_loader: - kv  38:                      quantize.imatrix.file str              = DeepSeek-V2-Lite-Chat-IMat-GGUF/imatr...
llama_model_loader: - kv  39:                   quantize.imatrix.dataset str              = DeepSeek-V2-Lite-Chat-IMat-GGUF/imatr...
llama_model_loader: - kv  40:             quantize.imatrix.entries_count i32              = 293
llama_model_loader: - kv  41:              quantize.imatrix.chunks_count i32              = 214
llama_model_loader: - type  f32:  108 tensors
llama_model_loader: - type q5_K:    1 tensors
llama_model_loader: - type iq4_nl:   27 tensors
llama_model_loader: - type iq3_s:   29 tensors
llama_model_loader: - type iq2_s:  212 tensors
llm_load_vocab: special tokens cache size = 2400.
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = deepseek2
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 102400
llm_load_print_meta: n_merges         = 99757
llm_load_print_meta: n_ctx_train      = 163840
llm_load_print_meta: n_embd           = 2048
llm_load_print_meta: n_head           = 16
llm_load_print_meta: n_head_kv        = 16
llm_load_print_meta: n_layer          = 27
llm_load_print_meta: n_rot            = 64
llm_load_print_meta: n_embd_head_k    = 192
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 1
llm_load_print_meta: n_embd_k_gqa     = 3072
llm_load_print_meta: n_embd_v_gqa     = 2048
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-06
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 10944
llm_load_print_meta: n_expert         = 64
llm_load_print_meta: n_expert_used    = 6
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = yarn
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 0.025
llm_load_print_meta: n_yarn_orig_ctx  = 4096
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: model type       = 16B
llm_load_print_meta: model ftype      = IQ2_M - 2.7 bpw
llm_load_print_meta: model params     = 15.71 B
llm_load_print_meta: model size       = 5.89 GiB (3.22 BPW) 
llm_load_print_meta: general.name     = DeepSeek-V2-Lite-Chat
llm_load_print_meta: BOS token        = 100000 '<|begin▁of▁sentence|>'
llm_load_print_meta: EOS token        = 100001 '<|end▁of▁sentence|>'
llm_load_print_meta: PAD token        = 100001 '<|end▁of▁sentence|>'
llm_load_print_meta: LF token         = 126 'Ä'
llm_load_print_meta: n_layer_dense_lead   = 1
llm_load_print_meta: n_lora_q             = 0
llm_load_print_meta: n_lora_kv            = 512
llm_load_print_meta: n_ff_exp             = 1408
llm_load_print_meta: n_expert_shared      = 2
llm_load_print_meta: expert_weights_scale = 1.0
llm_load_print_meta: rope_yarn_log_mul    = 0.0707
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:   no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 2 CUDA devices:
  Device 0: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes
  Device 1: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes
llm_load_tensors: ggml ctx size =    0.53 MiB
llm_load_tensors: offloading 27 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 28/28 layers to GPU
llm_load_tensors:        CPU buffer size =    85.94 MiB
llm_load_tensors:      CUDA0 buffer size =  2919.51 MiB
llm_load_tensors:      CUDA1 buffer size =  3026.04 MiB
..................................................................................
llama_new_context_with_model: n_ctx      = 512
llama_new_context_with_model: n_batch    = 512
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 0.025
llama_kv_cache_init:      CUDA0 KV buffer size =    70.00 MiB
llama_kv_cache_init:      CUDA1 KV buffer size =    65.00 MiB
llama_new_context_with_model: KV self size  =  135.00 MiB, K (f16):   81.00 MiB, V (f16):   54.00 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     0.39 MiB
llama_new_context_with_model: pipeline parallelism enabled (n_copies=4)
llama_new_context_with_model:      CUDA0 compute buffer size =    70.76 MiB
llama_new_context_with_model:      CUDA1 compute buffer size =   232.02 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =     8.02 MiB
llama_new_context_with_model: graph nodes  = 1870
llama_new_context_with_model: graph splits = 3
GGML_ASSERT: /home/jupyter-trile-new/llama.cpp/ggml-cuda/concat.cu:107: ggml_is_contiguous(src0)
Could not attach to process.  If your uid matches the uid of the target
process, check the setting of /proc/sys/kernel/yama/ptrace_scope, or try
again as the root user.  For more details, see /etc/sysctl.d/10-ptrace.conf
ptrace: Operation not permitted.
No stack.
The program is not being run.

@ggerganov
Copy link
Owner

I'll try to fix this now - I didn't realize the concat tensors in DS2 are not contiguous

@ggerganov
Copy link
Owner

Please test #7610 to see if it works now

@fairydreaming
Copy link
Collaborator Author

@ggerganov I think there is still something wrong with the CUDA implementation. The more layers I offload to the GPU, the more broken the model output is.

Is there any way to print tensors from the CUDA code? I'd love to be able to add GGML graph nodes for printing tensors, something like a = ggml_print(ctx, a) that would allow me to compare selected tensor values between different backends. I have a crude implementation of something similar that I used to compare tensors on CPU with DeepSeeek-V2 transformers implementation. Or perhaps there is a better way to debug this?

@ggerganov
Copy link
Owner

You can already do that with the eval-callback example

@fairydreaming
Copy link
Collaborator Author

You can already do that with the eval-callback example

@ggerganov The first meaningful difference is in the q_pe-26 = (f32) ROPE(q_pe-26). Since that's the last layer of the lite model, I guess it's the one that is offloaded to the GPU with -ngl 1. Does ggml_rope_ext() CUDA implementation expect contiguous tensors?

@fairydreaming
Copy link
Collaborator Author

@ggerganov Another difference in tensor values I found was in norm-26 = (f32) RMS_NORM(compressed_kv-26). So it looks like normalization is also affected by non-contiguous tensors.

After adding the following changes:

diff --git a/llama.cpp b/llama.cpp
index dac81acc..8d71a1ac 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -11204,6 +11204,7 @@ struct llm_build_context {
                 struct ggml_tensor * k_pe = ggml_view_2d(ctx0, compressed_kv_pe, n_embd_head_qk_rope, n_tokens, compressed_kv_pe->nb[1], ggml_element_size(compressed_kv_pe)*kv_lora_rank);
                 cb(k_pe, "k_pe", il);
 
+                compressed_kv = ggml_cont(ctx0, compressed_kv);
                 compressed_kv = llm_build_norm(ctx0, compressed_kv, hparams,
                         model.layers[il].attn_kv_a_norm, NULL,
                         LLM_NORM_RMS, cb, il);
@@ -11227,6 +11228,7 @@ struct llm_build_context {
                 v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, ggml_element_size(kv) * hparams.n_embd_head_v * n_head, 0);
                 cb(v_states, "v_states", il);
 
+                q_pe = ggml_cont(ctx0, q_pe);
                 q_pe = ggml_rope_ext(
                     ctx0, q_pe, inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
@@ -11235,6 +11237,7 @@ struct llm_build_context {
                 cb(q_pe, "q_pe", il);
 
                 // shared RoPE key
+                k_pe = ggml_cont(ctx0, k_pe);
                 k_pe = ggml_rope_ext(
                     ctx0, ggml_view_3d(ctx0, k_pe, n_embd_head_qk_rope, 1, n_tokens, k_pe->nb[0], k_pe->nb[1], 0), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,

I can offload the whole model to GPU and it works correctly.

@ggerganov
Copy link
Owner

Hm, it's strange that we are missing the proper asserts for these. I thought at least for ggml_rope there should have been asserts at some point

But anyway, it would probably be better to support non-contiguous rope and norm. I'll also have to double-check the Metal backend and add tests

@ggerganov
Copy link
Owner

I've started adding tests and fixing some of the RoPE issues in: #7617

I think we should try to avoid ggml_cont and ggml_view and instead we should extend the NORM (mode = 0) RoPE to support partial head rotation like NeoX (mode = 2) already supports it. This way, instead of ggml_rope(ggml_view(x, n_rot), n_rot), you could straight up do ggml_rope(x, n_rot) and the data will be contiguous. Will try to implement this within #7617

Comment on lines +11301 to +11303
k_pe = ggml_rope_ext(
ctx0, ggml_view_3d(ctx0, k_pe, n_embd_head_qk_rope, 1, n_tokens, k_pe->nb[0], k_pe->nb[1], 0), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
Copy link
Owner

@ggerganov ggerganov May 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it correct here that we rotate just the first head of k_pe?
With DS2 Lite, there is a second head here which seems to be simply discarded, if I understand correctly

Edit: nvm got it

@ggerganov
Copy link
Owner

@fairydreaming Could you give #7617 a try and see if it works correctly?

I ended up implementing your patch because the idea that I had is not going to work. The partial RoPE requires that the rotated dimensions are at the beginning of the head, but here they are at the end. There also seem to be other limitations, so I guess the current view + cont implementation should be ok

The CUDA kernels could be extended to support non-contiguous RoPE and norm, but I'm not sure if this is going to be worth it - I guess it might even hurt the performance. So instead, I added asserts to prevent from using non-contiguous data with these ops

I've also simplified some of the stride and offset computations in the DS2 compute graph by using the ggml_row_size helper

@fairydreaming
Copy link
Collaborator Author

@fairydreaming Could you give #7617 a try and see if it works correctly?

@ggerganov Sure, seems to work OK both on CPU and on GPU (CUDA).

I ended up implementing your patch because the idea that I had is not going to work. The partial RoPE requires that the rotated dimensions are at the beginning of the head, but here they are at the end. There also seem to be other limitations, so I guess the current view + cont implementation should be ok

Roger that, I checked the performance difference on the CPU caused by added ggml_cont operations and it's negligible (around 2%):

without ggml_cont:

llama_print_timings: prompt eval time =     106.20 ms /    17 tokens (    6.25 ms per token,   160.07 tokens per second)
llama_print_timings:        eval time =   12712.43 ms /   472 runs   (   26.93 ms per token,    37.13 tokens per second)

with ggml_cont:

llama_print_timings: prompt eval time =     108.40 ms /    17 tokens (    6.38 ms per token,   156.83 tokens per second)
llama_print_timings:        eval time =   12875.21 ms /   472 runs   (   27.28 ms per token,    36.66 tokens per second)

The CUDA kernels could be extended to support non-contiguous RoPE and norm, but I'm not sure if this is going to be worth it - I guess it might even hurt the performance. So instead, I added asserts to prevent from using non-contiguous data with these ops

If needed I guess there could be separate implementations for contiguous and non-contiguous tensors. However, it looks like a lot of work and the efficiency gains won't be large (or like you said on CUDA the performance may be even worse), so IMHO it can stay this way.

@oldmanjk
Copy link

oldmanjk commented Jun 1, 2024

Possibly relevant - #2445 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
model Model specific python python script changes Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level
Projects
None yet
9 participants