ml-explore / mlx-examples

Examples in the MLX framework
MIT License
6.3k stars 898 forks source link

adding-support-for-mamba2 #1009

Open Goekdeniz-Guelmez opened 2 months ago

hg0428 commented 1 month ago

Codestral Mamba and other models rely on the Mamba2 architecture. Hopefully we can get this soon.

awni commented 4 weeks ago

How is it going here? Still very slow?

Goekdeniz-Guelmez commented 4 weeks ago

How is it going here? Still very slow?

Unfortunately Yes, I did look into the transformers implementation and rewrote the slow working Mamba2Mixer class, I haven’t got time to continue working on it, but will continue in the weekend.

Goekdeniz-Guelmez commented 1 week ago

@awni I finally got it to work!

Inference:

python -m mlx_lm.generate --model rokyang/mamba2-130m-hf --prompt "hello" --max-tokens 22 --ignore-chat-templat
Fetching 5 files: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 65948.18it/s]
==========
Prompt: hello
, I am a little girl, I am a little girl, I am a little girl, I am a
==========
Prompt: 1 tokens, 7.499 tokens-per-sec
Generation: 22 tokens, 28.258 tokens-per-sec
Peak memory: 0.454 GB
python -m mlx_lm.generate --model rokyang/mamba2-130m-hf --prompt "hello world" --max-tokens 22 --ignore-chat-templat
Fetching 5 files: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 55043.36it/s]
==========
Prompt: hello world

hello world
hello world
hello world
hello world
hello world
hello world
hello world
==========
Prompt: 2 tokens, 5.552 tokens-per-sec
Generation: 22 tokens, 24.904 tokens-per-sec
Peak memory: 0.454 GB

Training

python -m mlx_lm.lora \                                                           (adding-support-for-mamba2|-1)
    --model rokyang/mamba2-130m-hf \
    --train \
    --data /Users/gokdenizgulmez/Library/Mobile\ Documents/com\~apple\~CloudDocs/Datastes/data_tyni \
    --iters 5 \
    --batch-size 1 \
    --num-layers 1 \
    --val-batches 1 \
    --steps-per-report 1 \
    --adapter-path /Users/gokdenizgulmez/Desktop/mamba2-pretrain \
    --max-seq-length 12
Loading pretrained model
Fetching 5 files: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 87381.33it/s]
Loading datasets
Training
Trainable parameters: 0.956% (1.233M/128.988M)
Starting training..., iters: 5
[WARNING] Some sequences are longer than 12 tokens. The longest sentence 1508 will be truncated to 12. Consider pre-splitting your data to save memory.
[WARNING] Some sequences are longer than 12 tokens. The longest sentence 1250 will be truncated to 12. Consider pre-splitting your data to save memory.
Iter 1: Val loss 7.408, Val took 1.578s
Iter 1: Train loss 7.408, Learning Rate 1.000e-05, It/sec 0.405, Tokens/sec 4.450, Trained Tokens 11, Peak mem 2.173 GB
[WARNING] Some sequences are longer than 12 tokens. The longest sentence 1692 will be truncated to 12. Consider pre-splitting your data to save memory.
Iter 2: Train loss 7.275, Learning Rate 1.000e-05, It/sec 2.110, Tokens/sec 23.212, Trained Tokens 22, Peak mem 2.189 GB
[WARNING] Some sequences are longer than 12 tokens. The longest sentence 1397 will be truncated to 12. Consider pre-splitting your data to save memory.
Iter 3: Train loss 7.093, Learning Rate 1.000e-05, It/sec 2.694, Tokens/sec 29.637, Trained Tokens 33, Peak mem 2.189 GB
[WARNING] Some sequences are longer than 12 tokens. The longest sentence 1238 will be truncated to 12. Consider pre-splitting your data to save memory.
Iter 4: Train loss 6.880, Learning Rate 1.000e-05, It/sec 2.803, Tokens/sec 30.829, Trained Tokens 44, Peak mem 2.189 GB
[WARNING] Some sequences are longer than 12 tokens. The longest sentence 1265 will be truncated to 12. Consider pre-splitting your data to save memory.
[WARNING] Some sequences are longer than 12 tokens. The longest sentence 802 will be truncated to 12. Consider pre-splitting your data to save memory.
Iter 5: Val loss 6.641, Val took 0.175s
Iter 5: Train loss 6.641, Learning Rate 1.000e-05, It/sec 2.754, Tokens/sec 30.298, Trained Tokens 55, Peak mem 2.189 GB
Saved final weights to /Users/gokdenizgulmez/Desktop/mamba2-pretrain/adapters.safetensors.
awni commented 1 week ago

Very nice!! What's a good model to test with? The one you are using doesn't look like it generates high-quality responses.

hg0428 commented 1 week ago

Very nice!! What's a good model to test with? The one you are using doesn't look like it generates high-quality responses.

Mamba Codestral or one of the larger base Mamba2 models.

awni commented 1 week ago

I tried running codestral and it crashed with a weight size mismatch error:

ValueError: Expected shape (16768, 4096) but received shape (18560, 4096) for parameter backbone.layers.0.mixer.in_proj.weight

Looks like the weight shape is not computed correctly for that model?

This is what I ran for reference:

mlx_lm.generate --model mistralai/Mamba-Codestral-7B-v0.1 --prompt "Write a quick sort in c++" -m 128
Goekdeniz-Guelmez commented 1 week ago

Ahh ok, yea I didn't try Codestral, the model I used is the safetensor convert from the OG states-space account called rokyang/mamba2-130m-hf, I'll look into the Codestral shape problem later this day.