kevinzakka / recurrent-visual-attention

A PyTorch Implementation of "Recurrent Models of Visual Attention"
MIT License
468 stars 123 forks source link

Question on how to train Location network with detach function #37

Closed lizhenstat closed 3 years ago

lizhenstat commented 3 years ago

Hi, thanks for your work and release of the code, I have one question related with training location network using REINFORCE algorithm. If I understand right, In modules.py , the following part is the implementation for REINFORCE

# compute mean
feat = F.relu(self.fc(h_t.detach())) 
mu = torch.tanh(self.fc_lt(feat))

# reparametrization trick
l_t = torch.distributions.Normal(mu, self.std).rsample() 
l_t = l_t.detach()
log_pi = Normal(mu, self.std).log_prob(l_t)

and for calculating the loss_reinforce and reward, the relevant part is the following

# calculate reward
predicted = torch.max(log_probas, 1)[1]
R = (predicted.detach() == y).float()
R = R.unsqueeze(1).repeat(1, self.num_glimpses)

...
...

# compute reinforce loss
# summed over timesteps and averaged across batch
adjusted_reward = R - baselines.detach()
loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1) # gradient ascent (negative)
loss_reinforce = torch.mean(loss_reinforce, dim=0) 

My question is how do we update parameters in fully connected layer if we detach all the related parameters? I read some examples on REINFORCE algorithm implementation like pytorch document and pytorch REINFORCE official example. however, I still cannot figure out how the detach function works I saw another similar issues #29 and issues #20

Any help would be appreciated and thanks for your time! Best wishes

malashinroman commented 3 years ago

Hi, please, pay attention log_pi is never detached from the computational graph. Therefore we can backpropagate through location network here: loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1) # gradient ascent (negative)

Minimizing this loss will increase probability (log_pi) of selecting action that provided good rewad in the past. (More strictly it will increase or decrease the probability of mapping the particular hidden state vector to the particular action according to the reward of the whole trajectory!)

lizhenstat commented 3 years ago

@malashinroman Hi, thanks for your quick and thorough reply! As I understand now, the detach() is used as follows: (1) two FC layers inside location network is trained through loss_reinforce through "mu" and "h_t" is not influenced by the location network. (2)the FC layer inside baseline network is trained through loss_baseline and not by REINFORCE . Besides, "h_t" is not influenced by the baseline network. Is this right?

Thanks a lot!!

malashinroman commented 3 years ago

Yes, that is right.

h_t is only influenced by supervised loss. I wonder if it is possible to adjust it by locator as well. That refers to (1) in your list.