alxndrTL / mamba.py

A simple and efficient Mamba implementation in pure PyTorch and MLX.
MIT License
960 stars 86 forks source link

MuP #49

Closed norikazu99 closed 2 months ago

norikazu99 commented 2 months ago

Hello, and thanks for sharing your code. I stumbled on your repo while looking for how to implement mup for mamba. It seems like you implemented mup without scaling any attn-like matrices. Does that mean that mup work withs mamba out of the box as long as the right initializations (from mup package for example) are implemented?

Thanks for your help.

alxndrTL commented 2 months ago

Hello, yes indeed I applied muP out of the box and the first tests seem to work. I don't know if you've seen the coord checks I've done they look really great :

This is SP : mon_image_mamba2_sp

This is with muP : mon_image_mamba2_mup

A LR a 1e-3 was used so quite high, and 5e-3 still gives kind of ok results : mamba2_mup_5e-3

Next step is to actually test if HPs transfer that's what I'm working on right now, will update here and on the README

alxndrTL commented 2 months ago

Just pushed the mup implementation for Mamba-1 and Mamba-2 on the repo. It works really great : successfully transfer LR from a 172k model to a +100M model. See the PR for more details.

norikazu99 commented 2 months ago

Thanks for sharing! Check out https://arxiv.org/pdf/2407.17465. It adds on mup and could save you some time in your future experiments. Still in the process of confirming how well it transfers for my use cases.

alxndrTL commented 2 months ago

Yes I read it a few days ago. The implementation is quite heavy (as opposed to regular muP), only for Transformers, and need fused ops to not slow down training. That's why I started simple with muP only. But definitively something to look for in the future.