tristandeleu / pytorch-maml-rl

Reinforcement Learning with Model-Agnostic Meta-Learning in Pytorch
MIT License
827 stars 158 forks source link

Question : hessian_vector_product in MetaLearner needed for TRPO, or MAML? #30

Closed eugval closed 5 years ago

eugval commented 5 years ago

Hello!

Great implementation, thank you for putting this out there!

I am using it in order to get a design framework to make a supervised learning MAML implementation, and I have a quick question on the outer loop gradient:

From my understanding, the hessian_vector_product calculation is only needed as part of the TRPO implementation, in order to do conjugate gradients and line search. Is that right?

What I mean is that if I want to do supervised learning, I can just use the autograd.grad(create_graph = True) trick in order to create gradients that I can back propagate through, and then in the outer loop just use the standard pytorch Adam implementation right?

(and I know that this is getting slightly out of topic for this repo, apologies, but this should also extend seemingly to having multiple inner updates right?I'd just need to set create_graph = True to all of them)

I will appreciate any input on these a lot! Again, great repo! Thanks!!

tristandeleu commented 5 years ago

Thank you for the kind words!

From my understanding, the hessian_vector_product calculation is only needed as part of the TRPO implementation, in order to do conjugate gradients and line search. Is that right?

That's right, the Hessian-vector product function is mainly used in the conjugate gradient function to compute the update of TRPO.

What I mean is that if I want to do supervised learning, I can just use the autograd.grad(create_graph = True) trick in order to create gradients that I can back propagate through, and then in the outer loop just use the standard pytorch Adam implementation right?

Exactly, you'd only need to use autograd.grad (still with create_graph=True if you don't want the first-order approximation), and any standard Pytorch optimizer will do for the outer loop.

(and I know that this is getting slightly out of topic for this repo, apologies, but this should also extend seemingly to having multiple inner updates right?I'd just need to set create_graph = True to all of them)

Yes it does extend nicely to the case where you have multiple inner updates (maybe even this repo for RL could, but I haven't checked that carefully). One thing to note though is that your update_params function (or something along these lines) should take as input the values of the parameters at the previous step of the inner update. I made that mistake originally in my implementation of MAML supervised.