state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.62k stars 1.06k forks source link

Any suggestions for regularization? #88

Open drscotthawley opened 8 months ago

drscotthawley commented 8 months ago

Dear Mamba-SSM team, congratulations on your success! Obviously many of us are excited about exploring the applications of your work.

Since there's no dropout in your model, what do you suggest for imposing regularizations? I applied the drop-in Mamba replacement that was used with Karpathy's GPT example to my MIDI Transformer Colab. (which is paired with a pre-Mamba blog post).

Compared to the original multi-head attention version, the Mamba-powered version runs faster, but also overfits a lot more. Meaning that losses on the validation set bottom out sooner but higher with Mamba than with the vanilla multi-head attention.

So far, I've tried

...trying values expressed in your paper. But nothing I've tried seems to have any regularizing effect.

What Mamba parameters would you suggest tweaking to improve generalization?

Thanks!

chazzmoney commented 8 months ago

I’m very interested to see the response here from the authors, but I’ll say something potentially wrong.

i’m not familiar with the architecture used to create MIDI using the transformer mechanism, but in the original MAMBA paper they were pretty clear that continuous signals do not benefit form MAMBA, and that something like S4 seemed to outperform. So the performance you may be seeing may have less to do with regularization and more to do with the application of MAMBA to this problem.

tridao commented 8 months ago

You can use dropout, just like Transformers. It's not implemented here but you can add it.

ElliottDyson commented 7 months ago

You can use dropout, just like Transformers. It's not implemented here but you can add it.

This thread may be more useful if tagged as a feature request then? (Just for traceability to mitigate it from being lost)

Anri-Lombard commented 2 months ago

This is still an issue; I trained the model on 10 epochs for a specific use case and the loss just decreases drastically to overfit quickly Screenshot 2024-07-06 at 15 52 09

Anri-Lombard commented 2 months ago

Has anyone implemented this with some success?

windsornguyen commented 2 months ago

I find that adding dropout decreases performance for state space models. Does anyone else also observe this phenomenon?

Anri-Lombard commented 2 months ago

I added dropout and it actually improved performance for me. training_convergence

Anri-Lombard commented 2 months ago

Although MAMBA still does very bad in my use case, it just learned better.