greta-dev / greta

simple and scalable statistical modelling in R
https://greta-stats.org
Other
527 stars 63 forks source link

Calculate fail for hierarchical models #347

Open lionel68 opened 4 years ago

lionel68 commented 4 years ago

While trying to use calculate to get prior distribution of an hierarchical model with correlated varying effect I get an error:

modmat <- model.matrix(~ Sepal.Width, iris) 
# index of species
jj <- as.numeric(iris$Species)

M <- ncol(modmat) # number of varying coefficients
N <- max(jj) # number of species

# prior on the standard deviation of the varying coefficient
tau <- exponential(0.5, dim = M)

# prior on the correlation between the varying coefficient
Omega <- lkj_correlation(3, M)

# optimization of the varying coefficient sampling through
# cholesky factorization and whitening
Omega_U <- chol(Omega)
Sigma_U <- sweep(Omega_U, 2, tau, "*")
z <- normal(0, 1, dim = c(N, M)) 
ab <- z %*% Sigma_U # equivalent to: ab ~ multi_normal(0, Sigma_U)

# the linear predictor
mu <- rowSums(ab[jj,] * modmat)

# the residual variance
sigma_e <- cauchy(0, 3, truncation = c(0, Inf))

#model
y <- as_data(iris$Sepal.Length)
distribution(y) <- normal(mu, sigma_e)

# get priors
calculate(y)

Returns an error:

 Fehler in py_call_impl(callable, dots$args, dots$keywords) : 
  InvalidArgumentError: You must feed a value for placeholder tensor 'Placeholder_3' with dtype double and shape [1,2,2]
     [[node Placeholder_3 (defined at /ops/array_ops.py:2143) ]]

Original stack trace for 'Placeholder_3':
  File "/ops/array_ops.py", line 2143, in placeholder
    return gen_array_ops.placeholder(dtype=dtype, shape=shape, name=name)
  File "/ops/gen_array_ops.py", line 6262, in placeholder
    "Placeholder", dtype=dtype, shape=shape, name=name)
  File "/framework/op_def_library.py", line 788, in _apply_op_helper
    op_def=op_def)
  File "/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/framework/ops.py", line 3616, in create_op
    op_def=op_def)
  File "/framework/ops.py", line 2005, in __init__
    self._traceback = tf_stack.extract_stack()

This was done with the following package versions:

R version 3.6.2 greta_0.3.1.9011

adknudson commented 4 years ago

Using your code, I was not able to reproduce the error.

R v3.6.1 Greta v0.3.1 Python v3.6 TF v1.14 TF-Probability v0.7

njtierney commented 3 years ago

Hi there, are you able to confirm if this error is still happening with the new version of greta?

There is a new approach in {greta} to installation, it involves some interactive prompts that arise when you install greta, which help setup a python environment with a specific version of python and other python packages. It should help make everything more reproducible and easier to implement. Are you able to give this a try? Let us know if you run into any issues and we will try and resolve as soon as possible.

Install current master branch of greta

# install.packages("remotes")
remotes::install_github("greta-dev/greta")

Restart R

Load greta with library(greta)

library(greta)

image

Create a greta model

This will initialise python and trigger internal checks that make sure packages are installed. Something like this code is short and sweet and should trigger this.

model(normal(0,1))

image

Then this:

image

Follow these instructions:

Install greta dependencies

install_greta_deps()

image

image

image

Restart R + run library(greta)

library(greta)

Create a greta model

model(normal(0,1))

image

image

Let us know if this works! 😄

njtierney commented 2 years ago

Just wanted to write that I get the same error on the latest version of greta:

library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply
modmat <- model.matrix(~ Sepal.Width, iris) 
# index of species
jj <- as.numeric(iris$Species)

M <- ncol(modmat) # number of varying coefficients
N <- max(jj) # number of species

# prior on the standard deviation of the varying coefficient
tau <- exponential(0.5, dim = M)
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✓ Initialising python and checking dependencies ... done!
#> 

# prior on the correlation between the varying coefficient
Omega <- lkj_correlation(3, M)

# optimization of the varying coefficient sampling through
# cholesky factorization and whitening
Omega_U <- chol(Omega)
Sigma_U <- sweep(Omega_U, 2, tau, "*")
z <- normal(0, 1, dim = c(N, M)) 
ab <- z %*% Sigma_U # equivalent to: ab ~ multi_normal(0, Sigma_U)

# the linear predictor
mu <- rowSums(ab[jj,] * modmat)

# the residual variance
sigma_e <- cauchy(0, 3, truncation = c(0, Inf))

#model
y <- as_data(iris$Sepal.Length)
distribution(y) <- normal(mu, sigma_e)

# get priors
calculate(y)
#> Error in py_call_impl(callable, dots$args, dots$keywords): InvalidArgumentError: You must feed a value for placeholder tensor 'Placeholder_3' with dtype double and shape [1,2,2]
#>   [[node Placeholder_3 (defined at /ops/array_ops.py:2143) ]]
#> 
#> Original stack trace for 'Placeholder_3':
#>   File "/ops/array_ops.py", line 2143, in placeholder
#>     return gen_array_ops.placeholder(dtype=dtype, shape=shape, name=name)
#>   File "/ops/gen_array_ops.py", line 6262, in placeholder
#>     "Placeholder", dtype=dtype, shape=shape, name=name)
#>   File "/framework/op_def_library.py", line 788, in _apply_op_helper
#>     op_def=op_def)
#>   File "/util/deprecation.py", line 507, in new_func
#>     return func(*args, **kwargs)
#>   File "/framework/ops.py", line 3616, in create_op
#>     op_def=op_def)
#>   File "/framework/ops.py", line 2005, in __init__
#>     self._traceback = tf_stack.extract_stack()
#> 
#> 
#> Detailed traceback:
#>   File "/Users/njtierney/Library/r-miniconda/envs/greta-env/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 950, in run
#>     run_metadata_ptr)
#>   File "/Users/njtierney/Library/r-miniconda/envs/greta-env/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1173, in _run
#>     feed_dict_tensor, options, run_metadata)
#>   File "/Users/njtierney/Library/r-miniconda/envs/greta-env/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1350, in _do_run
#>     run_metadata)
#>   File "/Users/njtierney/Library/r-miniconda/envs/greta-env/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1370, in _do_call
#>     raise type(e)(node_def, op, message)

Created on 2022-03-18 by the reprex package (v2.0.1)

Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.1.3 (2022-03-10) #> os macOS Big Sur/Monterey 10.16 #> system x86_64, darwin17.0 #> ui X11 #> language (EN) #> collate en_AU.UTF-8 #> ctype en_AU.UTF-8 #> tz Australia/Perth #> date 2022-03-18 #> pandoc 2.17.1.1 @ /Applications/RStudio.app/Contents/MacOS/quarto/bin/ (via rmarkdown) #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> package * version date (UTC) lib source #> backports 1.4.1 2021-12-13 [1] CRAN (R 4.1.0) #> base64enc 0.1-3 2015-07-28 [1] CRAN (R 4.1.0) #> callr 3.7.0 2021-04-20 [1] CRAN (R 4.1.0) #> cli 3.2.0 2022-02-14 [1] CRAN (R 4.1.2) #> coda 0.19-4 2020-09-30 [1] CRAN (R 4.1.0) #> codetools 0.2-18 2020-11-04 [1] CRAN (R 4.1.3) #> crayon 1.5.0 2022-02-14 [1] CRAN (R 4.1.2) #> digest 0.6.29 2021-12-01 [1] CRAN (R 4.1.0) #> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.1.0) #> evaluate 0.15 2022-02-18 [1] CRAN (R 4.1.2) #> fansi 1.0.2 2022-01-14 [1] CRAN (R 4.1.2) #> fastmap 1.1.0 2021-01-25 [1] CRAN (R 4.1.0) #> fs 1.5.2 2021-12-08 [1] CRAN (R 4.1.0) #> future 1.24.0 2022-02-19 [1] CRAN (R 4.1.2) #> globals 0.14.0 2020-11-22 [1] CRAN (R 4.1.0) #> glue 1.6.2 2022-02-24 [1] CRAN (R 4.1.2) #> greta * 0.4.1 2022-03-15 [1] CRAN (R 4.1.2) #> here 1.0.1 2020-12-13 [1] CRAN (R 4.1.0) #> highr 0.9 2021-04-16 [1] CRAN (R 4.1.0) #> hms 1.1.1 2021-09-26 [1] CRAN (R 4.1.0) #> htmltools 0.5.2 2021-08-25 [1] CRAN (R 4.1.0) #> jsonlite 1.8.0 2022-02-22 [1] CRAN (R 4.1.2) #> knitr 1.37 2021-12-16 [1] CRAN (R 4.1.0) #> lattice 0.20-45 2021-09-22 [1] CRAN (R 4.1.3) #> lifecycle 1.0.1 2021-09-24 [1] CRAN (R 4.1.0) #> listenv 0.8.0 2019-12-05 [1] CRAN (R 4.1.0) #> magrittr 2.0.2 2022-01-26 [1] CRAN (R 4.1.2) #> Matrix 1.4-0 2021-12-08 [1] CRAN (R 4.1.3) #> parallelly 1.30.0 2021-12-17 [1] CRAN (R 4.1.0) #> pillar 1.7.0 2022-02-01 [1] CRAN (R 4.1.2) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.1.0) #> png 0.1-7 2013-12-03 [1] CRAN (R 4.1.0) #> prettyunits 1.1.1 2020-01-24 [1] CRAN (R 4.1.0) #> processx 3.5.2 2021-04-30 [1] CRAN (R 4.1.0) #> progress 1.2.2 2019-05-16 [1] CRAN (R 4.1.0) #> ps 1.6.0 2021-02-28 [1] CRAN (R 4.1.0) #> purrr 0.3.4 2020-04-17 [1] CRAN (R 4.1.0) #> R.cache 0.15.0 2021-04-30 [1] CRAN (R 4.1.0) #> R.methodsS3 1.8.1 2020-08-26 [1] CRAN (R 4.1.0) #> R.oo 1.24.0 2020-08-26 [1] CRAN (R 4.1.0) #> R.utils 2.11.0 2021-09-26 [1] CRAN (R 4.1.0) #> R6 2.5.1 2021-08-19 [1] CRAN (R 4.1.0) #> Rcpp 1.0.8.2 2022-03-11 [1] CRAN (R 4.1.2) #> reprex 2.0.1 2021-08-05 [1] CRAN (R 4.1.0) #> reticulate 1.24 2022-01-26 [1] CRAN (R 4.1.2) #> rlang 1.0.2 2022-03-04 [1] CRAN (R 4.1.2) #> rmarkdown 2.11 2021-09-14 [1] CRAN (R 4.1.0) #> rprojroot 2.0.2 2020-11-15 [1] CRAN (R 4.1.0) #> rstudioapi 0.13 2020-11-12 [1] CRAN (R 4.1.0) #> sessioninfo 1.2.2.9000 2022-03-01 [1] Github (r-lib/sessioninfo@d70760d) #> stringi 1.7.6 2021-11-29 [1] CRAN (R 4.1.0) #> stringr 1.4.0 2019-02-10 [1] CRAN (R 4.1.0) #> styler 1.6.2 2021-09-23 [1] CRAN (R 4.1.0) #> tensorflow 2.8.0 2022-02-09 [1] CRAN (R 4.1.2) #> tfruns 1.5.0 2021-02-26 [1] CRAN (R 4.1.0) #> tibble 3.1.6 2021-11-07 [1] CRAN (R 4.1.0) #> utf8 1.2.2 2021-07-24 [1] CRAN (R 4.1.0) #> vctrs 0.3.8 2021-04-29 [1] CRAN (R 4.1.0) #> whisker 0.4 2019-08-28 [1] CRAN (R 4.1.0) #> withr 2.5.0 2022-03-03 [1] CRAN (R 4.1.2) #> xfun 0.30 2022-03-02 [1] CRAN (R 4.1.2) #> yaml 2.3.5 2022-02-21 [1] CRAN (R 4.1.2) #> #> [1] /Library/Frameworks/R.framework/Versions/4.1/Resources/library #> #> ─ Python configuration ─────────────────────────────────────────────────────── #> python: /Users/njtierney/Library/r-miniconda/envs/greta-env/bin/python #> libpython: /Users/njtierney/Library/r-miniconda/envs/greta-env/lib/libpython3.7m.dylib #> pythonhome: /Users/njtierney/Library/r-miniconda/envs/greta-env:/Users/njtierney/Library/r-miniconda/envs/greta-env #> version: 3.7.12 | packaged by conda-forge | (default, Oct 26 2021, 05:59:23) [Clang 11.1.0 ] #> numpy: /Users/njtierney/Library/r-miniconda/envs/greta-env/lib/python3.7/site-packages/numpy #> numpy_version: 1.16.4 #> tensorflow: /Users/njtierney/Library/r-miniconda/envs/greta-env/lib/python3.7/site-packages/tensorflow #> #> NOTE: Python version was forced by use_python function #> #> ────────────────────────────────────────────────────────────────────────────── ```
njtierney commented 2 years ago

Working through this, here is a more minimal reproducible example - something is going wrong with how it is accessing/calculating chol:

library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson, sd
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply
Omega <- lkj_correlation(3, 2)
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✓ Initialising python and checking dependencies ... done!
#> 

# optimization of the varying coefficient sampling through
# cholesky factorization and whitening
Omega_U <- chol(Omega)

calculate(Omega, nsim = 3)
#> $Omega
#> , , 1
#> 
#>      [,1]      [,2]
#> [1,]    1 0.2870519
#> [2,]    1 0.5621991
#> [3,]    1 0.3855464
#> 
#> , , 2
#> 
#>           [,1] [,2]
#> [1,] 0.2870519    1
#> [2,] 0.5621991    1
#> [3,] 0.3855464    1
calculate(Omega_U, nsim = 3)
#> Error in py_call_impl(callable, dots$args, dots$keywords): InvalidArgumentError: You must feed a value for placeholder tensor 'Placeholder_1' with dtype double and shape [1,2,2]
#>   [[node Placeholder_1 (defined at /ops/array_ops.py:2143) ]]
#> 
#> Original stack trace for 'Placeholder_1':
#>   File "/ops/array_ops.py", line 2143, in placeholder
#>     return gen_array_ops.placeholder(dtype=dtype, shape=shape, name=name)
#>   File "/ops/gen_array_ops.py", line 6262, in placeholder
#>     "Placeholder", dtype=dtype, shape=shape, name=name)
#>   File "/framework/op_def_library.py", line 788, in _apply_op_helper
#>     op_def=op_def)
#>   File "/util/deprecation.py", line 507, in new_func
#>     return func(*args, **kwargs)
#>   File "/framework/ops.py", line 3616, in create_op
#>     op_def=op_def)
#>   File "/framework/ops.py", line 2005, in __init__
#>     self._traceback = tf_stack.extract_stack()
#> 
#> 
#> Detailed traceback:
#>   File "/Users/njtierney/Library/r-miniconda/envs/greta-env/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 950, in run
#>     run_metadata_ptr)
#>   File "/Users/njtierney/Library/r-miniconda/envs/greta-env/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1173, in _run
#>     feed_dict_tensor, options, run_metadata)
#>   File "/Users/njtierney/Library/r-miniconda/envs/greta-env/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1350, in _do_run
#>     run_metadata)
#>   File "/Users/njtierney/Library/r-miniconda/envs/greta-env/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1370, in _do_call
#>     raise type(e)(node_def, op, message)

Created on 2022-03-31 by the reprex package (v2.0.1)

Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.1.3 (2022-03-10) #> os macOS Big Sur/Monterey 10.16 #> system x86_64, darwin17.0 #> ui X11 #> language (EN) #> collate en_AU.UTF-8 #> ctype en_AU.UTF-8 #> tz Australia/Perth #> date 2022-03-31 #> pandoc 2.17.1.1 @ /Applications/RStudio.app/Contents/MacOS/quarto/bin/ (via rmarkdown) #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> package * version date (UTC) lib source #> backports 1.4.1 2021-12-13 [1] CRAN (R 4.1.0) #> base64enc 0.1-3 2015-07-28 [1] CRAN (R 4.1.0) #> callr 3.7.0 2021-04-20 [1] CRAN (R 4.1.0) #> cli 3.2.0 2022-02-14 [1] CRAN (R 4.1.2) #> coda 0.19-4 2020-09-30 [1] CRAN (R 4.1.0) #> codetools 0.2-18 2020-11-04 [1] CRAN (R 4.1.3) #> crayon 1.5.0 2022-02-14 [1] CRAN (R 4.1.2) #> digest 0.6.29 2021-12-01 [1] CRAN (R 4.1.0) #> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.1.0) #> evaluate 0.15 2022-02-18 [1] CRAN (R 4.1.2) #> fansi 1.0.2 2022-01-14 [1] CRAN (R 4.1.2) #> fastmap 1.1.0 2021-01-25 [1] CRAN (R 4.1.0) #> fs 1.5.2 2021-12-08 [1] CRAN (R 4.1.0) #> future 1.24.0 2022-02-19 [1] CRAN (R 4.1.2) #> globals 0.14.0 2020-11-22 [1] CRAN (R 4.1.0) #> glue 1.6.2 2022-02-24 [1] CRAN (R 4.1.2) #> greta * 0.4.2.9000 2022-03-31 [1] local #> here 1.0.1 2020-12-13 [1] CRAN (R 4.1.0) #> highr 0.9 2021-04-16 [1] CRAN (R 4.1.0) #> hms 1.1.1 2021-09-26 [1] CRAN (R 4.1.0) #> htmltools 0.5.2 2021-08-25 [1] CRAN (R 4.1.0) #> jsonlite 1.8.0 2022-02-22 [1] CRAN (R 4.1.2) #> knitr 1.37 2021-12-16 [1] CRAN (R 4.1.0) #> lattice 0.20-45 2021-09-22 [1] CRAN (R 4.1.3) #> lifecycle 1.0.1 2021-09-24 [1] CRAN (R 4.1.0) #> listenv 0.8.0 2019-12-05 [1] CRAN (R 4.1.0) #> magrittr 2.0.2 2022-01-26 [1] CRAN (R 4.1.2) #> Matrix 1.4-0 2021-12-08 [1] CRAN (R 4.1.3) #> parallelly 1.30.0 2021-12-17 [1] CRAN (R 4.1.0) #> pillar 1.7.0 2022-02-01 [1] CRAN (R 4.1.2) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.1.0) #> png 0.1-7 2013-12-03 [1] CRAN (R 4.1.0) #> prettyunits 1.1.1 2020-01-24 [1] CRAN (R 4.1.0) #> processx 3.5.2 2021-04-30 [1] CRAN (R 4.1.0) #> progress 1.2.2 2019-05-16 [1] CRAN (R 4.1.0) #> ps 1.6.0 2021-02-28 [1] CRAN (R 4.1.0) #> purrr 0.3.4 2020-04-17 [1] CRAN (R 4.1.0) #> R.cache 0.15.0 2021-04-30 [1] CRAN (R 4.1.0) #> R.methodsS3 1.8.1 2020-08-26 [1] CRAN (R 4.1.0) #> R.oo 1.24.0 2020-08-26 [1] CRAN (R 4.1.0) #> R.utils 2.11.0 2021-09-26 [1] CRAN (R 4.1.0) #> R6 2.5.1 2021-08-19 [1] CRAN (R 4.1.0) #> Rcpp 1.0.8.2 2022-03-11 [1] CRAN (R 4.1.2) #> reprex 2.0.1 2021-08-05 [1] CRAN (R 4.1.0) #> reticulate 1.24 2022-01-26 [1] CRAN (R 4.1.2) #> rlang 1.0.2 2022-03-04 [1] CRAN (R 4.1.2) #> rmarkdown 2.11 2021-09-14 [1] CRAN (R 4.1.0) #> rprojroot 2.0.2 2020-11-15 [1] CRAN (R 4.1.0) #> rstudioapi 0.13 2020-11-12 [1] CRAN (R 4.1.0) #> sessioninfo 1.2.2.9000 2022-03-01 [1] Github (r-lib/sessioninfo@d70760d) #> stringi 1.7.6 2021-11-29 [1] CRAN (R 4.1.0) #> stringr 1.4.0 2019-02-10 [1] CRAN (R 4.1.0) #> styler 1.6.2 2021-09-23 [1] CRAN (R 4.1.0) #> tensorflow 2.8.0 2022-02-09 [1] CRAN (R 4.1.2) #> tfruns 1.5.0 2021-02-26 [1] CRAN (R 4.1.0) #> tibble 3.1.6 2021-11-07 [1] CRAN (R 4.1.0) #> utf8 1.2.2 2021-07-24 [1] CRAN (R 4.1.0) #> vctrs 0.3.8 2021-04-29 [1] CRAN (R 4.1.0) #> whisker 0.4 2019-08-28 [1] CRAN (R 4.1.0) #> withr 2.5.0 2022-03-03 [1] CRAN (R 4.1.2) #> xfun 0.30 2022-03-02 [1] CRAN (R 4.1.2) #> yaml 2.3.5 2022-02-21 [1] CRAN (R 4.1.2) #> #> [1] /Library/Frameworks/R.framework/Versions/4.1/Resources/library #> #> ─ Python configuration ─────────────────────────────────────────────────────── #> python: /Users/njtierney/Library/r-miniconda/envs/greta-env/bin/python #> libpython: /Users/njtierney/Library/r-miniconda/envs/greta-env/lib/libpython3.7m.dylib #> pythonhome: /Users/njtierney/Library/r-miniconda/envs/greta-env:/Users/njtierney/Library/r-miniconda/envs/greta-env #> version: 3.7.12 | packaged by conda-forge | (default, Oct 26 2021, 05:59:23) [Clang 11.1.0 ] #> numpy: /Users/njtierney/Library/r-miniconda/envs/greta-env/lib/python3.7/site-packages/numpy #> numpy_version: 1.16.4 #> tensorflow: /Users/njtierney/Library/r-miniconda/envs/greta-env/lib/python3.7/site-packages/tensorflow #> #> NOTE: Python version was forced by use_python function #> #> ────────────────────────────────────────────────────────────────────────────── ```