mlverse / torch

R Interface to Torch
https://torch.mlverse.org
Other
482 stars 65 forks source link

autograd functional? #738

Open rdinnager opened 2 years ago

rdinnager commented 2 years ago

I am interested in some of the functionality of torch.autograd.functional (https://github.com/pytorch/pytorch/blob/master/torch/autograd/functional.py) in Pytorch for R (especially jvp() and jacobian()). Is there plans for adding this functionality to R torch? I was thinking about trying to do it myself and making a PR, but I don't want to duplicate any effort, and make sure it would fit into the plan for torch. Since autograd.functional appears to only make use of existing torch functions, it wouldn't require any updates to lantern.

dfalbel commented 2 years ago

Yeah! We are interested in adding this to torch but have a few other priorities first in the line. It would be awesome to have your contribution and I can definitely help discussing and reviewing PR's.

Linking to #561 . I'll close #561 in favor of this issue.

rdinnager commented 2 years ago

Just having another look at the Python version of this. Interestingly the jacobian() function here: https://github.com/pytorch/pytorch/blob/00ebbd5ef640f5647400c0a4cb75c4e1f1f539ef/torch/autograd/functional.py#L481, has a 'forward-mode' option, which used forward mode AD for the jacobian. I honestly didn't even know that Pytorch could do forward mode AD. It looks like the infrastructure for this though is deep in unexported territory (in the torch._C module, details here: https://github.com/pytorch/pytorch/blob/master/torch/autograd/forward_ad.py), and so would probably require significant updates to lantern. I am not sure I am up for that undertaking at the moment, so how do you feel about only supporting the 'reverse-mode' for this function for now, which can be accomplished using only the current exported API of torch?

dfalbel commented 2 years ago

Oh I didn't know that was possible either. Sure, there's no problem in supporting only reverse-mode in a first iteration.

rdinnager commented 2 years ago

I'm getting set to make a start on this and wanted to ask what prefix would be good for the autograd.functional functions? autograd_functional_ seems a little long to me. Following the pattern of nn.functional, which has the prefix nnf_, it should be autogradf_, but we could potentially shorten it more as autogf_ or even agf_. So we would have functions like agf_jacobian(). What do you think?

dfalbel commented 2 years ago

I like agf_ as a prefix. Although since the autograd_ namespace is quite small, I think we could also omit the functional references and use autograd_ for them too. So it could be named autograd_jacobian, autograd_hessian, etc.

rgiordan commented 2 years ago

I too could really use the autograd_ functionals for zaminfluence. I thought I'd check in on the status of this issue and to offer to help if I can!

rdinnager commented 2 years ago

Hi @rgiordan . Yes, I started working on this a while ago but then got pulled away by another project. I could certainly use help.. I've been working on this branch on my own fork: https://github.com/rdinnager/torch/tree/feature/autograd_functional . I have just been copying Python code into R source files and then going through line by line and translating into R as best I can. Let me know if you want to work on a function or two and we can probably work on porting different functions.

skeydan commented 2 years ago

@rgiordan I was pointed to your blog post (https://rgiordan.github.io/code/2022/04/01/rtorch_example.html) by @dfalbel ... this is very nice! :-) Perhaps it's of interest to hear that there's an optim_lbfgs() in torch - with line search ... There's a post about it here: https://blogs.rstudio.com/ai/posts/2021-04-22-torch-for-optimization/

rgiordan commented 2 years ago

@skeydan Oh, thank you for pointing that out to me. I probably just browsed through the names and didn't realize it did line search as well --- that should be a big improvement!

For what it's worth, my go-to optimization method is Newton trust region with CG (e.g. scipy.optimize.minimize with trust-ncg and Hesian-vector products provided by autodiff). If you know of an R implementation / port of that method, I'd be very glad to know...