alxndrTL / mamba.py

A simple and efficient Mamba implementation in pure PyTorch and MLX.
MIT License
960 stars 86 forks source link

support non-zero H[0] inputs #38

Closed Zizzzzzzz closed 3 months ago

Zizzzzzzz commented 3 months ago

Very wonderful work! Does the current pscan code support non-zero H[0] inputs? image

alxndrTL commented 3 months ago

Hello, thank you! No indeed, the current pscan doesn't support it (as Mamba 1 always uses H[0]=0). But I'm sure only a small modification is needed to have an arbitrary H[0].

I will work on that and come back to you shortly

alxndrTL commented 3 months ago

So it seems to me that a solution is to modify the A and X you give to pscan : you modify A with the first time element being 1's, and X with the first time element being H[0]. So you would pass A of shape (B, L+1, D, N) and X of shape (B, L+1, D, N). Maybe there is a nicer solution but this is the first one I thought of (and it works with the current pscan).

Hope this helps

Zizzzzzzz commented 3 months ago

Thank you very much!!! I will try this method.

alxndrTL commented 3 months ago

Hello, for your information, the new pscan now supports a custom H_0. It is not in the main branch as of now, but in the mem-update branch of the repo. It's still in the pscan.py file (it now takes an optional argument H_0)