musically-ut / tf_rmtpp

Recurrent Marked Temporal Point Processes
MIT License
55 stars 16 forks source link

Support for LSTM/GRUs #7

Closed ak-7 closed 5 years ago

ak-7 commented 5 years ago

Hello,

Are you planning to extend support to LSTM/GRUs? From your code I see all the RNN functionality is not encapsulated in any RNN class. Currently, hidden states creation and updation code is intertwined with other variables. Do you have something in mind for easiest way to extend it to LSTMs + GRUs?

musically-ut commented 5 years ago

Such an extension is not in the pipeline.

The functionality is not embedded inside an RNNCell because truncated BPTT, which was needed for longer sequences, is much trickier to implement in that setting. You can see examples of such cells in TPPRL, where we did not employ BPTT but instead kept the episode lengths controlled.

The extension to LSTMs + GRUs is not too difficult, though, because the only thing which would change is the formula for calculating h[i+1] from h[i], i.e. Eq. (9) in the paper.

musically-ut commented 5 years ago

Please close the issue if the question was answered.

Thanks!

ak-7 commented 5 years ago

Thanks!

Could you elaborate a bit on why BPTT is trickier to implement with RNNCell? My understanding is in a batch if you've reached the end of a sequence, the events will be zero and we can ignore those updates. What am I missing?

musically-ut commented 5 years ago

RNNCell can indeed handle end of sequences and, hence, can work with batches with sequences of different sizes. However, truncated-BPTT is about limiting the number of elements over which the gradient is calculated, which is a very different problem, requires blocking the flow of gradients after N steps (where N is the truncation horizon for BPTT). This can be done using 'tf.stop_gradients` but it requires injecting extra information into the Cell state and is messy in general.

I hope that makes the problem with BPTT clearer.

That said, I haven't looked around to see if there are any good ways of doing BPTT while using RNNCell with the recent versions of TensorFlow (or PyTorch). If you find such an example, please do share it with me and I would be happy to take a look.

ak-7 commented 5 years ago

Thanks for the explanation.