greta-dev / greta

simple and scalable statistical modelling in R
https://greta-stats.org
Other
524 stars 63 forks source link

MCMC fails with TF2 - Evaluation error: AttributeError: module 'tensorflow_probability.python.bijectors' has no attribute 'AffineScalar' #549

Closed njtierney closed 1 year ago

njtierney commented 2 years ago
library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply
greta_sitrep()
#> ℹ checking if python available
#> ✔ python (v3.8) available
#> 
#> ℹ checking if TensorFlow available
#> ✔ TensorFlow (v2.9.2) available
#> 
#> ℹ checking if TensorFlow Probability available
#> ✔ TensorFlow Probability (v0.17.0) available
#> 
#> ℹ checking if greta conda environment available
#> ✔ greta conda environment available
#> 
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#> 
#> ℹ greta is ready to use!
x <- normal(0,1)
m <- model(x)
#> Loaded Tensorflow version 2.9.2
draws <- mcmc(m)
#> Called from: force(expr)
#> running 4 chains simultaneously on up to 8 CPU cores
#> 
#>     warmup                                           0/1000 | eta:  ?s              warmup ==                                       50/1000 | eta: 13s              warmup ====                                    100/1000 | eta:  8s              warmup ======                                  150/1000 | eta:  6s              warmup ========                                200/1000 | eta:  5s              warmup ==========                              250/1000 | eta:  4s              warmup ===========                             300/1000 | eta:  3s              warmup =============                           350/1000 | eta:  3s              warmup ===============                         400/1000 | eta:  3s              warmup =================                       450/1000 | eta:  2s              warmup ===================                     500/1000 | eta:  2s              warmup =====================                   550/1000 | eta:  2s              warmup =======================                 600/1000 | eta:  2s              warmup =========================               650/1000 | eta:  1s              warmup ===========================             700/1000 | eta:  1s              warmup ============================            750/1000 | eta:  1s              warmup ==============================          800/1000 | eta:  1s              warmup ================================        850/1000 | eta:  1s              warmup ==================================      900/1000 | eta:  0s              warmup ====================================    950/1000 | eta:  0s              warmup ====================================== 1000/1000 | eta:  0s          
#>   sampling                                           0/1000 | eta:  ?s            sampling ==                                       50/1000 | eta:  1s            sampling ====                                    100/1000 | eta:  1s            sampling ======                                  150/1000 | eta:  1s            sampling ========                                200/1000 | eta:  1s            sampling ==========                              250/1000 | eta:  1s            sampling ===========                             300/1000 | eta:  0s            sampling =============                           350/1000 | eta:  0s            sampling ===============                         400/1000 | eta:  0s            sampling =================                       450/1000 | eta:  0s            sampling ===================                     500/1000 | eta:  0s            sampling =====================                   550/1000 | eta:  0s            sampling =======================                 600/1000 | eta:  0s            sampling =========================               650/1000 | eta:  0s            sampling ===========================             700/1000 | eta:  0s            sampling ============================            750/1000 | eta:  0s            sampling ==============================          800/1000 | eta:  0s            sampling ================================        850/1000 | eta:  0s            sampling ==================================      900/1000 | eta:  0s            sampling ====================================    950/1000 | eta:  0s            sampling ====================================== 1000/1000 | eta:  0s

# what about lognormal?
ln <- lognormal(0,1)
m_ln <- model(ln)
draws_ln <- mcmc(m_ln)
#> Called from: force(expr)
#> Error in py_call_impl(callable, dots$args, dots$keywords): RuntimeError: Evaluation error: AttributeError: module 'tensorflow_probability.python.bijectors' has no attribute 'AffineScalar'
#> .

Created on 2022-08-17 by the reprex package (v2.0.1)

Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.2.0 (2022-04-22) #> os macOS Monterey 12.3.1 #> system aarch64, darwin20 #> ui X11 #> language (EN) #> collate en_AU.UTF-8 #> ctype en_AU.UTF-8 #> tz Australia/Perth #> date 2022-08-17 #> pandoc 2.18 @ /Applications/RStudio.app/Contents/MacOS/quarto/bin/tools/ (via rmarkdown) #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> package * version date (UTC) lib source #> abind 1.4-5 2016-07-21 [1] CRAN (R 4.2.0) #> backports 1.4.1 2021-12-13 [1] CRAN (R 4.2.0) #> base64enc 0.1-3 2015-07-28 [1] CRAN (R 4.2.0) #> callr 3.7.1 2022-07-13 [1] CRAN (R 4.2.0) #> cli 3.3.0.9000 2022-06-15 [1] Github (r-lib/cli@31a5db5) #> coda 0.19-4 2020-09-30 [1] CRAN (R 4.2.0) #> codetools 0.2-18 2020-11-04 [1] CRAN (R 4.2.0) #> crayon 1.5.1 2022-03-26 [1] CRAN (R 4.2.0) #> digest 0.6.29 2021-12-01 [1] CRAN (R 4.2.0) #> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.2.0) #> evaluate 0.16 2022-08-09 [1] CRAN (R 4.2.0) #> fansi 1.0.3 2022-03-24 [1] CRAN (R 4.2.0) #> fastmap 1.1.0 2021-01-25 [1] CRAN (R 4.2.0) #> fs 1.5.2 2021-12-08 [1] CRAN (R 4.2.0) #> future 1.27.0 2022-07-22 [1] CRAN (R 4.2.0) #> globals 0.16.0 2022-08-05 [1] CRAN (R 4.2.0) #> glue 1.6.2 2022-02-24 [1] CRAN (R 4.2.0) #> greta * 0.4.2.9000 2022-08-17 [1] local #> here 1.0.1 2020-12-13 [1] CRAN (R 4.2.0) #> highr 0.9 2021-04-16 [1] CRAN (R 4.2.0) #> hms 1.1.1 2021-09-26 [1] CRAN (R 4.2.0) #> htmltools 0.5.3 2022-07-18 [1] CRAN (R 4.2.0) #> jsonlite 1.8.0 2022-02-22 [1] CRAN (R 4.2.0) #> knitr 1.39 2022-04-26 [1] CRAN (R 4.2.0) #> lattice 0.20-45 2021-09-22 [1] CRAN (R 4.2.0) #> lifecycle 1.0.1 2021-09-24 [1] CRAN (R 4.2.0) #> listenv 0.8.0 2019-12-05 [1] CRAN (R 4.2.0) #> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.2.0) #> Matrix 1.4-1 2022-03-23 [1] CRAN (R 4.2.0) #> parallelly 1.32.1 2022-07-21 [1] CRAN (R 4.2.0) #> pillar 1.8.0 2022-07-18 [1] CRAN (R 4.2.0) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.2.0) #> png 0.1-7 2013-12-03 [1] CRAN (R 4.2.0) #> prettyunits 1.1.1 2020-01-24 [1] CRAN (R 4.2.0) #> processx 3.7.0 2022-07-07 [1] CRAN (R 4.2.0) #> progress 1.2.2 2019-05-16 [1] CRAN (R 4.2.0) #> ps 1.7.1 2022-06-18 [1] CRAN (R 4.2.0) #> purrr 0.3.4 2020-04-17 [1] CRAN (R 4.2.0) #> R.cache 0.16.0 2022-07-21 [1] CRAN (R 4.2.0) #> R.methodsS3 1.8.2 2022-06-13 [1] CRAN (R 4.2.0) #> R.oo 1.25.0 2022-06-12 [1] CRAN (R 4.2.0) #> R.utils 2.12.0 2022-06-28 [1] CRAN (R 4.2.0) #> R6 2.5.1 2021-08-19 [1] CRAN (R 4.2.0) #> Rcpp 1.0.9 2022-07-08 [1] CRAN (R 4.2.0) #> reprex 2.0.1 2021-08-05 [1] CRAN (R 4.2.0) #> reticulate 1.25 2022-05-11 [1] CRAN (R 4.2.0) #> rlang 1.0.4 2022-07-12 [1] CRAN (R 4.2.0) #> rmarkdown 2.14 2022-04-25 [1] CRAN (R 4.2.0) #> rprojroot 2.0.3 2022-04-02 [1] CRAN (R 4.2.0) #> rstudioapi 0.13 2020-11-12 [1] CRAN (R 4.2.0) #> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.2.0) #> stringi 1.7.8 2022-07-11 [1] CRAN (R 4.2.0) #> stringr 1.4.0 2019-02-10 [1] CRAN (R 4.2.0) #> styler 1.7.0 2022-03-13 [1] CRAN (R 4.2.0) #> tensorflow 2.9.0 2022-05-21 [1] CRAN (R 4.2.0) #> tfautograph 0.3.2 2021-09-17 [1] CRAN (R 4.2.0) #> tfruns 1.5.0 2021-02-26 [1] CRAN (R 4.2.0) #> tibble 3.1.8 2022-07-22 [1] CRAN (R 4.2.0) #> utf8 1.2.2 2021-07-24 [1] CRAN (R 4.2.0) #> vctrs 0.4.1 2022-04-13 [1] CRAN (R 4.2.0) #> whisker 0.4 2019-08-28 [1] CRAN (R 4.2.0) #> withr 2.5.0 2022-03-03 [1] CRAN (R 4.2.0) #> xfun 0.32.1 2022-08-11 [1] https://yihui.r-universe.dev (R 4.2.0) #> yaml 2.3.5 2022-02-21 [1] CRAN (R 4.2.0) #> #> [1] /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/library #> #> ─ Python configuration ─────────────────────────────────────────────────────── #> python: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/bin/python #> libpython: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/libpython3.8.dylib #> pythonhome: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2:/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2 #> version: 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:05:16) [Clang 12.0.1 ] #> numpy: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/numpy #> numpy_version: 1.22.4 #> tensorflow: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/tensorflow #> #> NOTE: Python version was forced by use_python function #> #> ────────────────────────────────────────────────────────────────────────────── ```
njtierney commented 2 years ago

OK so I've traced it back to

https://github.com/njtierney/greta/blob/tf2-poke-tf-fun/R/inference_class.R#L183-L186

And I can't trace it much further than that for the moment...unsure why this is happening for lognormal as opposed to just normal

njtierney commented 2 years ago

OK, tracing further, it seems it is related to here:

https://github.com/njtierney/greta/blob/tf2-poke-tf-fun/R/tf_functions.R#L591-L613

The issue is specifically this code

tfp$bijectors$AffineScalar

And reading the changelog - https://github.com/tensorflow/probability/releases/tag/v0.15.0

They state:

BREAKING CHANGE: Remove deprecated AffineScalar bijector. Please use tfb.Shift(shift)(tfb.Scale(scale)) instead.

So that gives us a starting point, will experiment

njtierney commented 1 year ago

OK so current fix is:


tf_scalar_pos_bijector <- function(dim, lower, upper) {
  tf_scalar_biject(
    tfp$bijectors$Shift(fl(lower)),
    # tfp$bijectors$AffineScalar(shift = fl(lower)),
    tfp$bijectors$Exp(),
    dim = dim
  )
}

This will need to change for

tf_scalar_neg_bijector <- function(dim, lower, upper) {
  tf_scalar_biject(
    tfp$bijectors$AffineScalar(shift = fl(upper), scale = fl(-1)),
    tfp$bijectors$Exp(),
    dim = dim
  )
}

Which has the additional scale argument, which I think needs to be handled as a subsequent call?

E.g.,

tfp$bijectors$Shift(shift)(tfp$bijectors$Scale(scale))

Which in my head is a bit hectic, because it means there's a function call in there, which I wasn't expecting. Not sure if it should be two separate calls. Might make it simpler to parse/maintain.

Also looks like the lognormal is working

library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply
greta_sitrep()
#> ℹ checking if python available
#> ✔ python (v3.8) available
#> 
#> ℹ checking if TensorFlow available
#> ✔ TensorFlow (v2.9.2) available
#> 
#> ℹ checking if TensorFlow Probability available
#> ✔ TensorFlow Probability (v0.17.0) available
#> 
#> ℹ checking if greta conda environment available
#> ✔ greta conda environment available
#> 
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#> 
#> ℹ greta is ready to use!
ln <- lognormal(0,1)
m_ln <- model(ln)
#> Loaded Tensorflow version 2.9.2
draws_ln <- mcmc(m_ln, n_samples = 500, warmup = 500)
#> running 4 chains simultaneously on up to 8 CPU cores
#> 
#>     warmup                                            0/500 | eta:  ?s              warmup ====                                      50/500 | eta:  7s | 9% bad     warmup ========                                 100/500 | eta:  4s | 4% bad     warmup ============                             150/500 | eta:  3s | 3% bad     warmup ================                         200/500 | eta:  2s | 2% bad     warmup ====================                     250/500 | eta:  2s | 7% bad     warmup ========================                 300/500 | eta:  1s | 6% bad     warmup ============================             350/500 | eta:  1s | 5% bad     warmup ================================         400/500 | eta:  1s | 5% bad     warmup ====================================     450/500 | eta:  0s | 5% bad     warmup ======================================== 500/500 | eta:  0s | 4% bad 
#>   sampling                                            0/500 | eta:  ?s            sampling ====                                      50/500 | eta:  1s            sampling ========                                 100/500 | eta:  1s            sampling ============                             150/500 | eta:  0s            sampling ================                         200/500 | eta:  0s            sampling ====================                     250/500 | eta:  0s            sampling ========================                 300/500 | eta:  0s            sampling ============================             350/500 | eta:  0s            sampling ================================         400/500 | eta:  0s            sampling ====================================     450/500 | eta:  0s            sampling ======================================== 500/500 | eta:  0s

library(coda)
#> 
#> Attaching package: 'coda'
#> 
#> The following object is masked from 'package:greta':
#> 
#>     mcmc
plot(draws_ln)

Created on 2022-08-19 by the reprex package (v2.0.1)

Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.2.0 (2022-04-22) #> os macOS Monterey 12.3.1 #> system aarch64, darwin20 #> ui X11 #> language (EN) #> collate en_AU.UTF-8 #> ctype en_AU.UTF-8 #> tz Australia/Perth #> date 2022-08-19 #> pandoc 2.18 @ /Applications/RStudio.app/Contents/MacOS/quarto/bin/tools/ (via rmarkdown) #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> package * version date (UTC) lib source #> abind 1.4-5 2016-07-21 [1] CRAN (R 4.2.0) #> backports 1.4.1 2021-12-13 [1] CRAN (R 4.2.0) #> base64enc 0.1-3 2015-07-28 [1] CRAN (R 4.2.0) #> callr 3.7.1 2022-07-13 [1] CRAN (R 4.2.0) #> cli 3.3.0.9000 2022-06-15 [1] Github (r-lib/cli@31a5db5) #> coda * 0.19-4 2020-09-30 [1] CRAN (R 4.2.0) #> codetools 0.2-18 2020-11-04 [1] CRAN (R 4.2.0) #> crayon 1.5.1 2022-03-26 [1] CRAN (R 4.2.0) #> curl 4.3.2 2021-06-23 [1] CRAN (R 4.2.0) #> digest 0.6.29 2021-12-01 [1] CRAN (R 4.2.0) #> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.2.0) #> evaluate 0.16 2022-08-09 [1] CRAN (R 4.2.0) #> fansi 1.0.3 2022-03-24 [1] CRAN (R 4.2.0) #> fastmap 1.1.0 2021-01-25 [1] CRAN (R 4.2.0) #> fs 1.5.2 2021-12-08 [1] CRAN (R 4.2.0) #> future 1.27.0 2022-07-22 [1] CRAN (R 4.2.0) #> globals 0.16.0 2022-08-05 [1] CRAN (R 4.2.0) #> glue 1.6.2 2022-02-24 [1] CRAN (R 4.2.0) #> greta * 0.4.2.9000 2022-08-19 [1] local #> here 1.0.1 2020-12-13 [1] CRAN (R 4.2.0) #> highr 0.9 2021-04-16 [1] CRAN (R 4.2.0) #> hms 1.1.1 2021-09-26 [1] CRAN (R 4.2.0) #> htmltools 0.5.3 2022-07-18 [1] CRAN (R 4.2.0) #> httr 1.4.3 2022-05-04 [1] CRAN (R 4.2.0) #> jsonlite 1.8.0 2022-02-22 [1] CRAN (R 4.2.0) #> knitr 1.39 2022-04-26 [1] CRAN (R 4.2.0) #> lattice 0.20-45 2021-09-22 [1] CRAN (R 4.2.0) #> lifecycle 1.0.1 2021-09-24 [1] CRAN (R 4.2.0) #> listenv 0.8.0 2019-12-05 [1] CRAN (R 4.2.0) #> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.2.0) #> Matrix 1.4-1 2022-03-23 [1] CRAN (R 4.2.0) #> mime 0.12 2021-09-28 [1] CRAN (R 4.2.0) #> parallelly 1.32.1 2022-07-21 [1] CRAN (R 4.2.0) #> pillar 1.8.0 2022-07-18 [1] CRAN (R 4.2.0) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.2.0) #> png 0.1-7 2013-12-03 [1] CRAN (R 4.2.0) #> prettyunits 1.1.1 2020-01-24 [1] CRAN (R 4.2.0) #> processx 3.7.0 2022-07-07 [1] CRAN (R 4.2.0) #> progress 1.2.2 2019-05-16 [1] CRAN (R 4.2.0) #> ps 1.7.1 2022-06-18 [1] CRAN (R 4.2.0) #> purrr 0.3.4 2020-04-17 [1] CRAN (R 4.2.0) #> R.cache 0.16.0 2022-07-21 [1] CRAN (R 4.2.0) #> R.methodsS3 1.8.2 2022-06-13 [1] CRAN (R 4.2.0) #> R.oo 1.25.0 2022-06-12 [1] CRAN (R 4.2.0) #> R.utils 2.12.0 2022-06-28 [1] CRAN (R 4.2.0) #> R6 2.5.1 2021-08-19 [1] CRAN (R 4.2.0) #> Rcpp 1.0.9 2022-07-08 [1] CRAN (R 4.2.0) #> reprex 2.0.1 2021-08-05 [1] CRAN (R 4.2.0) #> reticulate 1.25 2022-05-11 [1] CRAN (R 4.2.0) #> rlang 1.0.4 2022-07-12 [1] CRAN (R 4.2.0) #> rmarkdown 2.14 2022-04-25 [1] CRAN (R 4.2.0) #> rprojroot 2.0.3 2022-04-02 [1] CRAN (R 4.2.0) #> rstudioapi 0.13 2020-11-12 [1] CRAN (R 4.2.0) #> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.2.0) #> stringi 1.7.8 2022-07-11 [1] CRAN (R 4.2.0) #> stringr 1.4.0 2019-02-10 [1] CRAN (R 4.2.0) #> styler 1.7.0 2022-03-13 [1] CRAN (R 4.2.0) #> tensorflow 2.9.0 2022-05-21 [1] CRAN (R 4.2.0) #> tfautograph 0.3.2 2021-09-17 [1] CRAN (R 4.2.0) #> tfruns 1.5.0 2021-02-26 [1] CRAN (R 4.2.0) #> tibble 3.1.8 2022-07-22 [1] CRAN (R 4.2.0) #> utf8 1.2.2 2021-07-24 [1] CRAN (R 4.2.0) #> vctrs 0.4.1 2022-04-13 [1] CRAN (R 4.2.0) #> whisker 0.4 2019-08-28 [1] CRAN (R 4.2.0) #> withr 2.5.0 2022-03-03 [1] CRAN (R 4.2.0) #> xfun 0.32.1 2022-08-11 [1] https://yihui.r-universe.dev (R 4.2.0) #> xml2 1.3.3 2021-11-30 [1] CRAN (R 4.2.0) #> yaml 2.3.5 2022-02-21 [1] CRAN (R 4.2.0) #> #> [1] /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/library #> #> ─ Python configuration ─────────────────────────────────────────────────────── #> python: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/bin/python #> libpython: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/libpython3.8.dylib #> pythonhome: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2:/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2 #> version: 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:05:16) [Clang 12.0.1 ] #> numpy: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/numpy #> numpy_version: 1.22.4 #> tensorflow: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/tensorflow #> #> NOTE: Python version was forced by use_python function #> #> ────────────────────────────────────────────────────────────────────────────── ```
njtierney commented 1 year ago

This has been resolved as above