Open Unrealluver opened 11 months ago
I have similar issue. This is my result
---------------------------------------------------------------------- Captured stdout call ----------------------------------------------------------------------
Output max diff: 0.25
Output mean diff: 0.005645751953125
==================================================================== short test summary info =====================================================================
FAILED tests/ops/triton/test_selective_state_update.py::test_causal_conv1d_update[2048-64-True-itype2] - AssertionError: assert False
FAILED tests/ops/triton/test_selective_state_update.py::test_causal_conv1d_update[4096-64-True-itype2] - AssertionError: assert False
================================================================= 2 failed, 72 passed in 20.46s ==================================================================
It seems there are some correctness issue
I have the similar question:
test_mamba_inner_fn(True, True, 128, torch.float32, torch.float32)
Then I get:
Output max diff: 0.0
Output mean diff: 0.0
dxz max diff: 320.0
dA max diff: 0.0
dD max diff: 0.0
ddelta_bias max diff: 0.0
dout_proj_weight max diff: 0.0
ddelta_proj_weight max diff: 0.0
dx_proj_weight max diff: 312.0
dconv1d_weight max diff: 512.0
dconv1d_bias max diff: 320.0
i'm getting similar results:
for example for the test test_mamba_inner_fn
with is_variable_B=True, is_variable_C=True, seqlen=128, itype=torch.float32, wtype=torch.complex64
Output max diff: 0.0
Output mean diff: 0.0
dxz max diff: 5888.0
dx max diff: 5888.0
dz max diff: 1056.0
dA max diff: 1417760.5
dD max diff: 0.00042724609375
ddelta_bias max diff: 1024.0
dout_proj_weight max diff: 0.0
ddelta_proj_weight max diff: 98304.0
dx_proj_weight max diff: 11264.0
dconv1d_weight max diff: 20480.0
dconv1d_bias max diff: 28672.0
wondering whether anyone has an explanation for this?
i have changed the scale of the following tensors:
xz = (0.01 * torch.rand(bs, 2 * dim, seq_len, dtype=itype, device=device)).requires_grad_()
g = torch.randn_like(out) * 0.01
the discrepancies become rather small then, possibly the large differences in the gradients were due to some numerical instabilities when the numbers become too large.
Greetings!
Thanks for your awesome work! The GPU-level optimization of the mamba operator is impressive to me. But I face the accuracy problem when trying the unit test in running
mamba/tests/ops/test_selective_scan.py
. I got the output below:It is worth noticing that the variables' gradients, such as
dxz
,dA
,ddelta_bias
, etc., are different with the reference mamba implement. Could you share some reasons for this? And how could we judge the mamba operator's reliability from what kind of results? Looking forward to your reply~