hidasib / GRU4Rec

GRU4Rec is the original Theano implementation of the algorithm in "Session-based Recommendations with Recurrent Neural Networks" paper, published at ICLR 2016 and its follow-up "Recurrent Neural Networks with Top-k Gains for Session-based Recommendations". The code is optimized for execution on the GPU.
Other
754 stars 223 forks source link

Do errors backpropagate Through Time? #8

Closed Songweiping closed 7 years ago

Songweiping commented 7 years ago

Hi @hidasib ,

I have questions about GRU4Rec training. Since the input length in a mini-batch is set to 1, does it mean that there is no errors backpropagate through time as in standard RNN? If not, could you please help me out and explain the training procedure?

Thanks

hidasib commented 7 years ago

Error is not propagated through time in the public version. The reason behind this is that we tried propagating through time (BPTT) and it gave us bad results. While this seems counterintuitive when compared to other fields, such as NLP, it makes sense on session data. In real life session data there are lots of very short sessions (e.g. 2-3 clicks).

There are three choices: 1.) Use BPTT on fixed length sub-sessions with length X. This allows us to use proper mini-batch training and propagates through time, but we have to discard sessions shorter than X. 2.) Use BPTT on full sessions, which have different lengths. This allows us to use all data and propagate through time, and while mini-batching is not strictly impossible, it has several problems (e.g. imagine using a very long (e.g. length 500+) and several very short (e.g. length 2-3) sessions in the same mini-batch). 3.) Don't propagate through time. We tested all three options. 1.) gave worse results than 3.), because we lose more with discarding the data than what we gain by doing BPTT. This is true, even if we filter the test set for longer sessions. It is entirely possible that some items occur only in short sessions in the training set; if we discard the data, those items won't be modelled at all. We even tried doing mixed training, i.e. training without BPTT then with sub-sessions of length 2 then length 3, etc. We did this in different orders and alternations, but differences with the base method were within -5% to +1%. 2.) was very unstable, due to also having some very long sessions, through which we need to propagate the error. There might be a way to make 2.) work, but I wouldn't expect great improvement from it.