mlverse / torch

R Interface to Torch
https://torch.mlverse.org
Other
500 stars 68 forks source link

`set_num_threads` on Linux does not seem to work #1178

Open MrDomani opened 5 months ago

MrDomani commented 5 months ago

Hi,

I'm using a fairly up-to-date Manjaro Linux. I've noticed that R's torch does not seem utilize my CPU (AMD Ryzen 7, series 5000) to its full extent. Further, using torch_set_num_threads does not seem to take any effect, as the code takes roughly the same amount of time. An equivalent Python code does not have these issues.

I'm attaching a reproducible example, modelled after one available at documentation of Python torch library. Let me know whether the setup is correct and whether You observe a similar effect on Your side.

size <- 1024
set.seed(2024)
X <- matrix(runif(size^2), size, size)
Y <- matrix(runif(size^2), size, size)

for(n_threads in c(1,2,3,4,5,6,7,8)){
  torch::torch_set_num_threads(n_threads)
  library(torch)
  message <- paste0("Number of threads: ", torch_get_num_threads(), "\n")
  cat(message)
  t1 <- torch_tensor(X)
  t2 <- torch_tensor(Y)
  time_start <- Sys.time()
  out <- microbenchmark::microbenchmark(torch_matmul(t1,t2), times = 7)
  time_stop <- Sys.time()
  print((time_stop - time_start) / 7)
  print(out)
  detach(package:torch)
}

print(sessionInfo())

Output :

Number of threads: 1
Time difference of 0.4125318 secs
Unit: milliseconds
                 expr      min       lq     mean   median       uq      max neval
 torch_matmul(t1, t2) 385.0982 386.6031 392.8908 387.9618 390.5209 422.9279     7
Number of threads: 2
Time difference of 0.4029093 secs
Unit: milliseconds
                 expr      min       lq     mean   median       uq      max neval
 torch_matmul(t1, t2) 383.9286 385.3504 387.0242 387.8626 388.6715 389.3341     7
Number of threads: 3
Time difference of 0.412262 secs
Unit: milliseconds
                 expr      min       lq     mean   median       uq      max neval
 torch_matmul(t1, t2) 390.5602 393.3822 395.6617 394.6437 397.4655 402.7326     7
Number of threads: 4
Time difference of 0.4133101 secs
Unit: milliseconds
                 expr      min       lq     mean   median       uq      max neval
 torch_matmul(t1, t2) 389.1392 389.7427 398.4571 390.4203 390.6904 448.7739     7
Number of threads: 5
Time difference of 0.4015262 secs
Unit: milliseconds
                 expr      min       lq     mean   median       uq      max neval
 torch_matmul(t1, t2) 383.1957 385.1875 386.9126 387.8565 388.3859 390.1893     7
Number of threads: 6
Time difference of 0.4085842 secs
Unit: milliseconds
                 expr      min       lq     mean   median       uq      max neval
 torch_matmul(t1, t2) 387.7209 390.1149 391.9693 391.7547 393.8607 396.3586     7
Number of threads: 7
Time difference of 0.4061831 secs
Unit: milliseconds
                 expr      min       lq     mean   median       uq      max neval
 torch_matmul(t1, t2) 389.0678 389.5376 391.5841 390.2344 392.4632 397.7849     7
Number of threads: 8
Time difference of 0.4225501 secs
Unit: milliseconds
                 expr      min    lq    mean   median       uq      max neval
 torch_matmul(t1, t2) 396.3549 398.5 406.618 400.3949 404.1815 444.2131     7
R version 4.4.0 (2024-04-24)
Platform: x86_64-pc-linux-gnu
Running under: Manjaro Linux

Matrix products: default
BLAS:   /usr/lib/libblas.so.3.12.0 
LAPACK: /usr/lib/liblapack.so.3.12.0

locale:
 [1] LC_CTYPE=en_GB.UTF-8       LC_NUMERIC=C               LC_TIME=pl_PL.UTF-8       
 [4] LC_COLLATE=en_GB.UTF-8     LC_MONETARY=pl_PL.UTF-8    LC_MESSAGES=en_GB.UTF-8   
 [7] LC_PAPER=pl_PL.UTF-8       LC_NAME=C                  LC_ADDRESS=C              
[10] LC_TELEPHONE=C             LC_MEASUREMENT=pl_PL.UTF-8 LC_IDENTIFICATION=C       

time zone: Europe/Warsaw
tzcode source: system (glibc)

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

loaded via a namespace (and not attached):
 [1] microbenchmark_1.4.10 processx_3.8.4        bit_4.0.5             compiler_4.4.0       
 [5] magrittr_2.0.3        cli_3.6.2             tools_4.4.0           rstudioapi_0.16.0    
 [9] torch_0.13.0          Rcpp_1.0.12           bit64_4.0.5           coro_1.0.4           
[13] callr_3.7.6           ps_1.7.6              rlang_1.1.3