Open WUHU-G opened 7 months ago
const int n_chunks = (seqlen + 2048 - 1) / 2048
We do the scan on chunks of length at most 2048 (i.e. process one chunk at a time before moving to the next chunk). It's just for performance reason to fit in the resources of a threadblock of 128 threads. E.g. if you want to sum up 1M numbers with 128 workers, you might use those 128 workers to add up 2k numbers, then add the next 2k numbers, and so on.
Thank you, I now understand the role of chunk. I still have a question:
x = torch::empty({batch_size, dim, n_chunks, dstate 2... This variable x should be the SSM state vector. If the chunk is very large, for example, 512, it will require more storage space for the variable x in the code (batch_size dim 512 dstate * 2), because the size of this x variable depends on the chunk. Will this become a bottleneck for the model to run long sequence data?
------------------ 原始邮件 ------------------ 发件人: "state-spaces/mamba" @.>; 发送时间: 2024年3月29日(星期五) 中午11:51 @.>; @.**@.>; 主题: Re: [state-spaces/mamba] What is the use of params.n_chunks? (Issue #273)
const int n_chunks = (seqlen + 2048 - 1) / 2048 We do the scan on chunks of length at most 2048 (i.e. process one chunk at a time before moving to the next chunk). It's just for performance reason to fit in the resources of a threadblock of 128 threads. E.g. if you want to sum up 1M numbers with 128 workers, you might use those 128 workers to add up 2k numbers, then add the next 2k numbers, and so on.
— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>
This takes less space than the input for most values of dstate.
Hello I was wondering about params.n_chunks. What does the n_chunks variable do, and why does it have anything to do with the length of the sequence