ml-explore / mlx-examples

Examples in the MLX framework
MIT License
5.5k stars 791 forks source link

Su-RoPE(Rotary Position Embedding) for Phi-3 #813

Closed JosefAlbers closed 3 weeks ago

JosefAlbers commented 4 weeks ago

This Pull Request introduces an implementation of the su-RoPE used by some of the Phi-3 language models.

Key Changes:

Future Optimization:

JosefAlbers commented 4 weeks ago

Thanks for the suggestion, but I'm at this moment not entirely sure if it's directly feasible since the mx.fast.rope applies a uniform scale across dimensions. I'll look into whether it is indeed possible.

awni commented 4 weeks ago

since the mx.fast.rope applies a uniform scale across dimensions.

Indeed it may not be possible with the current API 🤔 . Have you seen this Su ROPE used with many other models or is Phi3 a bit of an exception?

JosefAlbers commented 3 weeks ago

Have you seen this Su ROPE used with many other models or is Phi3 a bit of an exception?

Phi-3 seems to be the only one I've come across using Su ROPE, but my knowledge of the broader model landscape is yet very limited.

JosefAlbers commented 3 weeks ago

@awni Thanks for the suggestion, but I'm not sure I follow how the shapes would align for element-wise multiplication in this case. When I tried implementing it, I ran into a ValueError due to a shape mismatch (e.g., ValueError: Shapes (1,32,3,96) and (1,1,48,6) cannot be broadcast.). Would you mind elaborating a bit more on using multiply?

awni commented 3 weeks ago

@JosefAlbers sorry my suggestion was not fully complete. Take a look at the diff I pushed to see what I meant: https://github.com/ml-explore/mlx-examples/pull/813/commits/ff569832d3a2c938ce33646ffe58bea96774a79f

Mostly the idea is to rely on broadcasting and multiplication rather than use the outer product and remove all the explicit shape expansions to simplify it.

JosefAlbers commented 3 weeks ago

@awni Oh, I get it now. Wow.

JosefAlbers commented 3 weeks ago

@awni sorry for the silly error I made in the su_rope.py (I ran the test on python 3.12 when the circleci runs the tests on python <3.9). I fixed the bug now.

JosefAlbers commented 3 weeks ago

@awni sorry for the trouble, I didn't know about isort.

awni commented 3 weeks ago

@awni sorry for the trouble, I didn't know about isort.

No problem at all. You should setup the pre-commit hooks. It will run all the formatting stuff for you https://github.com/ml-explore/mlx-examples/blob/main/CONTRIBUTING.md#pull-requests

JosefAlbers commented 3 weeks ago

@awni You're most welcome! And thank you for your help and patience throughout.