greta-dev / greta

simple and scalable statistical modelling in R
https://greta-stats.org
Other
527 stars 63 forks source link

initial_values seem to fail with LKJ priors #440

Open hrlai opened 2 years ago

hrlai commented 2 years ago

Hi! Whenever the lkj_correlation() prior is part of a model, the initial_values argument seems to always cause mcmc() to fail.

Reproducible code from the greta example models page:

# model matrix
modmat <- model.matrix(~ Sepal.Width, iris) 
# index of species
jj <- as.numeric(iris$Species)

M <- ncol(modmat) # number of varying coefficients
N <- max(jj) # number of species

# prior on the standard deviation of the varying coefficient
tau <- exponential(0.5, dim = M)

# prior on the correlation between the varying coefficient
Omega <- lkj_correlation(3, M)

# optimization of the varying coefficient sampling through
# cholesky factorization and whitening
Omega_U <- chol(Omega)
Sigma_U <- sweep(Omega_U, 2, tau, "*")
z <- normal(0, 1, dim = c(N, M)) 
ab <- z %*% Sigma_U # equivalent to: ab ~ multi_normal(0, Sigma_U)

# the linear predictor
mu <- rowSums(ab[jj,] * modmat)

# the residual variance
sigma_e <- cauchy(0, 3, truncation = c(0, Inf))

#model
y <- iris$Sepal.Length
distribution(y) <- normal(mu, sigma_e)
m <- model(ab, sigma_e)

draws <- mcmc(m, chains = 4, initial_values = initials(sigma_e = 1))

On my computer this throws an error:

Error in py_call_impl(callable, dots$args, dots$keywords) : 
  ValueError: Cannot feed value of shape (1, 13) for Tensor 'Placeholder:0', which has shape '(?, 10)'

Detailed traceback:
  File "/home/hrlai/.local/share/r-miniconda/envs/r-reticulate/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 950, in run
    run_metadata_ptr)
  File "/home/hrlai/.local/share/r-miniconda/envs/r-reticulate/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1149, in _run
    str(subfeed_t.get_shape())))

On various datasets, the Cannot feed value of shape (1, 13) for Tensor 'Placeholder:0', which has shape '(?, 10)' message always have the first number 13 larger than the second number 10. And the difference between them seems to be N --- this led me to suspect the LKJ prior in the first place.

I tried to remove LKJ from the model and initial values work again. (After restarting R session) This is reproducible via:

# model matrix
modmat <- model.matrix(~ Sepal.Width, iris) 
# index of species
jj <- as.numeric(iris$Species)

M <- ncol(modmat) # number of varying coefficients
N <- max(jj) # number of species

# prior on the standard deviation of the varying coefficient
tau <- exponential(0.5, dim = M)

Sigma_U <- zeros(dim = c(M, M))
diag(Sigma_U) <- tau
z <- normal(0, 1, dim = c(N, M)) 
ab <- z %*% Sigma_U # equivalent to: ab ~ multi_normal(0, Sigma_U)

# the linear predictor
mu <- rowSums(ab[jj,] * modmat)

# the residual variance
sigma_e <- cauchy(0, 3, truncation = c(0, Inf))

#model
y <- iris$Sepal.Length
distribution(y) <- normal(mu, sigma_e)
m <- model(ab, sigma_e)

draws <- mcmc(m, chains = 4, initial_values = initials(sigma_e = 1))

I'd really like to keep the LKJ prior as well as being able to specify initial values to help chain convergence. Looking forward to hear your idea!

njtierney commented 2 years ago

Thanks for posting! I confirm that I can get the same error:

library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply

# model matrix
modmat <- model.matrix(~ Sepal.Width, iris) 
# index of species
jj <- as.numeric(iris$Species)

M <- ncol(modmat) # number of varying coefficients
N <- max(jj) # number of species

# prior on the standard deviation of the varying coefficient
tau <- exponential(0.5, dim = M)
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✓ Initialising python and checking dependencies ... done!
#> 

# prior on the correlation between the varying coefficient
Omega <- lkj_correlation(3, M)

# optimization of the varying coefficient sampling through
# cholesky factorization and whitening
Omega_U <- chol(Omega)
Sigma_U <- sweep(Omega_U, 2, tau, "*")
z <- normal(0, 1, dim = c(N, M)) 
ab <- z %*% Sigma_U # equivalent to: ab ~ multi_normal(0, Sigma_U)

# the linear predictor
mu <- rowSums(ab[jj,] * modmat)

# the residual variance
sigma_e <- cauchy(0, 3, truncation = c(0, Inf))

#model
y <- iris$Sepal.Length
distribution(y) <- normal(mu, sigma_e)
m <- model(ab, sigma_e)

draws <- mcmc(m, chains = 4, initial_values = initials(sigma_e = 1))
#> only one set of initial values was provided, and was used for all chains
#> Error in py_call_impl(callable, dots$args, dots$keywords): ValueError: Cannot feed value of shape (1, 13) for Tensor 'Placeholder:0', which has shape '(?, 10)'
#> 
#> Detailed traceback:
#>   File "/Users/njtierney/Library/r-miniconda/envs/greta-env/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 950, in run
#>     run_metadata_ptr)
#>   File "/Users/njtierney/Library/r-miniconda/envs/greta-env/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1149, in _run
#>     str(subfeed_t.get_shape())))

Created on 2021-09-28 by the reprex package (v2.0.1)

Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.1.0 (2021-05-18) #> os macOS Big Sur 10.16 #> system x86_64, darwin17.0 #> ui X11 #> language (EN) #> collate en_AU.UTF-8 #> ctype en_AU.UTF-8 #> tz Australia/Perth #> date 2021-09-28 #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> package * version date lib source #> backports 1.2.1 2020-12-09 [1] CRAN (R 4.1.0) #> base64enc 0.1-3 2015-07-28 [1] CRAN (R 4.1.0) #> callr 3.7.0 2021-04-20 [1] CRAN (R 4.1.0) #> cli 3.0.1 2021-07-17 [1] CRAN (R 4.1.0) #> coda 0.19-4 2020-09-30 [1] CRAN (R 4.1.0) #> codetools 0.2-18 2020-11-04 [1] CRAN (R 4.1.0) #> crayon 1.4.1 2021-02-08 [1] CRAN (R 4.1.0) #> digest 0.6.27 2020-10-24 [1] CRAN (R 4.1.0) #> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.1.0) #> evaluate 0.14 2019-05-28 [1] CRAN (R 4.1.0) #> fansi 0.5.0 2021-05-25 [1] CRAN (R 4.1.0) #> fs 1.5.0 2020-07-31 [1] CRAN (R 4.1.0) #> future 1.22.1 2021-08-25 [1] CRAN (R 4.1.0) #> globals 0.14.0 2020-11-22 [1] CRAN (R 4.1.0) #> glue 1.4.2 2020-08-27 [1] CRAN (R 4.1.0) #> greta * 0.3.1.9012 2021-09-23 [1] local #> here 1.0.1 2020-12-13 [1] CRAN (R 4.1.0) #> highr 0.9 2021-04-16 [1] CRAN (R 4.1.0) #> hms 1.1.0 2021-05-17 [1] CRAN (R 4.1.0) #> htmltools 0.5.1.1 2021-01-22 [1] CRAN (R 4.1.0) #> jsonlite 1.7.2 2020-12-09 [1] CRAN (R 4.1.0) #> knitr 1.33 2021-04-24 [1] CRAN (R 4.1.0) #> lattice 0.20-44 2021-05-02 [1] CRAN (R 4.1.0) #> lifecycle 1.0.0 2021-02-15 [1] CRAN (R 4.1.0) #> listenv 0.8.0 2019-12-05 [1] CRAN (R 4.1.0) #> magrittr 2.0.1 2020-11-17 [1] CRAN (R 4.1.0) #> Matrix 1.3-4 2021-06-01 [1] CRAN (R 4.1.0) #> parallelly 1.28.1 2021-09-09 [1] CRAN (R 4.1.0) #> pillar 1.6.2 2021-07-29 [1] CRAN (R 4.1.0) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.1.0) #> png 0.1-7 2013-12-03 [1] CRAN (R 4.1.0) #> prettyunits 1.1.1 2020-01-24 [1] CRAN (R 4.1.0) #> processx 3.5.2 2021-04-30 [1] CRAN (R 4.1.0) #> progress 1.2.2 2019-05-16 [1] CRAN (R 4.1.0) #> ps 1.6.0 2021-02-28 [1] CRAN (R 4.1.0) #> purrr 0.3.4 2020-04-17 [1] CRAN (R 4.1.0) #> R6 2.5.1 2021-08-19 [1] CRAN (R 4.1.0) #> Rcpp 1.0.7 2021-07-07 [1] CRAN (R 4.1.0) #> reprex 2.0.1 2021-08-05 [1] CRAN (R 4.1.0) #> reticulate 1.22 2021-09-17 [1] CRAN (R 4.1.0) #> rlang 0.4.11 2021-04-30 [1] CRAN (R 4.1.0) #> rmarkdown 2.9 2021-06-15 [1] CRAN (R 4.1.0) #> rprojroot 2.0.2 2020-11-15 [1] CRAN (R 4.1.0) #> rstudioapi 0.13 2020-11-12 [1] CRAN (R 4.1.0) #> sessioninfo 1.1.1 2018-11-05 [1] CRAN (R 4.1.0) #> stringi 1.7.4 2021-08-25 [1] CRAN (R 4.1.0) #> stringr 1.4.0 2019-02-10 [1] CRAN (R 4.1.0) #> styler 1.4.1 2021-03-30 [1] CRAN (R 4.1.0) #> tensorflow 2.6.0 2021-08-19 [1] CRAN (R 4.1.0) #> tfruns 1.5.0 2021-02-26 [1] CRAN (R 4.1.0) #> tibble 3.1.4 2021-08-25 [1] CRAN (R 4.1.0) #> utf8 1.2.2 2021-07-24 [1] CRAN (R 4.1.0) #> vctrs 0.3.8 2021-04-29 [1] CRAN (R 4.1.0) #> whisker 0.4 2019-08-28 [1] CRAN (R 4.1.0) #> withr 2.4.2 2021-04-18 [1] CRAN (R 4.1.0) #> xfun 0.24 2021-06-15 [1] CRAN (R 4.1.0) #> yaml 2.2.1 2020-02-01 [1] CRAN (R 4.1.0) #> #> [1] /Library/Frameworks/R.framework/Versions/4.1/Resources/library ```

I'll have a think about this, we're a bit pushed at the moment for time, just wanted to give you a heads up we might not get to this as soon as we would like to help you :)

hrlai commented 2 years ago

Just browsing #314 and saw some discussion on dimensions and placeholders, just noting it down in case they are related.

hrlai commented 12 months ago

I was recently trying to do prior predictive check and discovered that calculate also doesn't work on a greta array with chol operation...

library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply

M <- 3

Omega <- lkj_correlation(3, M)
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#> 
Omega_U <- chol(Omega)

calculate(Omega, nsim = 1)    # works
#> $Omega
#> , , 1
#> 
#>      [,1]       [,2]      [,3]
#> [1,]    1 -0.1751263 0.1297896
#> 
#> , , 2
#> 
#>            [,1] [,2]      [,3]
#> [1,] -0.1751263    1 0.1951698
#> 
#> , , 3
#> 
#>           [,1]      [,2] [,3]
#> [1,] 0.1297896 0.1951698    1
calculate(Omega_U, nsim = 1)  # fails
#> You must feed a value for placeholder tensor 'Placeholder_1' with dtype double and shape [1,3,3]
#>   [[node Placeholder_1 (defined at /ops/array_ops.py:2143) ]]
#> 
#> Original stack trace for 'Placeholder_1':
#>   File "/ops/array_ops.py", line 2143, in placeholder
#>     return gen_array_ops.placeholder(dtype=dtype, shape=shape, name=name)
#>   File "/ops/gen_array_ops.py", line 6262, in placeholder
#>     "Placeholder", dtype=dtype, shape=shape, name=name)
#>   File "/framework/op_def_library.py", line 788, in _apply_op_helper
#>     op_def=op_def)
#>   File "/util/deprecation.py", line 507, in new_func
#>     return func(*args, **kwargs)
#>   File "/framework/ops.py", line 3616, in create_op
#>     op_def=op_def)
#>   File "/framework/ops.py", line 2005, in __init__
#>     self._traceback = tf_stack.extract_stack()

Created on 2023-09-26 with reprex v2.0.2

In case they are related, I'm linking #585 here.