Closed ShixuanGu closed 1 month ago
Yes, it is normal.
This is because that the flip implementation of pytorch may not exactly equals flip of the Tensor
.
Here's a small experiment for you:
a = torch.randn((233, 455)); print(a.flip(dims=[-1]).sum() - a.sum())
in which every time you'd get different answers.
Thanks for the great work! When running vmamba_checks.py, for comparison on "cross_scan, cross_merge, CrossScanTriton, CrossMergeTriton", I got: "test cross scan tensor(0., device='cuda:0', grad_fn=)
tensor(1.9073e-06, device='cuda:0')
tensor(1.9073e-06, device='cuda:0', grad_fn=)
tensor(0., device='cuda:0')"
it seems the gradient of cross_scan_pytorch doesn't match triton, and the forward result of cross_merge doesn't match triton. Wondering whether it is the correct case related to triton, as the number is the same. Would it cause any potential problems?
And it seems random inputs will always give the same mismatch number.