state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.93k stars 1.1k forks source link

Small datasets #454

Open Anri-Lombard opened 3 months ago

Anri-Lombard commented 3 months ago

I'm curious; I've trained a 20M MAMBA model for molecular generation, and it seems to fair quite badly when trained on small datasets. I added a dropout layer since it seems to overfit otherwise, but would Mamba perhaps need a lot of intricate optimisation and regularisation to work well with smaller datasets?

I know previous LSTM and RNN models needed this (https://arxiv.org/pdf/1708.02182v1) and curious about your intuition.

hhhhpaaa commented 3 months ago

I'm also doing similar work, applying mamba to small datasets. You can try adding gradient truncation or setting a weight decay on the optimizer, which worked for me. In addition, the size of the Mamba model needs to be carefully adjusted to match the learning rate, otherwise its effect may not be as good as Transformer, and the performance gap is quite large. In fact, Mamba is slightly better than Transformer under the same parameter quantity (same hidden layer size).

Anri-Lombard commented 3 months ago

Interesting, I have a weight decay set at 0.01 and max_grad_norm=1.0. So I found for the 20M model, having a learning rate of 1e-5 makes it overfit and it has a jagged learning rate over time, whereas simply changing to 1e-6 or 5e-6 results in much better convergence. Although my conclusions don't quite match up with your results, since a similar 20M transformer model seems to far outperform the mamba model...

hhhhpaaa commented 3 months ago

Interesting, I have a weight decay set at 0.01 and max_grad_norm=1.0. So I found for the 20M model, having a learning rate of 1e-5 makes it overfit and it has a jagged learning rate over time, whereas simply changing to 1e-6 or 5e-6 results in much better convergence. Although my conclusions don't quite match up with your results, since a similar 20M transformer model seems to far outperform the mamba model...

Yeah, you're right. I apologise for the inaccuracy of my previous answer. Dynamically changing learning rate is more suitable for Mamba on small datasets. My ablation experiments may be too rough. My network design uses transformer encoder and GNNs classifier. And the model parameters are less than 1M, in addition, my model introduces additional design. I compared two datasets, and for the encoder part of my network after using Mamba, one of them has roughly around 0.02 accuracy improvement, and the other dataset has less than 0.01 improvement. In my preliminary experiments, Mamba does overfit more easily than Transformer. And relatively speaking, Mamba requires more fine tuning, and I have the gradient cutoff set very large. Also, in my experiments I found that with the hidden layer size set to 64, my dataset had only about 600 entries of dimension roughly 500, where the sentence length was set to 500 // 64. The two Mamba stacking designs would result in a loss of NAN, which would have to be avoided by adding a gradient truncation. In contrast, this does not happen with Transformer, and this has bothered me for a long time.

Anri-Lombard commented 3 months ago

Thanks for sharing @hhhhpaaa; very interesting! 🙌

YCHYZW commented 1 month ago

Interesting, I have a weight decay set at 0.01 and max_grad_norm=1.0. So I found for the 20M model, having a learning rate of 1e-5 makes it overfit and it has a jagged learning rate over time, whereas simply changing to 1e-6 or 5e-6 results in much better convergence. Although my conclusions don't quite match up with your results, since a similar 20M transformer model seems to far outperform the mamba model...

Yeah, you're right. I apologise for the inaccuracy of my previous answer. Dynamically changing learning rate is more suitable for Mamba on small datasets. My ablation experiments may be too rough. My network design uses transformer encoder and GNNs classifier. And the model parameters are less than 1M, in addition, my model introduces additional design. I compared two datasets, and for the encoder part of my network after using Mamba, one of them has roughly around 0.02 accuracy improvement, and the other dataset has less than 0.01 improvement. In my preliminary experiments, Mamba does overfit more easily than Transformer. And relatively speaking, Mamba requires more fine tuning, and I have the gradient cutoff set very large. Also, in my experiments I found that with the hidden layer size set to 64, my dataset had only about 600 entries of dimension roughly 500, where the sentence length was set to 500 // 64. The two Mamba stacking designs would result in a loss of NAN, which would have to be avoided by adding a gradient truncation. In contrast, this does not happen with Transformer, and this has bothered me for a long time.

I had the same problem, I put mamba into the rnn model and found that it did not converge and overfitted directly with poor results, thanks for giving me new directions to think about with your discussion.

Anri-Lombard commented 1 month ago

Just a quick update on this. We trained small (20M) and larger (95M) models on a 1.6M molecule dataset (considered relatively small) and a larger 22M molecule dataset for de novo generation. In both cases it matched the transformers performance, and as the size increased it outperformed transformers in terms of training time and speed. I'll be finishing my paper soon and could pop it here if possible 👍 For both the small and large models I used the mixermodel of the authors without attention layers and added a dropout of 0.1 (although no dropout also worked well).

YCHYZW commented 1 month ago

Just a quick update on this. We trained small (20M) and larger (95M) models on a 1.6M molecule dataset (considered relatively small) and a larger 22M molecule dataset for de novo generation. In both cases it matched the transformers performance, and as the size increased it outperformed transformers in terms of training time and speed. I'll be finishing my paper soon and could pop it here if possible 👍 For both the small and large models I used the mixermodel of the authors without attention layers and added a dropout of 0.1 (although no dropout also worked well).

Congratulations and all the best.