kevinzakka / recurrent-visual-attention

A PyTorch Implementation of "Recurrent Models of Visual Attention"
MIT License
468 stars 123 forks source link

Log probs calculation is wrong #26

Closed clvcooke closed 4 years ago

clvcooke commented 4 years ago

When calculating the lob probablity of the sample the code currently doesn't take into account that a non-linerarity has occured.

Speicifically: https://github.com/kevinzakka/recurrent-visual-attention/blob/master/model.py#L109

Assumes an untransformed normal distribution. But the sample variables, l_t, has been transformed: https://github.com/kevinzakka/recurrent-visual-attention/blob/master/modules.py#L350

The easy solution to this is calculate the log probs prior applying the non-linearity. Therefore making the location_network return the log_probs and l_t (mu is no longer needed).

This probably hasn't had much of an effect if you're in the linear region of tanh its fine, however it is theoretically incorrect.

kevinzakka commented 4 years ago

Good catch!

I haven't been maintaining this repo in over 2 years, feel free to submit a pull request -- although PyTorch has changed a lot since.