Open PrincipalComponent-zz opened 7 years ago
I did not run this original code successfully, and didn't spend time to debug like you did (gread job), even if the code is eligent to read and understand. I think you got both of your points right, are you able to run it now?
@tianyu-tristan I'm working on my own implementation also computing the REINFORCE updates explicitly. I haven't run @yanji84 's code yet, only read through it. I will publish a first version of my code soon -- maybe we can discuss it further then? Will let you know once I'm done.
Hi @yanji84,
first of all compliments on your code, the clear structure makes it easy to understand. However, I think there are two issues with how you compute the policy gradients in the
backward
method of theLocationNet
class:locationNetTrainOp
to optimize the parameters of the location network by REINFORCE (here). You add it to theCoreRNN
here but never execute it as far as I can see (neither inCoreRNN
nor inMain
).you collect all policy gradients and the parameters in a list.
optimize.apply_gradients(gradsAndVars)
does not actually perform gradient updates (it's just an operator) andlocationNetTrainOp
is overwritten on every iteration. Thus, the training operator you return in the end would only apply the updates using gradients of the last time step (i.e. the last element in the listgradsAndVarsAllSteps
. I think instead of the loop it should simply be:Maybe I'm missing something but this should fix things.
Edit: since @tianyu-tristan has worked with your code as well I'm mentioning him as well.