mlverse / torch

R Interface to Torch
https://torch.mlverse.org
Other
495 stars 65 forks source link

Runtime #268

Open MaximilianPi opened 4 years ago

MaximilianPi commented 4 years ago

Hi all, first of all, I am very excited about this project because I’m already using pytorch in my own R package (via the torch pip wheel) and the prospect to use torch natively without the python intermediate step is very appealing.

I use pytorch more for smaller statistical models (datasets can still be very large) where the overhead plays an important role (e.g. the reimplementation of my core model in my pkg from R6 classes with the imported torch python module (via reticulate) to native python classes which are then imported into R via reticulate::import_from_path reduced the runtime on average by 30%, even for large datasets).

I compared r-torch, python-torch (written in python and imported into R), and imported-torch (torch was imported into R and code was written in R) by fitting a small neural network with 4 layers (4 layers, in sum 450 weights) and benchmarking the training loop. I found that the native python implementation is 20x times faster than the r-torch loop and even the imported torch training loop is 5x times faster than the r-torch loop:

image (millisenconds) Do have any ideas why the torch pkg is so much slower for smaller networks (at least I assume that this applies only for small models)?

R-Torch Code:

library(torch)
set.seed(1)

X = matrix(1.0, 200, 20)
Y = matrix(0.5, 200, 1)

net = nn_module(
  initialize = function() {
      self$fc1 = nn_linear(20, 10)
      self$fc2 = nn_linear(10, 10)
      self$fc3 = nn_linear(10, 10)
      self$fc4 = nn_linear(10, 1)
  },

  forward = function(x) {
    x %>% 
      self$fc1() %>%
      nnf_relu() %>%
      self$fc2() %>%
      nnf_relu() %>%
      self$fc3() %>%
      nnf_relu() %>%
      self$fc4()
  }
)
model <- net()
opt = optim_adam(model$parameters)
XT = torch_tensor(X, dtype=torch_float32())
YT = torch_tensor(Y, dtype=torch_float32())

result = 
    microbenchmark::microbenchmark({

        for(i in 1:50) {
            opt$zero_grad()
            pred = model(XT)
            loss = nnf_mse_loss(pred, YT)
            loss$backward()
            opt$step()
        }
    })

imported torch:

library(reticulate)
torch = import("torch")
set.seed(1)

X = matrix(1.0, 200, 20)
Y = matrix(0.5, 200, 1)

nn = torch$nn

model = nn$Sequential(
  (nn$Linear(20L, 10L)),
  (nn$ReLU()),
  (nn$Linear(10L, 10L)),
  (nn$ReLU()),
  (nn$Linear(10L, 10L)),
  (nn$ReLU()),
  (nn$Linear(10L, 1L))
)

XT = torch$tensor(X, dtype = torch$float32)
YT = torch$tensor(Y, dtype = torch$float32)

opt = torch$optim$Adam(model$parameters())

result = 
    microbenchmark::microbenchmark({
            for(i in 1:50) {
                opt$zero_grad()
                pred = model(XT)
                loss = torch$nn$functional$mse_loss(pred, YT)
                loss$backward()
                opt$step()
            }
    })

Native python: A) python part

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(20, 10)
        self.fc2 = nn.Linear(10, 10)
        self.fc3 = nn.Linear(10, 10)
        self.fc4 = nn.Linear(10, 1)

    def forward(self, X):
        x = F.relu( self.fc1(X) )
        x = F.relu( self.fc2(x) )
        x = F.relu( self.fc3(x) )
        x = self.fc4(x)
        return x

class Train:
    def __init__(self):
        self.net = Net()

    def train(self, X, Y):
        opt = torch.optim.Adam(self.net.parameters())

        for i in range(50):
            opt.zero_grad()
            Pred = self.net(X)
            loss = F.mse_loss(Pred, Y)
            loss.backward()
            opt.step()

B) R part:

library(reticulate)
torch = import("torch")
set.seed(1)

X = matrix(1.0, 200, 20)
Y = matrix(0.5, 200, 1)

py_torch = reticulate::import_from_path("python", "python/")
net = py_torch$Train()
XT = torch$tensor(X, dtype = torch$float32)
YT = torch$tensor(Y, dtype = torch$float32)

result = 
    microbenchmark::microbenchmark({
        net$train(XT, YT)
    })

Session Info:

R version 3.6.3 (2020-02-29)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 20.04 LTS

Matrix products: default
BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.9.0
LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.9.0

locale:
 [1] LC_CTYPE=C.UTF-8       LC_NUMERIC=C           LC_TIME=C.UTF-8       
 [4] LC_COLLATE=C.UTF-8     LC_MONETARY=C.UTF-8    LC_MESSAGES=C.UTF-8   
 [7] LC_PAPER=C.UTF-8       LC_NAME=C              LC_ADDRESS=C          
[10] LC_TELEPHONE=C         LC_MEASUREMENT=C.UTF-8 LC_IDENTIFICATION=C   

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

loaded via a namespace (and not attached):
[1] compiler_3.6.3

Ubuntu is running as WSL2

MaximilianPi commented 4 years ago

Profiles from r-torch and imported torch: image

image

Beside the fact that all steps are slower than the imported python-torch, the step method from the optimizer is disproportionaley slower than the other function calls.

dfalbel commented 4 years ago

Hi @MaximilianPi ,

Thanks for trying torch and for your benchmarks! The main agressor is the dispatcher used to decide how to pass an R object to LibTorch. It's defined here:

https://github.com/mlverse/torch/blob/master/R/codegen-utils.R#L34

We will eventually rewrite it in C++ to reduce the overhead. After that we still expect that r-torch is still a little slower than raw PyTorch, but not in the 20x scale.

Also, this overhead should be proportionally smaller with larger models and larger batch sizes as the actual computation will take longer than the dispatcher overhead.

MaximilianPi commented 4 years ago

Hi @dfalbel,

thanks for your quick response!

Yeah, I played around and for larger problems there is already no notable difference between r-torch and imported python torch.

MaximilianPi commented 3 years ago

update 02/2021: I repeated the benchmarks with the current development version (altough my hardware has changed, but since we are more interested in relative runtimes between the different options that shouldn't matter):

image

Profiles: imported torch (reticulate) image r-torch image

it seems like you were able to reduce the overhead by 50%! For small models, r-torch is now only 10x times slower than the native python runtime. Very cool!

dfalbel commented 3 years ago

Last few PR's should have reduced R overhead a little bit more. We will keep improving it :)

MaximilianPi commented 3 years ago

update 04/2021:

(my hardware has changed again...)

image

In short, r-torch/imported-torch: 4.924 -> 3.661 -> 3.369

Also, the advantage (overhead) of native pytorch over r-torch decreased von 19x to 12x (applies only for small models, ofc).

Profiles: imported torch: image

r-torch: image

Cool! I saw that you were successfull with moving your dispatcher to cpp. Memory and time usage (relatively to opt$zero_grad) dropped significantly (especially the lower memory usage is very nice!)

mikeyEcology commented 3 years ago

Hi, I have a similar question about runtime. I'm deploying a model in R that I trained in Pytorch and it's taking a bit longer to deploy in R (on the CPU). It is very possible that I'm not setting up my deployment function properly. This is my first attempt at this, so please let me know if I can improve the function. It seems like using torch::with_no_grad in evaluation is the R equivalent of with torch.no_grad(), but if I leave this out of my function, it does not affect inference time. For now I'm focusing on using the CPU. Here is the function I'm using:

# deployment function
deploy_model <- function(model, dl, device, num_classes=10, labeled=TRUE, gpu=FALSE){

  # length of data
  len_data <- length(dl)

  # make output table
  tbl_out <- matrix(NA, len_data, (num_classes+2))

  # send model to device
  if(gpu){
    model$to(device=device)
  }

  # add progress bar
  #pb = utils::txtProgressBar(min = 0, max = len_data, initial = 0) 

  # set model to evaluation 
  model$eval()

  # loop through data loader
  z <- torch::enumerate(dl)
  toc <- Sys.time()
  torch::with_no_grad({ # Do I need this step? It doesn't seem to increase speed
    for(i in seq_along(z)){
      batch <- z[[i]]

      if(gpu){
        x <- batch[[1]]$to(device=device)
      } else {
        x <- batch[[1]]
      }

      # run model
      output <- model(x)

      # run softmax
      y_out <- softmax(as.numeric(output[1]))

      # get ground truth if theres a label
      y_gt <- ifelse(labeled, 
                     as.numeric(batch[[2]]), NA)

      # make a row of the output for this batch
      row_out <- c(batch[[3]],
                   y_gt,
                   y_out)

      # put this into out table
      tbl_out[i,] <- row_out

      # update progress bar
      #utils::setTxtProgressBar(pb,i)

    } # end loop through batch
  }) # end no grad

  tic <- Sys.time()
  # return temporal information
  runtime <- tic-toc
  timeper <- runtime/len_data
  print(paste0("inference time of: ", timeper, " seconds per sample."))

  return(tbl_out)
}

In R the average time per sample is 0.22 and in Python it is 0.14. It is possible that the difference is due to the C++ reason described above, and this would make sense. But it would also be helpful to know if I am doing something with the enumerate function, or somewhere else, that is slowing things down. Sorry if this is a trivial question, but I've had trouble finding other examples of folks deploying models in this new package. Thank you in advance .

dfalbel commented 3 years ago

You code chunk looks correct to me, and i'd say it's expected that R is a bit slower than python for the same task. Just a few comments:

Hope this helps!

mikeyEcology commented 3 years ago

Great! Thanks a lot for your input! I didn't realize there was a torch_softmax function so I'll incorporate that too. I really appreciate your quick responses!