alxndrTL / mamba.py

A simple and efficient Mamba implementation in pure PyTorch and MLX.
MIT License
960 stars 86 forks source link

Values of deltaA are very large #53

Closed anhtienng closed 2 months ago

anhtienng commented 2 months ago

Hi,

The value of A is very large after discretization. deltaA = torch.exp(delta.unsqueeze(-1) * A)

The big value makes the loss NaN. I also found the similar problem in the original mamba repo, but I can't find the solution. I have try the ZOH discretization to avoid the exp function, but it still exits.

Do you know how to solve it ? Thank you.

alxndrTL commented 2 months ago

The author of Jamba (hybrid of Mamba & attention) apply inner layernorms to dt (as well as B and C). I've implemented this in the mamba.py file : https://github.com/alxndrTL/mamba.py/blob/eddec5da76da6594850ea86a7afa56c9ab6b5ac7/mambapy/mamba.py#L246C8-L246C58

Maybe this will help ?

anhtienng commented 2 months ago

The layernorms is not applied for A in the code now.

So you mean I could try to apply it for A ?

alxndrTL commented 2 months ago

No but it is applied to delta, which is used to compute deltaA, which is very big in your case so that's why I proposed this

anhtienng commented 2 months ago

I found the problem, it's because I forgot to use softplus for delta after the projection. My bad.

Thank you very much.

alxndrTL commented 2 months ago

Cool it worked out!