OscarKjell / text

Using Transformers from HuggingFace in R
https://r-text.org
137 stars 31 forks source link

Mac M1/M2 GPU Support #46

Closed adamramey closed 1 year ago

adamramey commented 1 year ago

Hello - I've managed to get the package up and running on my M1 Mac but GPU support isn't working for some reason. I get the error "Unable to use CUDA (GPU), using CPU"

Any idea how we can set this up properly?

MattCowgill commented 1 year ago

Same problem here. Anyone got suggestions for a fix?

Thanks for a great package @OscarKjell

sjgiorgi commented 1 year ago

Unfortunately, I'm unable to test this but I've attempted to add support for this. If you are able to install text via Github (as opposed to CRAN) then you can pull the latest code (I've attempted to push a quick fix) and try running

textEmbed(...., device='mps', ...)

Note that MPS acceleration is available on MacOS 12.3+.

Also, I think you will need PyTorch v1.12 and maybe even >=1.13. We haven't tested text on either of these versions as of writing this.

MattCowgill commented 1 year ago

Thank you @sjgiorgi , will do

OscarKjell commented 1 year ago

Thanks a lot for this @sjgiorgi! Zhoujun will soon make a pull request with an update that works! I think you will have to set device = "mps:0" in textEmbed().

moomoofarm1 commented 1 year ago

One can try with device = "mps" in textEmbed(). It still works.

MattCowgill commented 1 year ago

Yay, I'm happy to confirm this works for me (M1 MacBook Pro, {text} 0.9.99.9)

> # Example text
> texts <- c("I feel great!")
> 
> # Defaults
> embeddings <- textEmbed(texts,
+                         device = "mps:0")
Completed layers output for  (variable: 1/1, duration: 6.218717 secs).
Completed layers aggregation for word_type_embeddings. 
Completed layers aggregation (variable 1/1, duration: 0.063519 secs).
Completed layers aggregation (variable 1/1, duration: 0.057727 secs).
MPS_for_MacM1+_available: True
Using mps!
sjgiorgi commented 1 year ago

@MattCowgill Thanks for reporting back!

@moomoofarm1 @OscarKjell Have we confirmed that we are seeing increased speed when using this?

MattCowgill commented 1 year ago

If anything @sjgiorgi I see a slight slowdown with device = "mps:0"

Please see reprex with timings below

library(text)
#> This is text (version 0.9.99.9).
#> Text is new and still rapidly improving.
#>                
#> Newer versions may have improved functions and updated defaults to reflect current understandings of the state-of-the-art.
#>                Please send us feedback based on your experience.
#> 
#> Please note that defaults has changed in the textEmbed-functions since last version; see help(textEmbed) or www.r-text.org for more details.

textrpp_initialize()
#> 
#> Successfully initialized text required python packages.
#> 
#> Python options: 
#>  type = "textrpp_condaenv", 
#>  name = "textrpp_condaenv".

texts <- c("I feel great!",
           "I don't feel so good",
           "Water is wet")

bench::mark(textEmbed(texts,
                      device = "mps:0"),
            textEmbed(texts,
                      device = "cpu"),
            min_iterations = 3,
            max_iterations = 3,
            check = FALSE)
#> Completed layers output for  (variable: 1/1, duration: 8.813812 secs).
#> Completed layers aggregation for word_type_embeddings. 
#> Completed layers aggregation (variable 1/1, duration: 0.610342 secs).
#> Completed layers aggregation (variable 1/1, duration: 0.444646 secs).
#> Completed layers output for  (variable: 1/1, duration: 6.690455 secs).
#> Completed layers aggregation for word_type_embeddings. 
#> Completed layers aggregation (variable 1/1, duration: 0.307038 secs).
#> Completed layers aggregation (variable 1/1, duration: 0.340210 secs).
#> Completed layers output for  (variable: 1/1, duration: 6.106852 secs).
#> Completed layers aggregation for word_type_embeddings. 
#> Completed layers aggregation (variable 1/1, duration: 0.284889 secs).
#> Completed layers aggregation (variable 1/1, duration: 0.255946 secs).
#> Completed layers output for  (variable: 1/1, duration: 6.581084 secs).
#> Completed layers aggregation for word_type_embeddings. 
#> Completed layers aggregation (variable 1/1, duration: 0.232711 secs).
#> Completed layers aggregation (variable 1/1, duration: 0.291940 secs).
#> Completed layers output for  (variable: 1/1, duration: 6.734517 secs).
#> Completed layers aggregation for word_type_embeddings. 
#> Completed layers aggregation (variable 1/1, duration: 0.290402 secs).
#> Completed layers aggregation (variable 1/1, duration: 0.328100 secs).
#> Completed layers output for  (variable: 1/1, duration: 5.629423 secs).
#> Completed layers aggregation for word_type_embeddings. 
#> Completed layers aggregation (variable 1/1, duration: 0.225370 secs).
#> Completed layers aggregation (variable 1/1, duration: 0.227123 secs).
#> Completed layers output for  (variable: 1/1, duration: 5.711532 secs).
#> Completed layers aggregation for word_type_embeddings. 
#> Completed layers aggregation (variable 1/1, duration: 0.226697 secs).
#> Completed layers aggregation (variable 1/1, duration: 0.227989 secs).
#> Completed layers output for  (variable: 1/1, duration: 6.046991 secs).
#> Completed layers aggregation for word_type_embeddings. 
#> Completed layers aggregation (variable 1/1, duration: 0.542720 secs).
#> Completed layers aggregation (variable 1/1, duration: 0.239926 secs).
#> 
#> Warning: Some expressions had a GC in every iteration; so filtering is
#> disabled.
#> # A tibble: 2 × 6
#>   expression                              min   median `itr/sec` mem_a…¹ gc/se…²
#>   <bch:expr>                         <bch:tm> <bch:tm>     <dbl> <bch:b>   <dbl>
#> 1 textEmbed(texts, device = "mps:0")    7.37s    7.82s     0.129  43.1MB   0.516
#> 2 textEmbed(texts, device = "cpu")      6.78s    6.86s     0.142    36MB   0.613
#> # … with abbreviated variable names ¹​mem_alloc, ²​`gc/sec`

Created on 2023-04-18 with reprex v2.0.2

Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.2.2 (2022-10-31) #> os macOS Ventura 13.2.1 #> system aarch64, darwin20 #> ui X11 #> language (EN) #> collate en_US.UTF-8 #> ctype en_US.UTF-8 #> tz Australia/Perth #> date 2023-04-18 #> pandoc 2.19.2 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/ (via rmarkdown) #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> package * version date (UTC) lib source #> bench 1.1.2 2021-11-30 [1] CRAN (R 4.2.0) #> brio 1.1.3 2021-11-30 [1] CRAN (R 4.2.0) #> class 7.3-20 2022-01-16 [1] CRAN (R 4.2.2) #> cli 3.6.0 2023-01-09 [1] CRAN (R 4.2.0) #> codetools 0.2-18 2020-11-04 [1] CRAN (R 4.2.2) #> colorspace 2.1-0 2023-01-23 [1] CRAN (R 4.2.0) #> cowplot 1.1.1 2020-12-30 [1] CRAN (R 4.2.0) #> dials 1.1.0 2022-11-04 [1] CRAN (R 4.2.2) #> DiceDesign 1.9 2021-02-13 [1] CRAN (R 4.2.0) #> digest 0.6.31 2022-12-11 [1] CRAN (R 4.2.0) #> dplyr 1.1.0 2023-01-29 [1] CRAN (R 4.2.0) #> evaluate 0.20 2023-01-17 [1] CRAN (R 4.2.0) #> fansi 1.0.4 2023-01-22 [1] CRAN (R 4.2.0) #> fastmap 1.1.0 2021-01-25 [1] CRAN (R 4.2.0) #> foreach 1.5.2 2022-02-02 [1] CRAN (R 4.2.0) #> fs 1.6.1 2023-02-06 [1] CRAN (R 4.2.0) #> furrr 0.3.1 2022-08-15 [1] CRAN (R 4.2.0) #> future 1.31.0 2023-02-01 [1] CRAN (R 4.2.0) #> future.apply 1.10.0 2022-11-05 [1] CRAN (R 4.2.2) #> generics 0.1.3 2022-07-05 [1] CRAN (R 4.2.0) #> ggplot2 3.4.0.9000 2023-02-02 [1] Github (tidyverse/ggplot2@882584f) #> globals 0.16.2 2022-11-21 [1] CRAN (R 4.2.0) #> glue 1.6.2 2022-02-24 [1] CRAN (R 4.2.0) #> gower 1.0.1 2022-12-22 [1] CRAN (R 4.2.0) #> GPfit 1.0-8 2019-02-08 [1] CRAN (R 4.2.0) #> gtable 0.3.1 2022-09-01 [1] CRAN (R 4.2.0) #> hardhat 1.2.0 2022-06-30 [1] CRAN (R 4.2.0) #> here 1.0.1 2020-12-13 [1] CRAN (R 4.2.0) #> htmltools 0.5.4 2022-12-07 [1] CRAN (R 4.2.0) #> ipred 0.9-13 2022-06-02 [1] CRAN (R 4.2.0) #> iterators 1.0.14 2022-02-05 [1] CRAN (R 4.2.0) #> jsonlite 1.8.4 2022-12-06 [1] CRAN (R 4.2.0) #> knitr 1.42 2023-01-25 [1] CRAN (R 4.2.0) #> lattice 0.20-45 2021-09-22 [1] CRAN (R 4.2.2) #> lava 1.7.1 2023-01-06 [1] CRAN (R 4.2.0) #> lhs 1.1.6 2022-12-17 [1] CRAN (R 4.2.0) #> lifecycle 1.0.3 2022-10-07 [1] CRAN (R 4.2.0) #> listenv 0.9.0 2022-12-16 [1] CRAN (R 4.2.0) #> lubridate 1.9.1 2023-01-24 [1] CRAN (R 4.2.0) #> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.2.0) #> MASS 7.3-58.1 2022-08-03 [1] CRAN (R 4.2.2) #> Matrix 1.5-1 2022-09-13 [1] CRAN (R 4.2.2) #> munsell 0.5.0 2018-06-12 [1] CRAN (R 4.2.0) #> nnet 7.3-18 2022-09-28 [1] CRAN (R 4.2.2) #> overlapping 2.1 2022-12-14 [1] CRAN (R 4.2.0) #> parallelly 1.34.0 2023-01-13 [1] CRAN (R 4.2.0) #> parsnip 1.0.3 2022-11-11 [1] CRAN (R 4.2.0) #> pillar 1.8.1 2022-08-19 [1] CRAN (R 4.2.0) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.2.0) #> png 0.1-8 2022-11-29 [1] CRAN (R 4.2.0) #> prodlim 2019.11.13 2019-11-17 [1] CRAN (R 4.2.0) #> profmem 0.6.0 2020-12-13 [1] CRAN (R 4.2.0) #> purrr 1.0.1 2023-01-10 [1] CRAN (R 4.2.0) #> R.cache 0.16.0 2022-07-21 [1] CRAN (R 4.2.0) #> R.methodsS3 1.8.2 2022-06-13 [1] CRAN (R 4.2.0) #> R.oo 1.25.0 2022-06-12 [1] CRAN (R 4.2.0) #> R.utils 2.12.2 2022-11-11 [1] CRAN (R 4.2.0) #> R6 2.5.1 2021-08-19 [1] CRAN (R 4.2.0) #> Rcpp 1.0.10 2023-01-22 [1] CRAN (R 4.2.0) #> recipes 1.0.4 2023-01-11 [1] CRAN (R 4.2.0) #> reprex 2.0.2 2022-08-17 [1] CRAN (R 4.2.0) #> reticulate 1.28 2023-01-27 [1] CRAN (R 4.2.0) #> rlang 1.0.6 2022-09-24 [1] CRAN (R 4.2.0) #> rmarkdown 2.20 2023-01-19 [1] CRAN (R 4.2.0) #> rpart 4.1.19 2022-10-21 [1] CRAN (R 4.2.2) #> rprojroot 2.0.3 2022-04-02 [1] CRAN (R 4.2.0) #> rsample 1.1.1 2022-12-07 [1] CRAN (R 4.2.0) #> rstudioapi 0.14 2022-08-22 [1] CRAN (R 4.2.0) #> scales 1.2.1 2022-08-20 [1] CRAN (R 4.2.0) #> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.2.0) #> stringi 1.7.12 2023-01-11 [1] CRAN (R 4.2.0) #> styler 1.9.0 2023-01-15 [1] CRAN (R 4.2.0) #> survival 3.4-0 2022-08-09 [1] CRAN (R 4.2.2) #> testthat 3.1.6 2022-12-09 [1] CRAN (R 4.2.0) #> text * 0.9.99.9 2023-04-18 [1] Github (oscarkjell/text@2131237) #> tibble 3.1.8 2022-07-22 [1] CRAN (R 4.2.0) #> tidyr 1.3.0 2023-01-24 [1] CRAN (R 4.2.0) #> tidyselect 1.2.0 2022-10-10 [1] CRAN (R 4.2.0) #> timechange 0.2.0 2023-01-11 [1] CRAN (R 4.2.0) #> timeDate 4022.108 2023-01-07 [1] CRAN (R 4.2.0) #> tune 1.0.1 2022-10-09 [1] CRAN (R 4.2.0) #> utf8 1.2.3 2023-01-31 [1] CRAN (R 4.2.0) #> vctrs 0.5.2 2023-01-23 [1] CRAN (R 4.2.0) #> withr 2.5.0 2022-03-03 [1] CRAN (R 4.2.0) #> workflows 1.1.2 2022-11-16 [1] CRAN (R 4.2.0) #> xfun 0.37 2023-01-31 [1] CRAN (R 4.2.0) #> yaml 2.3.7 2023-01-23 [1] CRAN (R 4.2.0) #> yardstick 1.1.0 2022-09-07 [1] CRAN (R 4.2.0) #> #> [1] /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/library #> #> ─ Python configuration ─────────────────────────────────────────────────────── #> python: /Users/mcowgill/Library/r-miniconda-arm64/envs/textrpp_condaenv/bin/python #> libpython: /Users/mcowgill/Library/r-miniconda-arm64/envs/textrpp_condaenv/lib/libpython3.9.dylib #> pythonhome: /Users/mcowgill/Library/r-miniconda-arm64/envs/textrpp_condaenv:/Users/mcowgill/Library/r-miniconda-arm64/envs/textrpp_condaenv #> version: 3.9.0 | packaged by conda-forge | (default, Nov 26 2020, 07:55:15) [Clang 11.0.0 ] #> numpy: /Users/mcowgill/Library/r-miniconda-arm64/envs/textrpp_condaenv/lib/python3.9/site-packages/numpy #> numpy_version: 1.24.2 #> #> NOTE: Python version was forced by use_python function #> #> ────────────────────────────────────────────────────────────────────────────── ```
MattCowgill commented 1 year ago

This is with torch 1.13.0 FYI

OscarKjell commented 1 year ago

I have not been able to see improvements in time using mps:1 – but the good news is that @moomoofarm1 has improved a sorting function that reduces the time by about half when using textEmbed() or textEmbedRawLayers() with text data of more than 200 rows of the example data.

We have also added the sort parameter in textEmbedRawLayers(), which is set to FALSE return results in lists (rather than tidy format), which reduces the time a lot.

# Example data
texts <- bind_rows(Language_based_assessment_data_8[1],
                   Language_based_assessment_data_8[1],
                   Language_based_assessment_data_8[1],
                   Language_based_assessment_data_8[1],
                   Language_based_assessment_data_8[1])
texts

t1 <- Sys.time()
emb_old <- textEmbedRawLayers_OLD(texts)
t2 <- Sys.time()
t_old <- t2-t1

t3 <- Sys.time()
emb_mps <- text::textEmbedRawLayers(texts)
t4 <- Sys.time()
t_new <- t4 - t3

t5 <- Sys.time()
emb_fsort_mps <- textEmbedRawLayers(texts, sort = FALSE, device = "mps:1")
t6 <- Sys.time()
t_fsort_mps <- t6 - t5

t7 <- Sys.time()
emb_fsort <- textEmbedRawLayers(texts, sort = FALSE)
t8 <- Sys.time()
t_fsort <- t8 - t7

# 200 rows of text t_old Time difference of 2.868509 mins

t_new Time difference of 1.200108 mins

t_fsort_mps Time difference of 56.20257 secs

t_fsort Time difference of 43.55411 secs

# 600 rows of text

t_old Time difference of 6.698215 mins t_new Time difference of 2.765697 mins t_fsort_mps Time difference of 1.481872 mins t_fsort Time difference of 1.153369 mins

# 1200 rows of text

t_old Time difference of 14.84135 mins t_new Time difference of 6.070847 mins t_fsort_mps Time difference of 1.947607 mins t_fsort Time difference of 2.087477 mins