johnma2006 / mamba-minimal

Simple, minimal implementation of the Mamba SSM in one file of PyTorch.
Apache License 2.0
2.54k stars 188 forks source link

Using cumsum instead of a for loop #18

Open PeaBrane opened 7 months ago

PeaBrane commented 7 months ago

There is a way to perform the selective scan with two cumulative sums or torch.cumsum, which is effectively like a parallel scan but supported by pytorch natively.

I made a minimal commit in my fork here https://github.com/PeaBrane/mamba-tiny/commit/2908f50274c10cc7bb72a273517811dae0b38a33. The correctness and functionality are tested, and I could observe an inference speed up of ~14x on an A30. But not sure how close it is to the original impl with parallel scan still. More details are here.

If intersted, it would be nice if someone could review this change, and discuss whether this could be merged here, albiet the explicitness of the code may suffer (as I understand the repo is meant to be pedagogical).

johnma2006 commented 7 months ago

Thank you, and so sorry for the late reply! I’ve been a bit busy recently, but let me figure out the best way to incorporate these ideas in a bit. Thank you!

huiserwang commented 7 months ago

I have test the original mamba implementation. It's so fast! I consider the length=3136, bs=128, and channel=192 for the input x, meanwhile, d_state=16 for B, C. The original impl achieves an inference speed up of ~48x than the cumsum impl.

PeaBrane commented 7 months ago

I have test the original mamba implementation. It's so fast! I consider the length=3136, bs=128, and channel=192 for the input x, meanwhile, d_state=16 for B, C. The original impl achieves an inference speed up of ~48x than the cumsum impl.

Are you testing the original impl in training mode or inference mode? The inference (recurrent or online) mode is not comparable to the forward pass for training, because the former is a recurrent step and the latter takes in the full sequence. Either way, neither mamba-minimal nor mamba-tiny is optimized for training or inference, and they are purely pedagogical

wredan commented 6 months ago

I also would like to point out that cumsum implementation is a better way to go if you need to convert mamba-minimal or mamba-tiny to ONNX. The static PyTorch converter says:

It does not record any control-flow, like if-statements or loops

so that with a for loop you lose the dynamic input of sequence length.

The insane speed is tied up to the hardware-aware optimization the author made on the official mamba model, but the use of Triton and the close GPU optimization is preventing me from converting the original model to ONNX with the official PyTorch exporter.

Just leaving it here for someone who needs ONNX model conversion in the future, also thank you guys for mamba-minimal and mamba-tiny, they are so great to understand how mamba works.

DustinEwan commented 6 months ago

I tested out this cumsum approach and found that it doesn't actually produce the same outputs as the standard one in the for loop.

Everything else equal, while the current function is slow it ultimately produces a model with sensible output.

Using @PeaBrane 's cumsum version is multiple orders of magnitude faster, but the model ends up producing mostly nonsensical output.

PeaBrane commented 6 months ago

I tested out this cumsum approach and found that it doesn't actually produce the same outputs as the standard one in the for loop.

Everything else equal, while the current function is slow it ultimately produces a model with sensible output.

Using @PeaBrane 's cumsum version is multiple orders of magnitude faster, but the model ends up producing mostly nonsensical output.

By "nonsensical" do you mean encountering nan or inf, or semantically the outputs are non-sensical. Note the sentence generation script used is stochastic, so everytime the generated outputs is going to be different. That being said, I did encounter some stablity issues when running the logcumsumexp scan on the gpu where it would lead to nan or inf values (but no problem on the cpu)

dftidft commented 2 months ago

There is a problem with this code: dA_cumsum = F.pad(dA[:, 1:], (0, 0, 0, 0, 0, 1)).flip(1).cumsum(1).exp().flip(1) dA[:, 1:] uses the t-th value in the sequence as input when predicting the t-th value.