state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.93k stars 1.1k forks source link

Sequence parallelism in the mixer (Context Parallelism) #482

Open TranSirius opened 3 months ago

TranSirius commented 3 months ago

The general question is, does mamba-ssm currently support sequence parallelism in the mixer?

I noticed that Section 8.2 in the paper of Mamba2 proposes a potential way to split activation among multiple devices during mixing information among tokens. Does current version of mamba-ssm support such context-parallelism scheme?

By the way, if it is possible to confirm that, the suggested implementation should be incorporated into the fast scan algorithm. As a parallel tree traversing algorithm, each node should be calculated on a single device. In the leaf-to-root pass, the communication will be invoked when two brother nodes are calculated on different devices to transmit the hidden information; in the root-to-leaf pass, the communication is similarly triggered. I show a simple illustration on how to implement CP. As a result, the CP_SIZE is also determined by the number of children when implementing the fast scan algorithm. (Just to confirm whether I am understanding correctly, thx)

image
josiahbjorgaard commented 3 weeks ago

I don't believe it does, as I'm working on an implementation of it.

I think there is another detail to it than you've sketched out here. There is a weighted cumulative sum that occurs over all states from previous chunks in the sequence. This will need to be updated for each group of chunks as they've been scattered to multiple GPUs. It's in figure 7 of the Mamba 2 paper - that is, the yellow arrows. It is here in the code, but not modified for context parallel.

Distributing it either requires gathering final states from all GPUs operating on previous chunks of the sequence, calculation of weight updates (the products of A elements) and then a weighted sum reductions per GPU of previous 'final' states per GPU, or alternatively sequential point-to-point GPU by GPU in order to weight 'final' states sequentially.

It looks to me like everything else can be computed per chunk (i.e. per GPU), except the convolution on the sequence which runs prior to the mixer, which may also need to be modified to prevent a bottleneck when running sequence/context parallel.

Would be great to get some feedback on this if anyone else is working on it or understands the context parallel strategy for the SSD model.