This repository contains the code and released models for our paper.
Our goal is to distill a large Transformer into a (Hybrid)-Mamba model while preserving the generational quality with the best effort. Typically, you only need 8x80G A100 (with very limited resources) and run for 3 to 4 days to reproduce our results. Our approach can be used for both base models and chat models.
- Stepwise layer alignment (Optional). Replace the attention layers by Mamba2, one by one in a stepwise manner.
- End to end distillation (Most important). Minimize KL divergence loss between the student and teacher models. You can consider to use a larger teacher model to get better results.
- Instruction tuning (Optional). For simplicity, we use SFT + DPO for this process.
Please follow the instructions here. Our evaluation includes: a. Standard tasks in LM Eval, b. Chat Benchmarks and here, c. Reasoning tasks Math and Code Reasoning Benchmarks, and d. Long-range tasks, NeedleInAHaystack. Our goal is to provide a thorough evaluation and study.
- [2024.10.06] We simplified the procedure and distilled the Hybrid Mamba2 3B model using the Llama-3.1-8B-Instruct as the teacher model, and the Llama-3.2-3B-Instruct as the initialized model. Check this for more details.
- [2024.08.26] Hybrid Mamba models and Hybrid Mamba2 models distilled from meta-llama/Meta-Llama-3-8B-Instruct are available.
- [2024.07.18] We release first version code and models. We are distilling meta-llama/Meta-Llama-3-8B-Instruct. Stay tuned for updates.
Check this for more details.
Models are available here.
Model | MMLU |
AlpacaEval (LC win against GPT-4) |
MT-Bench (scored by GPT-4) |
GSM8K (0-shot) | CRUX (0-shot) |
---|---|---|---|---|---|
Llama-3.2-Mamba2-0.5-3B-dpo-v2 | 53.12 | 22.08 | 6.81 | 50.37 | 20.12 |
Teacher Model | Hybrid Mamba Model - DPO | Hybrid Mamba2 Model - DPO |
---|---|---|
Meta-Llama-3-8B-Instruct | Mamba (1/2 attention) | Mamba2 (1/2 attention) |
Mamba (1/4 attention) | Mamba2 (1/4 attention) | |
Mamba (1/8 attention) | Mamba2 (1/8 attention) | |
Mamba2 (0 attention) |
Model | MMLU (5 shots) |
AlpacaEval (LC win against GPT-4) |
MT-Bench (scored by GPT-4) |
---|---|---|---|
Mamba (1/2 attention) | 59.26 | 29.61 | 7.35 |
Mamba2 (1/2 attention) | 56.67 | 25.00 | 7.32 |
Mamba (1/4 attention) | 52.68 | 25.85 | 6.86 |
Mamba2 (1/4 attention) | 53.94 | 20.25 | 6.74 |
Mamba (1/8 attention) | 49.20 | 20.76 | 6.46 |
Mamba2 (1/8 attention) | 50.85 | 20.25 | 6.48 |
Mamba2 (0 attention) | 43.19 | 14.49 | 5.64 |
For reproduction, please follow the instructions here.
Teacher Model | Hybrid Mamba Model - SFT | Hybrid Mamba Model - DPO | Hybrid Mamba Model - DPO |
---|---|---|---|
Zephyr | Mamba (1/2 attention) | Mamba (1/2 attention) | Mamba (1/2 attention) |
Mamba (1/4 attention) | Mamba (1/4 attention) | Mamba (1/4 attention) | |
Mamba (1/8 attention) | Mamba (1/8 attention) | Mamba (1/8 attention) |
Model | MMLU (5 shots) |
AlpacaEval (LC win against GPT-4) |
MT-Bench (scored by GPT-4) |
---|---|---|---|
Zephyr | 61.44 | 13.20 | 7.34 |
Mamba DPO 1 (1/2 attention) | 55.23 | 20.66 | 7.12 |
Mamba DPO 3 (1/2 attention) | 55.38 | 17.48 | 7.31 |
Mamba DPO 1 (1/4 attention) | 50.94 | 17.16 | 7.03 |
Mamba DPO 3 (1/4 attention) | 51.19 | 13.89 | 6.58 |
Mamba DPO 1 (1/8 attention) | 48.35 | 15.32 | 6.39 |
Mamba DPO 3 (1/8 attention) | 48.44 | 12.67 | 6.37 |
For reproduction, please follow the instructions here.
We provide an environment file that lists the specific Python package versions used in our experiments. To ensure the best reproducibility, we suggest using these same package versions. Nonetheless, you may also use alternative versions and still be able to run the program. The alignment-handbook version that we use is here. The following script is to install mamba-ssm==2.2.2
and cuda-11.8.0.
# CUDA>=11.6 needed for `mamba-ssm` and `causal-conv1d`.
conda install -c "nvidia/label/cuda-11.8.0" cuda-toolkit
# Install PyTorch (with CUDA 11.8) before everything else. those assume you are using cu118
pip install torch --index-url https://download.pytorch.org/whl/cu118
pip install causal-conv1d==1.4.0
pip install flash-attn==2.6.3
# make sure you use this alignment version
git clone https://github.com/huggingface/alignment-handbook.git
cd alignment-handbook/
git checkout 606d2e9
git clone https://github.com/huggingface/transformers.git --branch v4.43.1
# check your version matches those
# deepspeed==0.12.2
# torch==2.1.1+cu118
# transformers==4.43.1
# trl==0.8.6
# accelerate==0.33.0
If you install mamba-ssm using pip install mamba-ssm==2.2.2
, you will need to manually change CONDA_ENV_PATH/site-packages/mamba_ssm/modules/mha.py
to this version to support GQA, since GQA is used in Llama3. The mamba-ssm used in my experiment is from this commit.
Alternatively, you can build mamba-ssm from source, but ensure the commit is after this one, which fixes the GQA bugs in generations.
Mamba:
import torch
from transformers import AutoTokenizer
from mamba_inference.hybrid_wrapper import MambaTransformerHybridModelWrapper
pretrained_model_name = "JunxiongWang/MambaInLlama_0_50" # change the model that you want to test here
model = MambaTransformerHybridModelWrapper.from_pretrained(pretrained_model_name, torch_dtype=torch.bfloat16)
model.eval()
messages = [[
{
"role": "user",
"content": "Farmer Brown has 20 animals on his farm, all either chickens or cows. They have a total of 70 legs, all together. How many of the animals are chickens?",
},
]]
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
formatted_prompts = [
tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in messages
]
prompts = [
tokenizer.encode(formatted_prompt, return_tensors="pt", truncation=True, max_length=200)
for formatted_prompt in formatted_prompts
]
batch_prompts = torch.cat(prompts, dim=0).cuda()
outputs = model.generate(
input_ids=batch_prompts,
max_length=1000,
cg=True,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
top_k=1,
eos_token_id=tokenizer.eos_token_id
)
generated_text = tokenizer.batch_decode(outputs.sequences.tolist())
print(generated_text[0])
#output:
#Let's use algebra to solve this problem. We'll use the variable \( c \) for the number of chickens and \( k \) for the number of cows. We know two things from the problem statement:
#1. The total number of animals is 20: \( c + k = 20 \)
#2. The total number of legs is 70: Chickens have 2 legs each, and cows have 4 legs each. So, \( 2c + 4k = 70 \).
#Now, we'll solve the system of equations:
#From the first equation, we can express \( k \) in terms of \( c \):
#\( k = 20 - c \)
#Now, substitute \( k \) in the second equation:
#\( 2c + 4(20 - c) = 70 \)
#Solve for \( c \):
#\( 2c + 80 - 4c = 70 \)
#\( -2c = 70 - 80 \)
#\( -2c = -10 \)
#\( c = 5 \)
#So, there are 5 chickens on Farmer Brown's farm.
Mamba 2:
import torch
from transformers import AutoTokenizer
from mamba2_inference.hybrid_wrapper import MambaTransformerHybridModelWrapper
pretrained_model_name = "JunxiongWang/Mamba2InLlama_0_50" # change the model that you want to test here
model = MambaTransformerHybridModelWrapper.from_pretrained(pretrained_model_name, torch_dtype=torch.bfloat16)
model.eval()
messages = [[
{
"role": "user",
"content": "Farmer Brown has 20 animals on his farm, all either chickens or cows. They have a total of 70 legs, all together. How many of the animals are chickens?",
},
]]
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
formatted_prompts = [
tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in messages
]
prompts = [
tokenizer.encode(formatted_prompt, return_tensors="pt", truncation=True, max_length=200)
for formatted_prompt in formatted_prompts
]
batch_prompts = torch.cat(prompts, dim=0).cuda()
outputs = model.generate(
input_ids=batch_prompts,
max_length=1000,
cg=True,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
top_k=1,
eos_token_id=tokenizer.eos_token_id
)
generated_text = tokenizer.batch_decode(outputs.sequences.tolist())
print(generated_text[0])
#output:
#Let's use algebra to solve this problem. Let \( c \) represent the number of chickens and \( k \) represent the number of cows.
#We know that:
#1. The total number of animals is 20: \( c + k = 20 \)
#2. Chickens have 2 legs each, and cows have 4 legs each, giving a total of 70 legs: \( 2c + 4k = 70 \)
#Now, we can solve these equations simultaneously.
#From equation 1, we can express \( k \) in terms of \( c \):
\( k = 20 - c \)
#Substitute \( k \) in equation 2:
\( 2c + 4(20 - c) = 70 \)
#Simplify and solve for \( c \):
#\( 2c + 80 - 4c = 70 \)
#\( -2c = -10 \)
#\( c = 5 \)
#So, there are 5 chickens on Farmer Brown's farm.
If you use this codebase, or otherwise found our work valuable, please cite:
@article{junxiongdaniele2024mambainllama,
title = {The Mamba in the Llama: Distilling and Accelerating Hybrid Models},
author = {Junxiong Wang and Daniele Paliotta and Avner May and Alexander M. Rush and Tri Dao},
journal = {arXiv preprint arXiv:2408.15237},
year = {2024}
}