Closed rgiordan closed 2 years ago
If you have the time or interest, I think that an approachable write up on how to use autodiff via torch in R would be widely appreciated as well.
I finally took a look into torch
for R, and I don't think it's a great solution, as it appears to require custom Torch-specific versions of most basic operations. (See the function reference here.)
Consider the following example. The native %*%
matrix multiplication operator does not work with Torch:
library(torch)
x <- c(1, 2)
x_torch <- torch_tensor(x, requires_grad = TRUE)
a <- matrix(c(2, 1, 1, 3), nrow=2)
a_torch <- torch_tensor(a_mat, requires_grad = TRUE)
y <- a %*% x
y_torch <- a_torch %*% x_torch
# Fails with
# Error in a_torch %*% x_torch :
# requires numeric/complex matrix/vector arguments
# Works
y_torch <- torch_matmul(a_torch, x_torch)
If you have to go through your R code and replace every single operation with special torch functions, it's not clear why you're not just writing in another language.
I'll also add that, though I only tried for an hour or so, I couldn't figure out how to get a Hessian matrix out of R torch. It appears to be possible it ordinary Torch (see this Stack Overflow post for example), but I couldn't get the same idea to work in R. Could be for lack of trying, though.
I'm looking into this again and have warmed to the torch solution (rather than hand-code a fix to https://github.com/rgiordan/zaminfluence/issues/15 or do something in C++). As part of this I'll try to write a quick intro to using R torch on my blog.
Unfortunately it seems true that Jacobians, JVPs, and Hessians are still a work in progress for R. But for zaminfluence it's not a deal-breaker.
To answer my past self: the reason you're not just working in another language is so that ordinary R users don't have to install and maintain that other language. :P
Consider the following example. The native
%*%
matrix multiplication operator does not work with Torch:
A workaround for this that I used in my styleganr
(https://github.com/rdinnager/styleganr) package is to define this in the package source:
`%*%.default` <-.Primitive("%*%") # assign default as current definition
`%*%` = function(x, ...){ #make S3
UseMethod("%*%", x)
}
#' @export
`%*%.torch_tensor` <- function(e1, e2) {
if(!is_torch_tensor(e1)) {
e1 <- torch_tensor(e1, device = e2$device)
}
torch_matmul(e1, e2)
}
Seemed to work pretty well for me.
As an aside, I’ve been meaning to suggest to the torch team that they define these kinds of primitive methods in the package namespace (it’s also what Matrix
and various other packages do). But I’m not sure whether it causes complications switching back and forth between R6 and S3/S4.
As an aside, I’ve been meaning to suggest to the torch team that they define these kinds of primitive methods in the package namespace (it’s also what
Matrix
and various other packages do). But I’m not sure whether it causes complications switching back and forth between R6 and S3/S4.
Most operators are defined in torch
already, you can see them here: https://github.com/mlverse/torch/blob/main/R/operators.R
%*%
is a special case for complicated reasons I don't fully understand. See conversation here for more on that:
https://github.com/mlverse/torch/issues/728 (starting at the fourth comment).
Also, at the end of that conversation, downsides to the solution I proposed above are discussed (which comes from here originally: https://stackoverflow.com/questions/40580149/overload-matrix-multiplication-for-s3-class-in-r), which I should have mentioned before.
I'm now (as of https://github.com/rgiordan/zaminfluence/pull/32) using torch to replace the manual derivatives for IV and linear regression, and it's a big win in terms of performance as well as maintainability and readability!
Actually, you can get Hessian matrices in R by differentiating gradients with create_graph=TRUE
. For example:
beta_ad <- torch_tensor(beta, requires_grad=TRUE)
loss <- -1 * EvalLogProb(beta_ad)
grad <- autograd_grad(loss, beta_ad, retain_graph=TRUE, create_graph=TRUE)[[1]]
hess <- matrix(NA, length(beta), length(beta))
for (d in 1:length(grad)) {
hess[d, ] <- autograd_grad(grad[d], beta_ad, retain_graph=TRUE)[[1]] %>% as.numeric()
}
It's probably inefficient relative to what I think is the usual forward mode / reverse mode combination, but maybe using retain_graph
means it's not too bad. (I would love to hear a torch expert's opinion on this.)
I think this means a general purpose version of zaminfluence is possible without waiting for the autograd
functions to be exposed in the R version of torch.
I mentioned this to Rachael the other day on Twitter, but might be useful to add here:
Since autograd appears to be the primary reason for the Python dependency, I'd recommend taking a look at the new(ish) torch implementation for R. This has direct bindings to the underlying libtorch C++ libraries, so no Python install required. The syntax is also pretty similar to the PyTorch implementation, which I'm sure you are familiar with.
Originally posted by @grantmcdermott in https://github.com/rgiordan/zaminfluence/issues/1#issuecomment-745469552