state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.53k stars 1.05k forks source link

Vanishing gradient problem with more layer #527

Open yjdy opened 4 weeks ago

yjdy commented 4 weeks ago

Dear author, I stacked multiple mamba layers to form a model, and trained the model from scatch. When I just stacked 4 layers, the perfomance was very good. So I decided to increase the number of layers.

But, when I stacked 8 layers, I met vanishing gradient problem. Specifically, the model stayed at low performance, which will not increase with training. I have increased training data to 20 times more than before, but the problem is still there. Besides, I have tested several methods, such as different lr, resnet etc, but I just can not solve the problem.

Have you met these problem before? Or any valuable suggestions?

Best regards

younghoon020 commented 2 weeks ago

Don't know how you trained it in what dataset. Would be great if you share your training code snippet. For example, what optimizer you used? Did you train in fp32? Have you tried gradient clipping? Have you plotted the gradient flow?

yjdy commented 1 week ago

I found the reason. Because of the rmsnorm in mamba2. When set rmsnorm=False, the training of deeper network is normal