ml-explore / mlx-examples

Examples in the MLX framework
MIT License
6.29k stars 897 forks source link

Add support for falcon-mamba #1074

Closed ilyasch2 closed 3 weeks ago

ilyasch2 commented 1 month ago

This PR adds the support of FalconMamba-7B. We add RMS Norms on Mamba B, C and delta parameters when use_bcdt_rms is set to True. FalconMamba uses non-learnable RMSNorms; However the non-learnable RMS is not supported in mlx. Thus we initialize RMS weights in Mamba forward. We tried initializing the RMSNorms in the init, with a patch to load them from state dict with dummy weights on the fly, but the generation was two times slower!... The cleanest fix to this would be to add support for non-learnable RMSNorms.

Here is a script to test it

from mlx_lm import load, generate

model, tokenizer = load("tiiuae/falcon-mamba-7b")
response = generate(model, tokenizer, prompt="hello", verbose=True)

CC @awni

ilyasch2 commented 3 weeks ago

@awni Could you please take some time to review this PR and suggest any change you think is necessary. Thank you!