yanji84 / deep-recurrent-attention-model

Apply reinforcement learning to visual attention
18 stars 5 forks source link

Policy gradients bug in LocationNet ? #1

Open PrincipalComponent-zz opened 7 years ago

PrincipalComponent-zz commented 7 years ago

Hi @yanji84,

first of all compliments on your code, the clear structure makes it easy to understand. However, I think there are two issues with how you compute the policy gradients in the backward method of the LocationNet class:

  1. you define a training operator locationNetTrainOp to optimize the parameters of the location network by REINFORCE (here). You add it to the CoreRNN here but never execute it as far as I can see (neither in CoreRNN nor in Main).
  2. you collect all policy gradients and the parameters in a list.

    gradsAndVarsAllSteps = [zip(g, params) for g in policyGradients]
    for gradsAndVars in gradsAndVarsAllSteps:
        locationNetTrainOp = optimize.apply_gradients(gradsAndVars)

    optimize.apply_gradients(gradsAndVars) does not actually perform gradient updates (it's just an operator) and locationNetTrainOp is overwritten on every iteration. Thus, the training operator you return in the end would only apply the updates using gradients of the last time step (i.e. the last element in the list gradsAndVarsAllSteps. I think instead of the loop it should simply be:

    gradsAndVarsAllSteps = [zip(g, params) for g in policyGradients]            
    locationNetTrainOp = optimize.apply_gradients(gradsAndVarsAllSteps)

Maybe I'm missing something but this should fix things.

Edit: since @tianyu-tristan has worked with your code as well I'm mentioning him as well.

tianyu-tristan commented 7 years ago

I did not run this original code successfully, and didn't spend time to debug like you did (gread job), even if the code is eligent to read and understand. I think you got both of your points right, are you able to run it now?

PrincipalComponent-zz commented 7 years ago

@tianyu-tristan I'm working on my own implementation also computing the REINFORCE updates explicitly. I haven't run @yanji84 's code yet, only read through it. I will publish a first version of my code soon -- maybe we can discuss it further then? Will let you know once I'm done.