alxndrTL / mamba.py

A simple and efficient Mamba implementation in pure PyTorch and MLX.
MIT License
960 stars 86 forks source link

Can I translate your PScan in Jax? #45

Closed clementpoiret closed 3 months ago

clementpoiret commented 3 months ago

Thanks a lot for your work and minimal implementation!

For work, I need to implement some models to benchmark, and I really want to include mamba-related models.

To do so, I created Jimmy (for Jax Image Model :) really not prod-ready at all yet) https://github.com/clementpoiret/jimmy But porting the CUDA code into something that can be compiled by XLA is just beyond what I can do rn.

With credits of course, may I port your pscan to Jimmy?

Thanks!

alxndrTL commented 3 months ago

Hello, thanks for the comment!

Yes no problem, good luck! You may also credit françoisfleuret who is the behind the very first version which I built upon. Also, don't know if you've seen, but there is a doc that explains the pscan (almost finished :))

clementpoiret commented 3 months ago

Oh nice, thanks for the doc, didn't saw it :)