Could you please elaborate more on how exactly the nd_mamba2 deals with 2D and 3D data? Does the code follow VSSD (Shi et al., 2024), where all tokens share the same hidden state? How about the forward/backward speed compared to single-scan Mamba2?
For 2D and 3D data, padding (0 padding) is used here to divide the block size evenly. Then a simple straightening operation was performed without using different unfolding methods such as (z, m, row, col). Because Mamba2 achieves unity with the transformer, I made a bold guess that, like the transformer, the attention effect generated at any position in Mamba2 is equivalent. Therefore, the expansion of z, m, row, and col will not affect the final result. But this is just a speculation, without experimental support.
The kernel of the code adopts the official implementation of MAMBA2, removing single step (image tasks are overall inference, without single step). The setting for token sharing with the same hidden state remains consistent with the official implementation
As stated in 2, the speed of this module depends on the official implementation speed of MAMBA2, which is consistent with the forward/backward speed of a single scan in the official description.
Thanks for this amazing codebase!
Could you please elaborate more on how exactly the nd_mamba2 deals with 2D and 3D data? Does the code follow VSSD (Shi et al., 2024), where all tokens share the same hidden state? How about the forward/backward speed compared to single-scan Mamba2?
Much appreciated!