Closed jtoyama4 closed 7 years ago
KL is used to compute the hessian (as in the original code).
Hi, I've tested the get_kl() function in various environments, but the return is always 0. The following is an example of get_kl() outputs in BipedalWalker-v2:
*** kl: Variable containing: 0 0 0 0 0 0 0 0 0 0 0 0 ⋮
0 0 0 0 0 0 0 0 0 0 0 0 [torch.DoubleTensor of size 15160x4]*** kl sum: Variable containing: 0 0 0 ⋮ 0 0 0 [torch.DoubleTensor of size 15160]
*** kl mean: Variable containing: 0 [torch.DoubleTensor of size 1]
** kl grad: Variable containing: 1.00000e-34 2.4390 [torch.DoubleTensor of size 1]
** kl grad grad: Variable containing: 1.00000e-34 2.4390 [torch.DoubleTensor of size 1]
('lagrange multiplier:', 2.509155592601353e-17, 'grad_norm:', 4.4529719813727475e-18) fval before 2.0545072021432404e-13 a/e/r 2.1300610441917076e-13 5.018311185202932e-19 424457.7439662247 fval after -7.555384204846721e-15
Does this make sense?
@jtoyama4 Hi, have you figure out how does get_kl() work? I have the same question with you.
@pxlong I think what get_kl() does is to get the gradient of kl (for hessian computing) , and kl-constraining part is somehow working with ratio in def linesearch
but I really do not understand it theoretically.
@jtoyama4, thanks for a quick reply. but what get_kl() return is always 0 how can you get a valid/useful gradient of it? I am a little confused.
@pxlong a simple example:
In this case we have something like this f(x)=(x_0^2 - x^2), f(x_0) = 0 but f'(x_0)=-2x_0.
The function is not f(x) == 0 but it has a value at one specific point == 0.
@ikostrikov, thanks for your explanation. But I've tested this implementation in various envs, such as BipedalWalker-v2, MountainCarContinuous-v0, Pendulum-v0 (except Reacher-v1), none of them gives reasonable results (i.e. the agent learns nothing during training).
To debug it, I added some print as you can see belowing:
def Fvp(v):
kl = get_kl()
kl = kl.mean()
print('*** kl mean: ', kl)
grads = torch.autograd.grad(kl, model.parameters(), create_graph=True)
flat_grad_kl = torch.cat([grad.view(-1) for grad in grads])
kl_v = (flat_grad_kl * Variable(v)).sum()
print('*** kl_v: ', kl_v)
grads = torch.autograd.grad(kl_v, model.parameters())
flat_grad_grad_kl = torch.cat([grad.view(-1) for grad in grads]).data
print('*** flat grad grad kl: ', flat_grad_grad_kl)
but all the tensors are zero (showed in the second message of this issue).
I am wondering what the problem of the bad performance on these environments?
Check hyperparams from the original implementation (modular rl). And also estimates of how long does convergence take. Default hyperparams of this code are tuned specifically for reacher-v1
ok, thanks.
This part is used to compute the hessian of KL. KL itself == 0, the derivative of the KL == 0 but the hessian is not.
This is the reason why we have to compute a second order approximation of the KL terms. Because its first order approximation is equal to zero.
@pxlong Sorry, that it took me so long to fix the bug.
It didn't work because they've changed default argument values for some functions in PyTorch recently.
@pxlong And anyone else, I found the get_kl
is related to the statement from (Schulman et al., 2015) [TRPO] Trust Region Policy Optimization:
computing the Hessian of DKL with respect to θ
It still feels non-intuitive to me, but I guess the goal is auto diff / hessian calc vs getting an actual value out of get_kl
.
Also For two univariate normal distributions p and q the above simplifies t has the math that looks directly related to the code.
The paper cites Numerical Optimization so I guess I have some reading to do :)
*edit:
So I if I understand correctly get_kl
simply structures the kl for auto grad. The actual 2nd hessian is built in Fvp
Hi. Thanks for publishing implementation of trpo.
I have question about get_kl().
I thought what get_kl() is supposed to do is to calculate the kl divergence of old policy and new policy, but this get_kl() seems always returning 0.
Also,I do not see kl constraining part in the parameters updating process.
Is this code the modification of trpo or do I have some misunderstanding?
Thanks,