state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.98k stars 1.1k forks source link

Question regarding what thread_reverse_data[i].x in "selective_scan_bwd_kernel.cuh" file contains #598

Open SudhanshuBokade opened 5 days ago

SudhanshuBokade commented 5 days ago

At line no.285 , it is given dx = thread_reverse_data[i].y , but according to my calculation of gradient of x it should be dx = dout . C should be there , also it seems to be as in code

https://github.com/state-spaces/mamba/blob/bc84fb1172e6dea04a7dc402118ed19985349e95/csrc/selective_scan/selective_scan_bwd_kernel.cuh#L285

but according to my calculation of gradients dx = dout . C should be there , also it seems to be because it is given in file at line no. 260-263

https://github.com/state-spaces/mamba/blob/bc84fb1172e6dea04a7dc402118ed19985349e95/csrc/selective_scan/selective_scan_bwd_kernel.cuh#L260C8-L264C8

but there is reverse_scan Operation on thread_reverse_data after that , so Does thread_reverse_data after the Reverse_Scan op contains dy .C ?

Thank you very much for your help

tridao commented 3 days ago

dx in the code might actually be gradient wrt to the hidden states, so maybe dh is a better var name. At some point in the process of writing up the paper we changed the notation.

SudhanshuBokade commented 2 days ago

Sorry, I was not clear. I understand that dx is gradient wrt hidden states. My question is whether thread_reverse_data[I].y contains dy * C after the Reverse Scan op? It is initialized to dy * C in these lines:

https://github.com/state-spaces/mamba/blob/bc84fb1172e6dea04a7dc402118ed19985349e95/csrc/selective_scan/selective_scan_bwd_kernel.cuh#L260-L263

SudhanshuBokade commented 1 day ago

Thanks for the help @tridao

Can you please also take a look at this question?

https://github.com/state-spaces/mamba/issues/598#issuecomment-2426704304