state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
13k stars 1.11k forks source link

Query regarding indexing of smem_delta_a #601

Open SudhanshuBokade opened 3 days ago

SudhanshuBokade commented 3 days ago

From what I understand, smem_delta_a is used to initialize the value of delta_a in shared memory in these lines: https://github.com/state-spaces/mamba/blob/bc84fb1172e6dea04a7dc402118ed19985349e95/csrc/selective_scan/selective_scan_bwd_kernel.cuh#L266-L268

Regarding these lines, I have the following queries:

  1. What is chunk?
  2. Why is (chunk + 1 ) % 2 used for the last thread but not the last chunk (params.n_chunks - 1)? In particular, why is there a modulo 2?
  3. In L268, why is 2 * MAX_DSTATE done, while indexing, when current thread is not the last thread? 1m
tridao commented 3 days ago

There's double buffering: we use 2 * MAX_DSTATE to store for even and odd chunk. For the last chunk there's nothing to load, it's just 1.f.

tridao commented 3 days ago

You can print out all the values in smem_delta_a to debug / understand