lhnguyen102 / cuTAGI

CUDA implementation of Tractable Approximate Gaussian Inference
MIT License
29 stars 9 forks source link

Update lstm hidden states in backward pass #77

Closed van-dai-vuong closed 1 month ago

van-dai-vuong commented 1 month ago

I was wondering if the code to update the LSTM's hidden states in the below function are correct. The forget, input, cell state gates use delta_m, whereas the output gate uses delta_m_out?

Screenshot 2024-07-15 at 3 27 46 PM Screenshot 2024-07-15 at 3 28 05 PM
lhnguyen102 commented 1 month ago

@van-dai-vuong That is definitely a bug. It should be delta_m_out . Could you make a PR to correct it along with the cuda version. Thanks!

van-dai-vuong commented 1 month ago

@lhnguyen102 I will do it

lhnguyen102 commented 1 month ago

@van-dai-vuong cool. Btw, I will make a PR for including previous states for LSTM tonight

van-dai-vuong commented 1 month ago

@lhnguyen102 thanks a lot