state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
13.24k stars 1.13k forks source link

The Accuracy Problem of Mamba Operator #75

Open Unrealluver opened 11 months ago

Unrealluver commented 11 months ago

Greetings!

Thanks for your awesome work! The GPU-level optimization of the mamba operator is impressive to me. But I face the accuracy problem when trying the unit test in running mamba/tests/ops/test_selective_scan.py. I got the output below:

mamba_inner_fn
Output max diff: 0.0
Output mean diff: 0.0
dxz max diff: 512.0
dA max diff: 131072.0
dD max diff: 0.0003662109375
ddelta_bias max diff: 8.0
dout_proj_weight max diff: 80.0
ddelta_proj_weight max diff: 1024.0
dx_proj_weight max diff: 320.0
dconv1d_weight max diff: 864.0
dconv1d_bias max diff: 704.0

It is worth noticing that the variables' gradients, such as dxz, dA, ddelta_bias, etc., are different with the reference mamba implement. Could you share some reasons for this? And how could we judge the mamba operator's reliability from what kind of results? Looking forward to your reply~

jacklishufan commented 10 months ago

I have similar issue. This is my result


---------------------------------------------------------------------- Captured stdout call ----------------------------------------------------------------------
Output max diff: 0.25
Output mean diff: 0.005645751953125
==================================================================== short test summary info =====================================================================
FAILED tests/ops/triton/test_selective_state_update.py::test_causal_conv1d_update[2048-64-True-itype2] - AssertionError: assert False
FAILED tests/ops/triton/test_selective_state_update.py::test_causal_conv1d_update[4096-64-True-itype2] - AssertionError: assert False
================================================================= 2 failed, 72 passed in 20.46s ==================================================================

It seems there are some correctness issue

tyshiwo1 commented 10 months ago

I have the similar question:

test_mamba_inner_fn(True, True, 128, torch.float32, torch.float32)

Then I get:

Output max diff: 0.0
Output mean diff: 0.0
dxz max diff: 320.0
dA max diff: 0.0
dD max diff: 0.0
ddelta_bias max diff: 0.0
dout_proj_weight max diff: 0.0
ddelta_proj_weight max diff: 0.0
dx_proj_weight max diff: 312.0
dconv1d_weight max diff: 512.0
dconv1d_bias max diff: 320.0
TangTangFei commented 7 months ago

i'm getting similar results:

for example for the test test_mamba_inner_fn with is_variable_B=True, is_variable_C=True, seqlen=128, itype=torch.float32, wtype=torch.complex64

Output max diff: 0.0
Output mean diff: 0.0
dxz max diff: 5888.0
dx max diff: 5888.0
dz max diff: 1056.0
dA max diff: 1417760.5
dD max diff: 0.00042724609375
ddelta_bias max diff: 1024.0
dout_proj_weight max diff: 0.0
ddelta_proj_weight max diff: 98304.0
dx_proj_weight max diff: 11264.0
dconv1d_weight max diff: 20480.0
dconv1d_bias max diff: 28672.0

wondering whether anyone has an explanation for this?

TangTangFei commented 7 months ago

i have changed the scale of the following tensors:

xz = (0.01 * torch.rand(bs, 2 * dim, seq_len, dtype=itype, device=device)).requires_grad_()
g = torch.randn_like(out) * 0.01

the discrepancies become rather small then, possibly the large differences in the gradients were due to some numerical instabilities when the numbers become too large.