kevinzakka / recurrent-visual-attention

A PyTorch Implementation of "Recurrent Models of Visual Attention"
MIT License
468 stars 123 forks source link

formula of loss_reinforce #10

Closed ipod825 closed 6 years ago

ipod825 commented 6 years ago

https://github.com/kevinzakka/recurrent-visual-attention/blob/b659b6ff06561d073320b8123811ee738f968d9f/trainer.py#L389 image

According to the paper's formula, the gradient is summed over samples and time steps but only averaged over samples. So I think it's more appropriate to calculate loss_reinforce as

loss_reinforce = torch.sum(-log_pi*adjusted_reward, dim=1)                                    
loss_reinforce = torch.mean(loss_reinforce)

Though it's just a matter of a scaler and should be absorbed by self-adjustable optimizer... What do you think?

kevinzakka commented 6 years ago

Great catch dude! I completely agree that it's a matter of scale and wouldn't change the result much with an optimizer like Adam. Cheers :)

kevinzakka commented 6 years ago

Running a quick test with SGD and momentum, it's doing worse than before haha. Validation acc peaked at 50% but has been decreasing ever since. Going to switch to Adam and if the accuracy increases, I'll switch back to SGD but add a more aggressive lr decay schedule.