Open YicongHong opened 2 weeks ago
Whoops that's a great catch. I'll go ahead and push something in a couple minutes to fix it. All you need to do is set D=D when the module layer when calling the function.
As for NaN gradients, I'm looking for parts of the kernel which cause this. Haven't identified anything yet, but I'll lyk.
Should be the same for z. My classes resumed so I was rather quick to push the layer out. Thanks for the catch :)
Alright, I pushed a fix. Lmk if NaNs still occur, as this is something I haven't been able to test for personally.
Thanks @Hprairie; after passing D=self.D
, I got the following error:
bimamba2/src/ssd/bi/ssd_chunk_scan.py":138:11): error: operation scheduled before its operands
Also, it seems that self.fc_D
is still not used? I thought self.fc_D
is for Dx and self.D
is a bias term?
1) The first error is with triton and I found occurs when you compile for the first time. It shouldn't keep occuring and shouldn't affect anything from what I have seen. 2) You are right, hydra does use self.D as a bias and self.fc_D. I am away from a computer, but will make a fix tmrw. I will give two options, as using D without fc_D is more canonical to Mamba. If you want access to it immediately just use F.linear as in Hydra and then don't pass D to the optimized kernel.
Thanks again for pointing this out, I learned something new.
Thanks @Hprairie,
error: operation scheduled before its operands
only occurs at the start and doesn't stop anything.Hmmm okay I'll try to block out some time to look into the NaN problem.
Yes, I have the same problem no matter how high you set the batch size or learning rate, they are the same. Colab link With ViT, remove the attetion and replace it with Bi-Mamba 2:https://colab.research.google.com/drive/1rgXkwnlevzZ0YPbefQS8qHRe7gFlb4J-?authuser=3. The loss result is always NaN, Bi-Mamba 2 gives too high gradient
Hi @Hprairie, I previously built mamba-2/hydra-based models, and I am now trying to replace the layers with your bi-mamba2 module. However, I found the new model can easily get invalid gradients (e.g., infinite gradient norm) that never appeared with mamba-2/hydra.
I tested with both
torch==2.1.0, triton==3.0.0, cu122
andtorch==2.4.0, triton==3.0.0, cu121
, it seems that the more bi-mamba2 layers I stack or the more multi-processes I used, the easier the model gets this problem.Any ideas?
Besides, you mentioned that the kernel implements y=SS(x)+flip(SS(flip(x)))+Dx, but in BiMamba2() Line 108, the skip parameters
self.D
andself.fc_D
are not used for Dx. Can I ask how to pass these parameters tobimamba_chunk_scan_combined()
, or we should do something similar as in Hydra?Thanks!!!