Open yjdy opened 4 weeks ago
Don't know how you trained it in what dataset. Would be great if you share your training code snippet. For example, what optimizer you used? Did you train in fp32? Have you tried gradient clipping? Have you plotted the gradient flow?
I found the reason. Because of the rmsnorm in mamba2. When set rmsnorm=False, the training of deeper network is normal
Dear author, I stacked multiple mamba layers to form a model, and trained the model from scatch. When I just stacked 4 layers, the perfomance was very good. So I decided to increase the number of layers.
But, when I stacked 8 layers, I met vanishing gradient problem. Specifically, the model stayed at low performance, which will not increase with training. I have increased training data to 20 times more than before, but the problem is still there. Besides, I have tested several methods, such as different lr, resnet etc, but I just can not solve the problem.
Have you met these problem before? Or any valuable suggestions?
Best regards