alxndrTL / mamba.py

A simple and efficient Mamba implementation in pure PyTorch and MLX.
MIT License
960 stars 86 forks source link

Up sweep in parallel scan #46

Closed anhtienng closed 2 months ago

anhtienng commented 2 months ago

Thank you for your great work.

In your parallel scan, when Xa.size(2) == 2 or Xa.size(2) == 1, why you skip the up-sweep operation ? (line 64 in file)

Is it related to the fact this isn't a static tree but a representation of the evolution of our tensor in memory in your document ?

anhtienng commented 2 months ago

I found it by myself. So when Xa.size(2) == 2 or Xa.size(2) == 1, we don't need up-sweep to calculate the cumulative sums, they are already done during down-sweep