Human9000 / nd-Mamba2-torch

Only implemented through torch: "bi - mamba2" , "vision- mamba2 -torch". support 1d/2d/3d/nd and support export by jit.script/onnx;
155 stars 5 forks source link

Explain ND processing and speed? #11

Open YicongHong opened 1 week ago

YicongHong commented 1 week ago

Thanks for this amazing codebase!

Could you please elaborate more on how exactly the nd_mamba2 deals with 2D and 3D data? Does the code follow VSSD (Shi et al., 2024), where all tokens share the same hidden state? How about the forward/backward speed compared to single-scan Mamba2?

Much appreciated!

Human9000 commented 1 day ago

Thank you for your question

  1. For 2D and 3D data, padding (0 padding) is used here to divide the block size evenly. Then a simple straightening operation was performed without using different unfolding methods such as (z, m, row, col). Because Mamba2 achieves unity with the transformer, I made a bold guess that, like the transformer, the attention effect generated at any position in Mamba2 is equivalent. Therefore, the expansion of z, m, row, and col will not affect the final result. But this is just a speculation, without experimental support.
  2. The kernel of the code adopts the official implementation of MAMBA2, removing single step (image tasks are overall inference, without single step). The setting for token sharing with the same hidden state remains consistent with the official implementation
  3. As stated in 2, the speed of this module depends on the official implementation speed of MAMBA2, which is consistent with the forward/backward speed of a single scan in the official description.