state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.45k stars 1.05k forks source link

Question about does mamba support variable-length input or cu_seqlens like flash attention? #180

Open zigzagcai opened 6 months ago

zigzagcai commented 6 months ago

We know that flash attention supports cu_seqlens, which can remove padding for variable-length input in a batch and only store regular tokens. This can be useful for optimizing the computational efficiency when packing multiple short sequences.

So, does Mamba also have this mechanism such as variable-length input or cu_seqlens like flash attention?

tridao commented 6 months ago

Yes, there should be ways to deal with variable length. It's not implemented yet however.

zigzagcai commented 6 months ago

Got it. Thank you Tri Dao!

zigzagcai commented 6 months ago

Yes, there should be ways to deal with variable length. It's not implemented yet however.

Sorry but I still have some confusions:

Is it theoretical possible for Mamba to provide variable-length API like Flash-Attention flash_attn_varlen_qkvpacked_func (https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1698610752)? Since in most cases, for the computing efficiency, we want to concatenating short samples to get one packed sample. We know that for Transformer-based models, we can use flash-attention API which provides cu_seqlens to process packed samples.

zigzagcai commented 6 months ago

From my understanding, since conv1d and parallel associative scan in the Mamba block are linear operations, hence in theory we can make Mamba block capable of processing packed sequence with the help of attention mask or cu_seqlens. For example, we want Mamba block to processes (packed_sequence, hidden_size) instead of (batch_size, seq_length, hidden_size), as what flash attention does.

Not sure if my understanding is correct? Just curious whether it is possible to feed in one packed sequence as input (packed_sequence, hidden_size) into mamba block like what LSTM (here) or Transformer-block has been done.

zigzagcai commented 6 months ago

Just have another question, could Mamba be parallelized over seq_len dimension like what has been done in flash-attention?

tridao commented 6 months ago

It's theoretically possible to process variable lengths / packed sequences, but the implementation will be a bit tricky. Parallelizing over seq_len dimension reduces to how one would parallelize associative scan (e..g with Blelloch scan).

albertfgu commented 6 months ago

In practice, depending on your setting, you may be able to simply concatenate the sequences and pass the whole sequence in (without enforcing state resetting at sequence boundaries). I've used this in the past where it has worked fine in some settings.

deroholic commented 6 months ago

In practice, depending on your setting, you may be able to simply concatenate the sequences and pass the whole sequence in (without enforcing state resetting at sequence boundaries). I've used this in the past where it has worked fine in some settings.

It is often done that way, but it does cause sample cross contamination during training and that is usually not desirable.

albertfgu commented 6 months ago

Yes. I'm just saying sometimes it's also fine :)

zigzagcai commented 6 months ago

In practice, depending on your setting, you may be able to simply concatenate the sequences and pass the whole sequence in (without enforcing state resetting at sequence boundaries). I've used this in the past where it has worked fine in some settings.

Hi @albertfgu @tridao , I just have another confusion about mamba. Does that mean selective SSM mechanism can learn the boundary patterns by delta, or we can reset the delta -> inf to manually specifying the sequence boundaries in a cumulative sequence input? I see in the section 3.5.2 of Mamba paper and find below description: image

zigzagcai commented 6 months ago

I also see one blog on together.ai and on cartesia.ai, where the next steps shows that variable length training are on the future roadmap. It would be fantastic if mamba could provide such feature like transformer in the future! image

zigzagcai commented 5 months ago

Update: Mamba variable-length sequences has been supported in https://github.com/state-spaces/mamba/pull/244