state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.99k stars 1.1k forks source link

What is the use of params.n_chunks? #273

Open WUHU-G opened 7 months ago

WUHU-G commented 7 months ago

Hello I was wondering about params.n_chunks. What does the n_chunks variable do, and why does it have anything to do with the length of the sequence

tridao commented 7 months ago

const int n_chunks = (seqlen + 2048 - 1) / 2048 We do the scan on chunks of length at most 2048 (i.e. process one chunk at a time before moving to the next chunk). It's just for performance reason to fit in the resources of a threadblock of 128 threads. E.g. if you want to sum up 1M numbers with 128 workers, you might use those 128 workers to add up 2k numbers, then add the next 2k numbers, and so on.

WUHU-G commented 6 months ago

Thank you, I now understand the role of chunk. I still have a question:

x = torch::empty({batch_size, dim, n_chunks, dstate 2... This variable x should be the SSM state vector. If the chunk is very large, for example, 512, it will require more storage space for the variable x in the code (batch_size dim 512 dstate * 2), because the size of this x variable depends on the chunk. Will this become a bottleneck for the model to run long sequence data?

------------------ 原始邮件 ------------------ 发件人: "state-spaces/mamba" @.>; 发送时间: 2024年3月29日(星期五) 中午11:51 @.>; @.**@.>; 主题: Re: [state-spaces/mamba] What is the use of params.n_chunks? (Issue #273)

const int n_chunks = (seqlen + 2048 - 1) / 2048 We do the scan on chunks of length at most 2048 (i.e. process one chunk at a time before moving to the next chunk). It's just for performance reason to fit in the resources of a threadblock of 128 threads. E.g. if you want to sum up 1M numbers with 128 workers, you might use those 128 workers to add up 2k numbers, then add the next 2k numbers, and so on.

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>

tridao commented 6 months ago

This takes less space than the input for most values of dstate.