rgiordan / zaminfluence

Tools in R for computing and using Z-estimator approximate influence functions.
Apache License 2.0
94 stars 10 forks source link

Look into R torch implementation for autodiff #12

Closed rgiordan closed 2 years ago

rgiordan commented 3 years ago

I mentioned this to Rachael the other day on Twitter, but might be useful to add here:

Since autograd appears to be the primary reason for the Python dependency, I'd recommend taking a look at the new(ish) torch implementation for R. This has direct bindings to the underlying libtorch C++ libraries, so no Python install required. The syntax is also pretty similar to the PyTorch implementation, which I'm sure you are familiar with.

Originally posted by @grantmcdermott in https://github.com/rgiordan/zaminfluence/issues/1#issuecomment-745469552

alexpghayes commented 3 years ago

If you have the time or interest, I think that an approachable write up on how to use autodiff via torch in R would be widely appreciated as well.

rgiordan commented 3 years ago

I finally took a look into torch for R, and I don't think it's a great solution, as it appears to require custom Torch-specific versions of most basic operations. (See the function reference here.)

Consider the following example. The native %*% matrix multiplication operator does not work with Torch:

library(torch)

x <- c(1, 2)
x_torch <- torch_tensor(x, requires_grad = TRUE)
a <- matrix(c(2, 1, 1, 3), nrow=2)
a_torch <- torch_tensor(a_mat, requires_grad = TRUE)

y <- a %*% x
y_torch <- a_torch %*% x_torch 
# Fails with 
# Error in a_torch %*% x_torch : 
# requires numeric/complex matrix/vector arguments

# Works
y_torch <- torch_matmul(a_torch, x_torch)

If you have to go through your R code and replace every single operation with special torch functions, it's not clear why you're not just writing in another language.

I'll also add that, though I only tried for an hour or so, I couldn't figure out how to get a Hessian matrix out of R torch. It appears to be possible it ordinary Torch (see this Stack Overflow post for example), but I couldn't get the same idea to work in R. Could be for lack of trying, though.

rgiordan commented 2 years ago

I'm looking into this again and have warmed to the torch solution (rather than hand-code a fix to https://github.com/rgiordan/zaminfluence/issues/15 or do something in C++). As part of this I'll try to write a quick intro to using R torch on my blog.

Unfortunately it seems true that Jacobians, JVPs, and Hessians are still a work in progress for R. But for zaminfluence it's not a deal-breaker.

https://github.com/mlverse/torch/issues/738

rgiordan commented 2 years ago

To answer my past self: the reason you're not just working in another language is so that ordinary R users don't have to install and maintain that other language. :P

rdinnager commented 2 years ago

Consider the following example. The native %*% matrix multiplication operator does not work with Torch:

A workaround for this that I used in my styleganr (https://github.com/rdinnager/styleganr) package is to define this in the package source:

`%*%.default` <-.Primitive("%*%") # assign default as current definition
`%*%` = function(x, ...){ #make S3
  UseMethod("%*%", x)
}

#' @export
`%*%.torch_tensor` <- function(e1, e2) {
  if(!is_torch_tensor(e1)) {
    e1 <- torch_tensor(e1, device = e2$device)
  }

  torch_matmul(e1, e2)
}

Seemed to work pretty well for me.

grantmcdermott commented 2 years ago

As an aside, I’ve been meaning to suggest to the torch team that they define these kinds of primitive methods in the package namespace (it’s also what Matrix and various other packages do). But I’m not sure whether it causes complications switching back and forth between R6 and S3/S4.

rdinnager commented 2 years ago

As an aside, I’ve been meaning to suggest to the torch team that they define these kinds of primitive methods in the package namespace (it’s also what Matrix and various other packages do). But I’m not sure whether it causes complications switching back and forth between R6 and S3/S4.

Most operators are defined in torch already, you can see them here: https://github.com/mlverse/torch/blob/main/R/operators.R

%*% is a special case for complicated reasons I don't fully understand. See conversation here for more on that: https://github.com/mlverse/torch/issues/728 (starting at the fourth comment).

Also, at the end of that conversation, downsides to the solution I proposed above are discussed (which comes from here originally: https://stackoverflow.com/questions/40580149/overload-matrix-multiplication-for-s3-class-in-r), which I should have mentioned before.

rgiordan commented 2 years ago

I'm now (as of https://github.com/rgiordan/zaminfluence/pull/32) using torch to replace the manual derivatives for IV and linear regression, and it's a big win in terms of performance as well as maintainability and readability!

rgiordan commented 2 years ago

Actually, you can get Hessian matrices in R by differentiating gradients with create_graph=TRUE. For example:

    beta_ad <- torch_tensor(beta, requires_grad=TRUE)
    loss <- -1 * EvalLogProb(beta_ad)
    grad <- autograd_grad(loss, beta_ad, retain_graph=TRUE, create_graph=TRUE)[[1]]
    hess <- matrix(NA, length(beta), length(beta))
    for (d in 1:length(grad)) {
        hess[d, ] <- autograd_grad(grad[d], beta_ad, retain_graph=TRUE)[[1]] %>% as.numeric()
    }

It's probably inefficient relative to what I think is the usual forward mode / reverse mode combination, but maybe using retain_graph means it's not too bad. (I would love to hear a torch expert's opinion on this.)

I think this means a general purpose version of zaminfluence is possible without waiting for the autograd functions to be exposed in the R version of torch.

rgiordan commented 2 years ago

I wrote a blog post about torch in R. I now think the torch package is the way forward, and I'll close this issue.