greta-dev / greta

simple and scalable statistical modelling in R
https://greta-stats.org
Other
524 stars 63 forks source link

extra_samples fails when when thin, pb_update and/or n_samples are not multiples #567

Open njtierney opened 1 year ago

njtierney commented 1 year ago
library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply

# define a simple Bayesian model
x <- rnorm(10)
mu <- normal(0, 5)
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#> 
sigma <- lognormal(1, 0.1)
distribution(x) <- normal(mu, sigma)
m <- model(mu, sigma)
#> Loaded Tensorflow version 2.10.0

# carry out mcmc on the model
draws <- mcmc(m, n_samples = 100)
#> running 4 chains simultaneously on up to 8 CPU cores
#> 
#>     warmup                                           0/1000 | eta:  ?s              warmup ==                                       50/1000 | eta: 20s | 18% bad    warmup ====                                    100/1000 | eta: 11s | 11% bad    warmup ======                                  150/1000 | eta:  8s | 8% bad     warmup ========                                200/1000 | eta:  7s | 6% bad     warmup ==========                              250/1000 | eta:  6s | 4% bad     warmup ===========                             300/1000 | eta:  5s | 4% bad     warmup =============                           350/1000 | eta:  4s | 3% bad     warmup ===============                         400/1000 | eta:  4s | 3% bad     warmup =================                       450/1000 | eta:  3s | 2% bad     warmup ===================                     500/1000 | eta:  3s | 2% bad     warmup =====================                   550/1000 | eta:  3s | 2% bad     warmup =======================                 600/1000 | eta:  2s | 2% bad     warmup =========================               650/1000 | eta:  2s | 2% bad     warmup ===========================             700/1000 | eta:  2s | 2% bad     warmup ============================            750/1000 | eta:  1s | 2% bad     warmup ==============================          800/1000 | eta:  1s | 1% bad     warmup ================================        850/1000 | eta:  1s | 1% bad     warmup ==================================      900/1000 | eta:  1s | 1% bad     warmup ====================================    950/1000 | eta:  0s | 1% bad     warmup ====================================== 1000/1000 | eta:  0s | 1% bad 
#>   sampling                                            0/100 | eta:  ?s            sampling ===================                       50/100 | eta:  0s            sampling ======================================   100/100 | eta:  0s

# add some more samples
draws <- extra_samples(draws, 200)
#> running 4 chains simultaneously on up to 8 CPU cores
#>   sampling                                            0/200 | eta:  ?s            sampling ==========                                50/200 | eta:  6s            sampling ====================                     100/200 | eta:  4s            sampling ==============================           150/200 | eta:  2s            sampling ======================================== 200/200 | eta:  0s

# I meet no problem as expected. But when thin, pb_update and/or n_samples are not multiples of one another, it can produce errors: e.g.:
draws <- extra_samples(draws, thin=3, pb_update=50,n_samples=202)
#> running 4 chains simultaneously on up to 8 CPU cores
#> sampling 0/202 | eta: ?s sampling ========== 50/202 | eta: 2s sampling
#> ==================== 100/202 | eta: 2s sampling ==============================
#> 150/202 | eta: 1s sampling ======================================== 200/202 |
#> eta: 0s
#> Error: greta hit a tensorflow error:
#> Error in py_call_impl(callable, dots$args, dots$keywords):
#> tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution
#> error: <... omitted ...>
#> "/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/tensorflow_probability/python/internal/loop_util.py",
#> line 231, in <listcomp> initial_trace, [ta.stack() for ta in trace_arrays],
#> Node: 'mcmc_sample_chain/trace_scan/TensorArrayV2Stack_8/TensorListStack' 2
#> root error(s) found.  (0) INVALID_ARGUMENT: Tried to stack elements of an empty
#> list with non-fully-defined element_shape: [?,2] [[{{node
#> mcmc_sample_chain/trace_scan/TensorArrayV2Stack_8/TensorListStack}}]]
#> [[mcmc_sample_chain/trace_scan/while/exit/_155/_99]] (1) INVALID_ARGUMENT:
#> Tried to stack elements of an empty list with non-fully-defined element_shape:
#> [?,2] [[{{node
#> mcmc_sample_chain/trace_scan/TensorArrayV2Stack_8/TensorListStack}}]] 0
#> successful operations. 0 derived errors ignored. [Op:__inference_fn_2848] See
#> `reticulate::py_last_error()` for details
draws <- extra_samples(draws, thin=101, pb_update=50,n_samples=202)
#> running 4 chains simultaneously on up to 8 CPU cores
#> sampling 0/202 | eta: ?s
#> Error: greta hit a tensorflow error:
#> Error in py_call_impl(callable, dots$args, dots$keywords):
#> tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution
#> error: <... omitted ...>ile
#> "/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/tensorflow_probability/python/internal/loop_util.py",
#> line 231, in <listcomp> initial_trace, [ta.stack() for ta in trace_arrays],
#> Node: 'mcmc_sample_chain/trace_scan/TensorArrayV2Stack_7/TensorListStack' 2
#> root error(s) found.  (0) INVALID_ARGUMENT: Tried to stack elements of an empty
#> list with non-fully-defined element_shape: [?] [[{{node
#> mcmc_sample_chain/trace_scan/TensorArrayV2Stack_7/TensorListStack}}]]
#> [[mcmc_sample_chain/trace_scan/while/exit/_156/_101]] (1) INVALID_ARGUMENT:
#> Tried to stack elements of an empty list with non-fully-defined element_shape:
#> [?] [[{{node
#> mcmc_sample_chain/trace_scan/TensorArrayV2Stack_7/TensorListStack}}]] 0
#> successful operations. 0 derived errors ignored. [Op:__inference_fn_2848] See
#> `reticulate::py_last_error()` for details

Created on 2022-10-18 with reprex v2.0.2

Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.2.1 (2022-06-23) #> os macOS Monterey 12.3.1 #> system aarch64, darwin20 #> ui X11 #> language (EN) #> collate en_AU.UTF-8 #> ctype en_AU.UTF-8 #> tz Australia/Perth #> date 2022-10-18 #> pandoc 2.19.2 @ /Applications/RStudio.app/Contents/MacOS/quarto/bin/tools/ (via rmarkdown) #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> package * version date (UTC) lib source #> abind 1.4-5 2016-07-21 [1] CRAN (R 4.2.0) #> backports 1.4.1 2021-12-13 [1] CRAN (R 4.2.0) #> base64enc 0.1-3 2015-07-28 [1] CRAN (R 4.2.0) #> callr 3.7.2 2022-08-22 [1] CRAN (R 4.2.0) #> cli 3.4.1 2022-09-23 [1] CRAN (R 4.2.0) #> coda 0.19-4 2020-09-30 [1] CRAN (R 4.2.0) #> codetools 0.2-18 2020-11-04 [1] CRAN (R 4.2.1) #> crayon 1.5.2 2022-09-29 [1] CRAN (R 4.2.0) #> digest 0.6.29 2021-12-01 [1] CRAN (R 4.2.0) #> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.2.0) #> evaluate 0.17 2022-10-07 [1] CRAN (R 4.2.0) #> fansi 1.0.3 2022-03-24 [1] CRAN (R 4.2.0) #> fastmap 1.1.0 2021-01-25 [1] CRAN (R 4.2.0) #> fs 1.5.2 2021-12-08 [1] CRAN (R 4.2.0) #> future 1.28.0 2022-09-02 [1] CRAN (R 4.2.0) #> globals 0.16.1 2022-08-28 [1] CRAN (R 4.2.0) #> glue 1.6.2 2022-02-24 [1] CRAN (R 4.2.0) #> greta * 0.4.2.9000 2022-10-18 [1] local #> here 1.0.1 2020-12-13 [1] CRAN (R 4.2.0) #> highr 0.9 2021-04-16 [1] CRAN (R 4.2.0) #> hms 1.1.2 2022-08-19 [1] CRAN (R 4.2.0) #> htmltools 0.5.3 2022-07-18 [1] CRAN (R 4.2.0) #> jsonlite 1.8.2 2022-10-02 [1] CRAN (R 4.2.0) #> knitr 1.40 2022-08-24 [1] CRAN (R 4.2.0) #> lattice 0.20-45 2021-09-22 [1] CRAN (R 4.2.1) #> lifecycle 1.0.3 2022-10-07 [1] CRAN (R 4.2.0) #> listenv 0.8.0 2019-12-05 [1] CRAN (R 4.2.0) #> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.2.0) #> Matrix 1.5-1 2022-09-13 [1] CRAN (R 4.2.0) #> parallelly 1.32.1 2022-07-21 [1] CRAN (R 4.2.0) #> pillar 1.8.1 2022-08-19 [1] CRAN (R 4.2.0) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.2.0) #> png 0.1-7 2013-12-03 [1] CRAN (R 4.2.0) #> prettyunits 1.1.1 2020-01-24 [1] CRAN (R 4.2.0) #> processx 3.7.0 2022-07-07 [1] CRAN (R 4.2.0) #> progress 1.2.2 2019-05-16 [1] CRAN (R 4.2.0) #> ps 1.7.1 2022-06-18 [1] CRAN (R 4.2.0) #> purrr 0.3.5 2022-10-06 [1] CRAN (R 4.2.0) #> R.cache 0.16.0 2022-07-21 [1] CRAN (R 4.2.0) #> R.methodsS3 1.8.2 2022-06-13 [1] CRAN (R 4.2.0) #> R.oo 1.25.0 2022-06-12 [1] CRAN (R 4.2.0) #> R.utils 2.12.0 2022-06-28 [1] CRAN (R 4.2.0) #> R6 2.5.1 2021-08-19 [1] CRAN (R 4.2.0) #> Rcpp 1.0.9 2022-07-08 [1] CRAN (R 4.2.0) #> reprex 2.0.2 2022-08-17 [1] CRAN (R 4.2.0) #> reticulate 1.26 2022-08-31 [1] CRAN (R 4.2.0) #> rlang 1.0.6 2022-09-24 [1] CRAN (R 4.2.0) #> rmarkdown 2.17 2022-10-07 [1] CRAN (R 4.2.0) #> rprojroot 2.0.3 2022-04-02 [1] CRAN (R 4.2.0) #> rstudioapi 0.14 2022-08-22 [1] CRAN (R 4.2.0) #> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.2.0) #> stringi 1.7.8 2022-07-11 [1] CRAN (R 4.2.0) #> stringr 1.4.1 2022-08-20 [1] CRAN (R 4.2.0) #> styler 1.7.0 2022-03-13 [1] CRAN (R 4.2.0) #> tensorflow 2.9.0 2022-05-21 [1] CRAN (R 4.2.0) #> tfautograph 0.3.2 2021-09-17 [1] CRAN (R 4.2.0) #> tfruns 1.5.1 2022-09-05 [1] CRAN (R 4.2.0) #> tibble 3.1.8 2022-07-22 [1] CRAN (R 4.2.0) #> utf8 1.2.2 2021-07-24 [1] CRAN (R 4.2.0) #> vctrs 0.4.2 2022-09-29 [1] CRAN (R 4.2.0) #> whisker 0.4 2019-08-28 [1] CRAN (R 4.2.0) #> withr 2.5.0 2022-03-03 [1] CRAN (R 4.2.0) #> xfun 0.33 2022-09-12 [1] CRAN (R 4.2.0) #> yaml 2.3.5 2022-02-21 [1] CRAN (R 4.2.0) #> #> [1] /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/library #> #> ─ Python configuration ─────────────────────────────────────────────────────── #> python: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/bin/python #> libpython: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/libpython3.8.dylib #> pythonhome: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2:/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2 #> version: 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:05:16) [Clang 12.0.1 ] #> numpy: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/numpy #> numpy_version: 1.23.2 #> tensorflow: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/tensorflow #> #> NOTE: Python version was forced by use_python function #> #> ────────────────────────────────────────────────────────────────────────────── ```

As reported in https://forum.greta-stats.org/t/problem-with-extra-samples-command-when-thin-pb-update-and-or-n-samples-are-not-multiples/338/2

hrlai commented 1 year ago

Okay this is weirdly timely because the following also failed on me this morning. Using the same example as above:

library(greta)

x <- rnorm(10)
mu <- normal(0, 5)
sigma <- lognormal(1, 0.1)
distribution(x) <- normal(mu, sigma)
m <- model(mu, sigma)

draws <- mcmc(m, n_samples = 2000, thin = 2, one_by_one = TRUE)

This errors:

 tensorflow error:
Error in py_call_impl(callable, dots$args, dots$keywords):
tensorflow.python.framework.errors_impl.UnimplementedError: TensorArray has size zero, but
element shape [?] is not fully defined. Currently only static shapes are supported when
packing zero-size TensorArrays. <... omitted ...>gather(math_ops.range(0, self.size()),
name=name) File "/tensorflow/python/ops/tensor_array_ops.py", line 323, in gather
element_shape=element_shape) File "/tensorflow/python/ops/gen_data_flow_ops.py", line 6705, in
tensor_array_gather_v3 element_shape=element_shape, name=name) File
"/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper op_def=op_def)
File "/tensorflow/python/util/deprecation.py", line 507, in new_func return func(*args,
**kwargs) File "/tensorflow/python/framework/ops.py", line 3616, in create_op op_def=op_def)
File "/tensorflow/python/framework/ops.py", line 2005, in __init__ self._traceback =
tf_s
Error in trace_list_batches[[1]] : subscript out of bounds

Note that this n_samples is a multiple of thin here, but one_by_one = TRUE seems to break things...