jxiw / MambaInLlama

Official Repository of The Mamba in the Llama: Distilling and Accelerating Hybrid Models
https://arxiv.org/abs/2408.15237
Apache License 2.0
169 stars 12 forks source link
large-language-models state-space-model

MambaInLlama

This repository contains the code and released models for our paper.

MambaInLlama MambaInLlama

Our goal is to distill a large Transformer into a (Hybrid)-Mamba model while preserving the generational quality with the best effort. Typically, you only need 8x80G A100 (with very limited resources) and run for 3 to 4 days to reproduce our results. Our approach can be used for both base models and chat models.

Approach

  1. Stepwise layer alignment (Optional). Replace the attention layers by Mamba2, one by one in a stepwise manner.
  2. End to end distillation (Most important). Minimize KL divergence loss between the student and teacher models. You can consider to use a larger teacher model to get better results.
  3. Instruction tuning (Optional). For simplicity, we use SFT + DPO for this process.

Evaluation

Please follow the instructions here. Our evaluation includes: a. Standard tasks in LM Eval, b. Chat Benchmarks and here, c. Reasoning tasks Math and Code Reasoning Benchmarks, and d. Long-range tasks, [NeedleInAHaystack](). Our goal is to provide a thorough evaluation and study.

Changelog

Released Models

Hybrid Mamba (3B) distilled from Llama3.2 3B

Check this for more details.

Models are available here.

Model MMLU
AlpacaEval
(LC win against GPT-4)
MT-Bench
(scored by GPT-4)
GSM8K (0-shot) CRUX (0-shot)
Llama-3.2-Mamba2-0.5-3B-dpo-v2 53.12 22.08 6.81 50.37 20.12

Hybrid Mamba distilled from Llama3

Teacher Model Hybrid Mamba Model - DPO Hybrid Mamba2 Model - DPO
Meta-Llama-3-8B-Instruct Mamba (1/2 attention) Mamba2 (1/2 attention)
Mamba (1/4 attention) Mamba2 (1/4 attention)
Mamba (1/8 attention) Mamba2 (1/8 attention)
Mamba2 (0 attention)
Model MMLU
(5 shots)
AlpacaEval
(LC win against GPT-4)
MT-Bench
(scored by GPT-4)
Mamba (1/2 attention) 59.26 29.61 7.35
Mamba2 (1/2 attention) 56.67 25.00 7.32
Mamba (1/4 attention) 52.68 25.85 6.86
Mamba2 (1/4 attention) 53.94 20.25 6.74
Mamba (1/8 attention) 49.20 20.76 6.46
Mamba2 (1/8 attention) 50.85 20.25 6.48
Mamba2 (0 attention) 43.19 14.49 5.64

For reproduction, please follow the instructions here.

Hybrid Mamba distilled from Zephyr

Teacher Model Hybrid Mamba Model - SFT Hybrid Mamba Model - DPO Hybrid Mamba Model - DPO
Zephyr Mamba (1/2 attention) Mamba (1/2 attention) Mamba (1/2 attention)
Mamba (1/4 attention) Mamba (1/4 attention) Mamba (1/4 attention)
Mamba (1/8 attention) Mamba (1/8 attention) Mamba (1/8 attention)
Model MMLU
(5 shots)
AlpacaEval
(LC win against GPT-4)
MT-Bench
(scored by GPT-4)
Zephyr 61.44 13.20 7.34
Mamba DPO 1 (1/2 attention) 55.23 20.66 7.12
Mamba DPO 3 (1/2 attention) 55.38 17.48 7.31
Mamba DPO 1 (1/4 attention) 50.94 17.16 7.03
Mamba DPO 3 (1/4 attention) 51.19 13.89 6.58
Mamba DPO 1 (1/8 attention) 48.35 15.32 6.39
Mamba DPO 3 (1/8 attention) 48.44 12.67 6.37

For reproduction, please follow the instructions here.

Usage

Environment

We provide an environment file that lists the specific Python package versions used in our experiments. To ensure the best reproducibility, we suggest using these same package versions. Nonetheless, you may also use alternative versions and still be able to run the program. The alignment-handbook version that we use is here. The following script is to install mamba-ssm==2.2.2 and cuda-11.8.0.

# CUDA>=11.6 needed for `mamba-ssm` and `causal-conv1d`.
conda install -c "nvidia/label/cuda-11.8.0" cuda-toolkit
# Install PyTorch (with CUDA 11.8) before everything else. those assume you are using cu118
pip install torch --index-url https://download.pytorch.org/whl/cu118

pip install causal-conv1d==1.4.0
pip install flash-attn==2.6.3

# make sure you use this alignment version
git clone https://github.com/huggingface/alignment-handbook.git
cd alignment-handbook/
git checkout 606d2e9

git clone https://github.com/huggingface/transformers.git --branch v4.43.1

# check your version matches those
# deepspeed==0.12.2
# torch==2.1.1+cu118
# transformers==4.43.1
# trl==0.8.6
# accelerate==0.33.0

If you install mamba-ssm using pip install mamba-ssm==2.2.2, you will need to manually change CONDA_ENV_PATH/site-packages/mamba_ssm/modules/mha.py to this version to support GQA, since GQA is used in Llama3. The mamba-ssm used in my experiment is from this commit.

Alternatively, you can build mamba-ssm from source, but ensure the commit is after this one, which fixes the GQA bugs in generations.

Generation Example

Mamba:

import torch
from transformers import AutoTokenizer
from mamba_inference.hybrid_wrapper import MambaTransformerHybridModelWrapper

pretrained_model_name = "JunxiongWang/MambaInLlama_0_50" # change the model that you want to test here
model = MambaTransformerHybridModelWrapper.from_pretrained(pretrained_model_name, torch_dtype=torch.bfloat16)
model.eval()

messages = [[
    {
        "role": "user",
        "content": "Farmer Brown has 20 animals on his farm, all either chickens or cows. They have a total of 70 legs, all together. How many of the animals are chickens?",
    },
]]

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
formatted_prompts = [
    tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in messages
]

prompts = [
    tokenizer.encode(formatted_prompt, return_tensors="pt", truncation=True, max_length=200)
    for formatted_prompt in formatted_prompts
]
batch_prompts = torch.cat(prompts, dim=0).cuda()

outputs = model.generate(
    input_ids=batch_prompts,
    max_length=1000,
    cg=True,
    return_dict_in_generate=True,
    output_scores=True,
    enable_timing=True,
    top_k=1,
    eos_token_id=tokenizer.eos_token_id
)

generated_text = tokenizer.batch_decode(outputs.sequences.tolist())
print(generated_text[0])

#output:
#Let's use algebra to solve this problem. We'll use the variable \( c \) for the number of chickens and \( k \) for the number of cows. We know two things from the problem statement:

#1. The total number of animals is 20: \( c + k = 20 \)
#2. The total number of legs is 70: Chickens have 2 legs each, and cows have 4 legs each. So, \( 2c + 4k = 70 \).

#Now, we'll solve the system of equations:

#From the first equation, we can express \( k \) in terms of \( c \):

#\( k = 20 - c \)

#Now, substitute \( k \) in the second equation:

#\( 2c + 4(20 - c) = 70 \)

#Solve for \( c \):

#\( 2c + 80 - 4c = 70 \)
#\( -2c = 70 - 80 \)
#\( -2c = -10 \)
#\( c = 5 \)

#So, there are 5 chickens on Farmer Brown's farm.

Mamba 2:

import torch
from transformers import AutoTokenizer
from mamba2_inference.hybrid_wrapper import MambaTransformerHybridModelWrapper

pretrained_model_name = "JunxiongWang/Mamba2InLlama_0_50" # change the model that you want to test here
model = MambaTransformerHybridModelWrapper.from_pretrained(pretrained_model_name, torch_dtype=torch.bfloat16)
model.eval()

messages = [[
    {
        "role": "user",
        "content": "Farmer Brown has 20 animals on his farm, all either chickens or cows. They have a total of 70 legs, all together. How many of the animals are chickens?",
    },
]]

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
formatted_prompts = [
    tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in messages
]

prompts = [
    tokenizer.encode(formatted_prompt, return_tensors="pt", truncation=True, max_length=200)
    for formatted_prompt in formatted_prompts
]
batch_prompts = torch.cat(prompts, dim=0).cuda()

outputs = model.generate(
    input_ids=batch_prompts,
    max_length=1000,
    cg=True,
    return_dict_in_generate=True,
    output_scores=True,
    enable_timing=True,
    top_k=1,
    eos_token_id=tokenizer.eos_token_id
)

generated_text = tokenizer.batch_decode(outputs.sequences.tolist())
print(generated_text[0])

#output:
#Let's use algebra to solve this problem. Let \( c \) represent the number of chickens and \( k \) represent the number of cows.

#We know that:
#1. The total number of animals is 20: \( c + k = 20 \)
#2. Chickens have 2 legs each, and cows have 4 legs each, giving a total of 70 legs: \( 2c + 4k = 70 \)

#Now, we can solve these equations simultaneously.

#From equation 1, we can express \( k \) in terms of \( c \):
\( k = 20 - c \)

#Substitute \( k \) in equation 2:
\( 2c + 4(20 - c) = 70 \)

#Simplify and solve for \( c \):
#\( 2c + 80 - 4c = 70 \)
#\( -2c = -10 \)
#\( c = 5 \)

#So, there are 5 chickens on Farmer Brown's farm.

Citation

If you use this codebase, or otherwise found our work valuable, please cite:

@article{junxiongdaniele2024mambainllama,
  title   = {The Mamba in the Llama: Distilling and Accelerating Hybrid Models},
  author  = {Junxiong Wang and Daniele Paliotta and Avner May and Alexander M. Rush and Tri Dao},
  journal = {arXiv preprint arXiv:2408.15237},
  year    = {2024}
}