kevinzakka / recurrent-visual-attention

A PyTorch Implementation of "Recurrent Models of Visual Attention"
MIT License
469 stars 124 forks source link

Caveat in last commit #12

Open ipod825 opened 6 years ago

ipod825 commented 6 years ago

https://github.com/kevinzakka/recurrent-visual-attention/commit/99c4cbe089439b83f86922a4e0154923da18a96f#diff-40d9c2c37e955447b1175a32afab171fL353 This is not an unnecessary detach. As it is used in log_pi = Normal(mu, self.std).log_prob(l_t) which is then used in loss_reinforce = torch.sum(-log_pi*adjusted_reward, dim=1) which means when minimizing reinforce loss, you are altering your location network through both mu and l_t (and yes, log_pi is differentiable w.r.t both mu and l_t). However, l_t is just mu+noise and we only want the gradient to flow through mu.

kevinzakka commented 6 years ago

Won't l_t.detach() stop the gradients for l_t and hence for mu? I mean, we'll still have gradient for mu from log_pi but there's a contribution from l_t as well.

This is my understanding:

We want the weights of the location network to be trained using REINFORCE. Now the hidden state vector h_t (detached because we do not want to train the weights of the RNN with REINFORCE) is fed through the fully-connected layer of the location net to produce the mean mu. Then using the reparametrization trick, we sample a location vector l_t from a Gaussian parametrized by mu.

Doing the above, we've made it possible to backpropagate through l_t and hence back to mu which means we can train the weights of the location network.

ipod825 commented 6 years ago

mu and l_t are two separate Variable (though highly correlated). l_t.detach() does not stop you from calculating d loss_reinforce / d mu. For example, consider the following math:

x = 1
y = x + 1
z = x * y

Both dz/dx and dz/dy are well defined. Even if you "detach" y, you can still calculate dz/dx

ipod825 commented 6 years ago

Oh. I think you know what I meant. I'll think more and reply again.

ipod825 commented 6 years ago

I think we shouldn't flow the information through l_t. Intuitively, for loss_reinforce to decrease, we want log_pi to increase. To have log_pi to increase, we want mu and l_t to be closer. Assume mu < l_t, gradient flow then tries to increase mu and decrease l_t simultaneously. However, decreasing l_t essentially decreasing mu as l_t = mu + noise. If you try to deriving the formula for the gradients, one should be the negative of the other as the kernel of Gaussian is (l_t-mu)^2, so they should cancel with each other.

kevinzakka commented 6 years ago

@ipod825 I need to think about it some more. Empirically, I haven't seen a performance difference between the 2. I still reach ~1.3-1.4% error in about 30 epochs of training.

What's bugging me right now is that I learned about the reparametrization trick this weekend, which essentially makes it possible to backprop through a sampled variable. So right now, I'm confused as to why we even need REINFORCE to train our network. We could just use the reparametrization trick like in VAEs to make the whole process differentiable and directly optimize for the weights of the location network.

I'll give it some more thought tonight.

ipod825 commented 6 years ago

Performance issue might not be related to all this formula issue. If you check this thread, you'll see many of the implementations online doesn't even learn anything for their location network but still get good performance on MNIST.

ipod825 commented 6 years ago

Also, I don't think re-parametrization trick applies to this scenario. Re-parametrization requires your target function (in our-scenario, the reward) to be differentiable w.r.t to its parameters. However, our reward is just an unknown function that we don't even have a formula for that.

ipod825 commented 6 years ago

https://github.com/kevinzakka/recurrent-visual-attention/blob/99c4cbe089439b83f86922a4e0154923da18a96f/modules.py#L350 This line is related to this issue. You shouldn't apply tanh on l_t again. Say mu is 100, tanh(mu)=1.0. Even after adding a noinse, tanh(l_t) ~ tanh(1.0) = 0.76159.

A better idea is to use tocrh.clamp(l_t, -1, 1)

kevinzakka commented 6 years ago

@ipod825 The PDF of a normal distribution is not bounded, so it is not guaranteed that l_t will never exceed [-1,1].

I was against using torch.clamp because it is not as smooth as tanh. Why do you think it's a better idea?

ipod825 commented 6 years ago
        mu = F.tanh(self.fc(h_t.detach()))
        # reparametrization trick
        noise = torch.zeros_like(mu)
        noise.data.normal_(std=self.std)
        l_t = mu + noise

        # bound between [-1, 1]
        l_t = F.tanh(l_t)

l_t is squeezed by tanh two times while mu is squeezed only one time. When mu saturates to 1.0, l_t is almost surely to be smaller than 1.0 as I described above. Second, if you modify the code as following

        mu = F.clamp(self.fc(h_t.detach()), -1,1)
        # reparametrization trick
        noise = torch.zeros_like(mu)
        noise.data.normal_(std=self.std)
        l_t = mu + noise

        # bound between [-1, 1]
        l_t = F.clamp(l_t,-1,1)

And do not detach the l_t in

log_pi = Normal(mu, self.std).log_prob(l_t)

You can check that the gradient in the location network is actually 0, as predicted by the discussion above. But if you use tanh, the gradient wouldn't be 0, as mu and l_t is not squeezed in the same way.

xycforgithub commented 6 years ago

@ipod825 Have you tried your implementation using clamp and l_t.detach()? I tried that and got a very high performance on 6 glimpses, 8*8, 1 scale setting, around 0.58%. Paper reported 1.12%.

ipod825 commented 6 years ago

I never got error lower than 1%. If you use only vanilla RNN (as already implemented by @kevinzakka), that would be an interesting result. If you consistently got similar results, it would be nice if you can share your code and let others figure out why it works so well.

sujoyp commented 6 years ago

@kevinzakka

We want the weights of the location network to be trained using REINFORCE. Now the hidden state vector h_t (detached because we do not want to train the weights of the RNN with REINFORCE) is fed through the fully-connected layer of the location net to produce the mean mu. Then using the reparametrization trick, we sample a location vector l_t from a Gaussian parametrized by mu.

Why should we not train the weights of the RNN with REINFORCE ?