This PR adds the support of FalconMamba-7B. We add RMS Norms on Mamba B, C and delta parameters when use_bcdt_rms is set to True. FalconMamba uses non-learnable RMSNorms; However the non-learnable RMS is not supported in mlx. Thus we initialize RMS weights in Mamba forward. We tried initializing the RMSNorms in the init, with a patch to load them from state dict with dummy weights on the fly, but the generation was two times slower!... The cleanest fix to this would be to add support for non-learnable RMSNorms.
This PR adds the support of FalconMamba-7B. We add RMS Norms on Mamba B, C and delta parameters when
use_bcdt_rms
is set toTrue
. FalconMamba uses non-learnable RMSNorms; However the non-learnable RMS is not supported in mlx. Thus we initialize RMS weights in Mamba forward. We tried initializing the RMSNorms in the init, with a patch to load them from state dict with dummy weights on the fly, but the generation was two times slower!... The cleanest fix to this would be to add support for non-learnable RMSNorms.Here is a script to test it
CC @awni