Open SudhanshuBokade opened 3 days ago
There's double buffering: we use 2 * MAX_DSTATE to store for even and odd chunk
.
For the last chunk there's nothing to load, it's just 1.f.
You can print out all the values in smem_delta_a
to debug / understand
From what I understand,
smem_delta_a
is used to initialize the value ofdelta_a
in shared memory in these lines: https://github.com/state-spaces/mamba/blob/bc84fb1172e6dea04a7dc402118ed19985349e95/csrc/selective_scan/selective_scan_bwd_kernel.cuh#L266-L268Regarding these lines, I have the following queries:
chunk
?(chunk + 1 ) % 2
used for the last thread but not the last chunk (params.n_chunks - 1
)? In particular, why is there a modulo 2?2 * MAX_DSTATE
done, while indexing, when current thread is not the last thread? 1m