snowkylin / ntm

TensorFlow implementation of Neural Turing Machines (NTM), with its application on one-shot learning (MANN)
GNU Lesser General Public License v3.0
183 stars 65 forks source link

Using MANN in A3C (reinforcement learning model) #21

Open TheMnBN opened 5 years ago

TheMnBN commented 5 years ago

Hi there,

I'm trying to integrate a memory network into an A3C agent. For reference, I followed closely this implementation of A3C: https://github.com/awjuliani/DeepRL-Agents/blob/master/A3C-Doom.ipynb

My aim is to replace the LSTM layer with a MANN module. This might be a far-fetched question but do you have any advice for me when refactoring your MANN implementation for my particular purpose?

snowkylin commented 5 years ago

Generally speaking, MANN is not so easy to get converged as other RNN models are, and a blind combination can result in severe instability of training. I take a lot of time to finally get it converged on the omniglot dataset demonstrated in the original paper. So please prepare enough time and patience, and you may need to adjust the model to fit your task. Good luck!

TheMnBN commented 5 years ago

Thanks so much for replying! You're absolutely correct. RL by itself can already go horribly wrong under various (and usually unknown) circumstances. I couldn't find any working implementation of memory-augmented RL models (open-sourced or from authors of original papers) so I have to do it myself. Naively combining memory net to RL is not technically a well-motivated approach but I'm still implementing it as baseline for my research.

If you don't mind keeping this issue thread open, I would like to continue this discussion here.

TheMnBN commented 5 years ago

I have 1 operation tf.nn.dynamic_rnn in my computation graph. I'm thinking of replacing that op with tf.while_loop whose body is the MANN operations. Do you think this approach makes sense? I'm aware that you used 'for' loop in your model so I will try both and see which works. Either way, I need to find a way to terminate the loop, i.e. define a condition for tf.while_loop or a sequence length for 'for' loop.