seann999 / tensorflow_mnist_ram

An attempt to implement the recurrent attention model (RAM) from "Recurrent Models of Visual Attention" (Mnih+ 2014)
MIT License
44 stars 19 forks source link

Backprop for Reinforce #4

Open peyush opened 7 years ago

peyush commented 7 years ago

Hi, Acc. to the paper, only the location network will be trained with REINFORCE, i.e the weights where you calculate the next set of locations in the RNN. While the rest of the network gets trained with gradients of the cross entropy loss only, like in a normal RNN. But in the implementation, all weights are getting trained with the REINFORCE gradient. Am i missing something?

jtkim-kaist commented 7 years ago

Acc. to line 181 in ram.py the c.e loss and location policy*reward were summed and the summation operator is just gradient router when backpropa phase, i.e. c.e. part and location policy part were not influenced each other due to summation operator. Therefore, the core network was trained with gradients of c.e. loss and location network was trained by policy gradient. I hope that this explanation is the intent of the author of this code.