Open SudhanshuBokade opened 5 days ago
dx
in the code might actually be gradient wrt to the hidden states, so maybe dh
is a better var name.
At some point in the process of writing up the paper we changed the notation.
Sorry, I was not clear. I understand that dx is gradient wrt hidden states. My question is whether thread_reverse_data[I].y
contains dy * C
after the Reverse Scan op? It is initialized to dy * C
in these lines:
Thanks for the help @tridao
Can you please also take a look at this question?
https://github.com/state-spaces/mamba/issues/598#issuecomment-2426704304
At line no.285 , it is given dx = thread_reverse_data[i].y , but according to my calculation of gradient of x it should be dx = dout . C should be there , also it seems to be as in code
https://github.com/state-spaces/mamba/blob/bc84fb1172e6dea04a7dc402118ed19985349e95/csrc/selective_scan/selective_scan_bwd_kernel.cuh#L285
but according to my calculation of gradients dx = dout . C should be there , also it seems to be because it is given in file at line no. 260-263
https://github.com/state-spaces/mamba/blob/bc84fb1172e6dea04a7dc402118ed19985349e95/csrc/selective_scan/selective_scan_bwd_kernel.cuh#L260C8-L264C8
but there is reverse_scan Operation on thread_reverse_data after that , so Does thread_reverse_data after the Reverse_Scan op contains dy .C ?
Thank you very much for your help