MzeroMiko / VMamba

VMamba: Visual State Space Models,code is based on mamba
MIT License
1.83k stars 100 forks source link

number of Triton CSM grad doesn't match the pytorch version #186

Closed ShixuanGu closed 1 month ago

ShixuanGu commented 1 month ago

Thanks for the great work! When running vmamba_checks.py, for comparison on "cross_scan, cross_merge, CrossScanTriton, CrossMergeTriton", I got: "test cross scan tensor(0., device='cuda:0', grad_fn=) tensor(1.9073e-06, device='cuda:0') tensor(1.9073e-06, device='cuda:0', grad_fn=) tensor(0., device='cuda:0')"

it seems the gradient of cross_scan_pytorch doesn't match triton, and the forward result of cross_merge doesn't match triton. Wondering whether it is the correct case related to triton, as the number is the same. Would it cause any potential problems?

And it seems random inputs will always give the same mismatch number.

MzeroMiko commented 1 month ago

Yes, it is normal.

This is because that the flip implementation of pytorch may not exactly equals flip of the Tensor.

Here's a small experiment for you:

a = torch.randn((233, 455)); print(a.flip(dims=[-1]).sum() - a.sum())

in which every time you'd get different answers.