Closed eugval closed 5 years ago
Thank you for the kind words!
From my understanding, the
hessian_vector_product
calculation is only needed as part of the TRPO implementation, in order to do conjugate gradients and line search. Is that right?
That's right, the Hessian-vector product function is mainly used in the conjugate gradient function to compute the update of TRPO.
What I mean is that if I want to do supervised learning, I can just use the
autograd.grad(create_graph = True)
trick in order to create gradients that I can back propagate through, and then in the outer loop just use the standard pytorch Adam implementation right?
Exactly, you'd only need to use autograd.grad
(still with create_graph=True
if you don't want the first-order approximation), and any standard Pytorch optimizer will do for the outer loop.
(and I know that this is getting slightly out of topic for this repo, apologies, but this should also extend seemingly to having multiple inner updates right?I'd just need to set
create_graph = True
to all of them)
Yes it does extend nicely to the case where you have multiple inner updates (maybe even this repo for RL could, but I haven't checked that carefully). One thing to note though is that your update_params
function (or something along these lines) should take as input the values of the parameters at the previous step of the inner update. I made that mistake originally in my implementation of MAML supervised.
Hello!
Great implementation, thank you for putting this out there!
I am using it in order to get a design framework to make a supervised learning MAML implementation, and I have a quick question on the outer loop gradient:
From my understanding, the
hessian_vector_product
calculation is only needed as part of the TRPO implementation, in order to do conjugate gradients and line search. Is that right?What I mean is that if I want to do supervised learning, I can just use the
autograd.grad(create_graph = True)
trick in order to create gradients that I can back propagate through, and then in the outer loop just use the standard pytorch Adam implementation right?(and I know that this is getting slightly out of topic for this repo, apologies, but this should also extend seemingly to having multiple inner updates right?I'd just need to set
create_graph = True
to all of them)I will appreciate any input on these a lot! Again, great repo! Thanks!!