Open ipod825 opened 6 years ago
Won't l_t.detach()
stop the gradients for l_t
and hence for mu
? I mean, we'll still have gradient for mu
from log_pi
but there's a contribution from l_t
as well.
This is my understanding:
We want the weights of the location network to be trained using REINFORCE. Now the hidden state vector h_t
(detached because we do not want to train the weights of the RNN with REINFORCE) is fed through the fully-connected layer of the location net to produce the mean mu
. Then using the reparametrization trick, we sample a location vector l_t
from a Gaussian parametrized by mu
.
Doing the above, we've made it possible to backpropagate through l_t
and hence back to mu
which means we can train the weights of the location network.
mu
and l_t
are two separate Variable
(though highly correlated). l_t.detach()
does not stop you from calculating d loss_reinforce / d mu
. For example, consider the following math:
x = 1
y = x + 1
z = x * y
Both dz/dx
and dz/dy
are well defined. Even if you "detach" y, you can still calculate dz/dx
Oh. I think you know what I meant. I'll think more and reply again.
I think we shouldn't flow the information through l_t
.
Intuitively, for loss_reinforce
to decrease, we want log_pi
to increase.
To have log_pi
to increase, we want mu
and l_t
to be closer.
Assume mu
< l_t
, gradient flow then tries to increase mu
and decrease l_t
simultaneously. However, decreasing l_t
essentially decreasing mu
as l_t = mu + noise
.
If you try to deriving the formula for the gradients, one should be the negative of the other
as the kernel of Gaussian is (l_t-mu)^2
, so they should cancel with each other.
@ipod825 I need to think about it some more. Empirically, I haven't seen a performance difference between the 2. I still reach ~1.3-1.4% error in about 30 epochs of training.
What's bugging me right now is that I learned about the reparametrization trick this weekend, which essentially makes it possible to backprop through a sampled variable. So right now, I'm confused as to why we even need REINFORCE to train our network. We could just use the reparametrization trick like in VAEs to make the whole process differentiable and directly optimize for the weights of the location network.
I'll give it some more thought tonight.
Performance issue might not be related to all this formula issue. If you check this thread, you'll see many of the implementations online doesn't even learn anything for their location network but still get good performance on MNIST.
Also, I don't think re-parametrization trick applies to this scenario. Re-parametrization requires your target function (in our-scenario, the reward) to be differentiable w.r.t to its parameters. However, our reward is just an unknown function that we don't even have a formula for that.
https://github.com/kevinzakka/recurrent-visual-attention/blob/99c4cbe089439b83f86922a4e0154923da18a96f/modules.py#L350
This line is related to this issue.
You shouldn't apply tanh on l_t
again. Say mu
is 100, tanh(mu)=1.0. Even after adding a noinse, tanh(l_t) ~ tanh(1.0) = 0.76159.
A better idea is to use tocrh.clamp(l_t, -1, 1)
@ipod825 The PDF of a normal distribution is not bounded, so it is not guaranteed that l_t
will never exceed [-1,1].
I was against using torch.clamp
because it is not as smooth as tanh
. Why do you think it's a better idea?
mu = F.tanh(self.fc(h_t.detach()))
# reparametrization trick
noise = torch.zeros_like(mu)
noise.data.normal_(std=self.std)
l_t = mu + noise
# bound between [-1, 1]
l_t = F.tanh(l_t)
l_t
is squeezed by tanh two times while mu
is squeezed only one time.
When mu
saturates to 1.0, l_t
is almost surely to be smaller than 1.0 as I described above.
Second, if you modify the code as following
mu = F.clamp(self.fc(h_t.detach()), -1,1)
# reparametrization trick
noise = torch.zeros_like(mu)
noise.data.normal_(std=self.std)
l_t = mu + noise
# bound between [-1, 1]
l_t = F.clamp(l_t,-1,1)
And do not detach the l_t
in
log_pi = Normal(mu, self.std).log_prob(l_t)
You can check that the gradient in the location network is actually 0, as predicted by the discussion above. But if you use tanh
, the gradient wouldn't be 0, as mu
and l_t
is not squeezed in the same way.
@ipod825 Have you tried your implementation using clamp and l_t.detach()? I tried that and got a very high performance on 6 glimpses, 8*8, 1 scale setting, around 0.58%. Paper reported 1.12%.
I never got error lower than 1%. If you use only vanilla RNN (as already implemented by @kevinzakka), that would be an interesting result. If you consistently got similar results, it would be nice if you can share your code and let others figure out why it works so well.
@kevinzakka
We want the weights of the location network to be trained using REINFORCE. Now the hidden state vector
h_t
(detached because we do not want to train the weights of the RNN with REINFORCE) is fed through the fully-connected layer of the location net to produce the meanmu
. Then using the reparametrization trick, we sample a location vectorl_t
from a Gaussian parametrized bymu
.
Why should we not train the weights of the RNN with REINFORCE ?
https://github.com/kevinzakka/recurrent-visual-attention/commit/99c4cbe089439b83f86922a4e0154923da18a96f#diff-40d9c2c37e955447b1175a32afab171fL353 This is not an unnecessary detach. As it is used in log_pi = Normal(mu, self.std).log_prob(l_t) which is then used in loss_reinforce = torch.sum(-log_pi*adjusted_reward, dim=1) which means when minimizing reinforce loss, you are altering your location network through both mu and l_t (and yes, log_pi is differentiable w.r.t both mu and l_t). However, l_t is just mu+noise and we only want the gradient to flow through mu.