state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
11.74k stars 966 forks source link

Questions about Chunk_size using Triton optimization in SSD kernel #449

Open AlwaysFHao opened 1 week ago

AlwaysFHao commented 1 week ago

Thank you for your great work,I found that in Mamba2, chunk_size defaults to 256, while my sequence length is only 200 and still runs normally. In issue #439 , you pointed out that this is due to the optimization of the SSD kernel (Padding on sequence length?) to meet the requirement of seq_len% chunk_size=0, but I didn't quite understand the content related to pointer offset in the implementation of triton. Can you give me some answers on how to implement seq_len=200 and chunk_size=256 in the above example to run normally? Additionally, I attempted to set chunk_size to 50 and 40 to accommodate the divisible seq_len=200 in my task (although your perfectly implemented kernel can ignore this issue), but a loss=nan situation occurred. Afterwards, I checked your kernel source code and found that there were block optimizations in the triton section, all of which were powers of 2. Therefore, I attempted to set chunk_size to 64, there was no loss=nan issue. However, due to the issue of precompiling the Triton source code, I am not sure how to debug the SSD kernel in your original version. Therefore, I would like to ask if it is best to set chunk_size to the power of 2 for the Triton block optimization you have implemented in your version of the SSD kernel? I am deeply sorry to disturb you with such a foolish question, but I hope you can give me some answers. Thank you very much!

tridao commented 1 week ago

Yes chunk size should be a power of 2, that's what Triton supports. To deal with seqlen not divisible by chunk_size, we load with a mask. Anything outside the seqlen is masked to be zero here for example: https://github.com/state-spaces/mamba/blob/8ffd905c91d207f5c0cc84fc2a2fb748655094f0/mamba_ssm/ops/triton/ssd_chunk_state.py#L225 There are similar masks for other triton kernels in this repo.

AlwaysFHao commented 1 week ago

Yes chunk size should be a power of 2, that's what Triton supports. To deal with seqlen not divisible by chunk_size, we load with a mask. Anything outside the seqlen is masked to be zero here for example:

https://github.com/state-spaces/mamba/blob/8ffd905c91d207f5c0cc84fc2a2fb748655094f0/mamba_ssm/ops/triton/ssd_chunk_state.py#L225

There are similar masks for other triton kernels in this repo.

I see! Thank you for your prompt response and perfect work!