greta-dev / greta

simple and scalable statistical modelling in R
https://greta-stats.org
Other
524 stars 63 forks source link

TF2 error - Failure (`test_inference.R:293)`: mcmc supports rwmh sampler with normal/uniform proposals #575

Closed njtierney closed 2 weeks ago

njtierney commented 1 year ago
devtools::load_all(".")
#> ℹ Loading greta
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> 
#> ✔ Initialising python and checking dependencies ... done!
#> Loaded Tensorflow version 2.10.0
source(here::here("tests", "testthat", "helpers.R"))
x <- normal(0, 1)
  m <- model(x)
  expect_ok(draws <- mcmc(m,
    sampler = rwmh("normal"),
    n_samples = 100, warmup = 100,
    verbose = FALSE
  ))
#> Error: `expr` threw an unexpected error.
#> Message: greta hit a tensorflow error:
#> Error in py_call_impl(callable, dots$args, dots$keywords): RuntimeError:
#> Evaluation error: ValueError: Cannot reshape a tensor with 0 elements to shape
#> [1] (1 elements) for '{{node Reshape}} = Reshape[T=DT_DOUBLE,
#> Tshape=DT_INT32](Mul, Reshape/shape)' with input shapes: [0], [1] and with
#> input tensors computed as partial shapes: input[1] = [1]. .
#> Class:   simpleError/error/condition

Created on 2022-12-09 with reprex v2.0.2

Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.2.1 (2022-06-23) #> os macOS Monterey 12.3.1 #> system aarch64, darwin20 #> ui X11 #> language (EN) #> collate en_US.UTF-8 #> ctype en_US.UTF-8 #> tz Australia/Perth #> date 2022-12-09 #> pandoc 2.19.2 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/ (via rmarkdown) #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> ! package * version date (UTC) lib source #> abind 1.4-5 2016-07-21 [1] CRAN (R 4.2.0) #> backports 1.4.1 2021-12-13 [1] CRAN (R 4.2.0) #> base64enc 0.1-3 2015-07-28 [1] CRAN (R 4.2.0) #> brio 1.1.3 2021-11-30 [1] CRAN (R 4.2.0) #> cachem 1.0.6 2021-08-19 [1] CRAN (R 4.2.0) #> callr 3.7.3 2022-11-02 [1] CRAN (R 4.2.0) #> cli 3.4.1 2022-09-23 [1] CRAN (R 4.2.0) #> coda 0.19-4 2020-09-30 [1] CRAN (R 4.2.0) #> codetools 0.2-18 2020-11-04 [1] CRAN (R 4.2.1) #> crayon 1.5.2 2022-09-29 [1] CRAN (R 4.2.0) #> desc 1.4.2 2022-09-08 [1] CRAN (R 4.2.0) #> devtools 2.4.5 2022-10-11 [1] CRAN (R 4.2.0) #> digest 0.6.30 2022-10-18 [1] CRAN (R 4.2.0) #> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.2.0) #> evaluate 0.18 2022-11-07 [1] CRAN (R 4.2.0) #> fansi 1.0.3 2022-03-24 [1] CRAN (R 4.2.0) #> fastmap 1.1.0 2021-01-25 [1] CRAN (R 4.2.0) #> fs 1.5.2 2021-12-08 [1] CRAN (R 4.2.0) #> future 1.29.0 2022-11-06 [1] CRAN (R 4.2.0) #> globals 0.16.2 2022-11-21 [1] CRAN (R 4.2.1) #> glue 1.6.2 2022-02-24 [1] CRAN (R 4.2.0) #> VP greta * 0.4.2.9000 2022-09-08 [?] CRAN (R 4.2.0) (on disk 0.4.3) #> here 1.0.1 2020-12-13 [1] CRAN (R 4.2.0) #> highr 0.9 2021-04-16 [1] CRAN (R 4.2.0) #> hms 1.1.2 2022-08-19 [1] CRAN (R 4.2.0) #> htmltools 0.5.3 2022-07-18 [1] CRAN (R 4.2.0) #> htmlwidgets 1.5.4 2021-09-08 [1] CRAN (R 4.2.0) #> httpuv 1.6.6 2022-09-08 [1] CRAN (R 4.2.0) #> jsonlite 1.8.3 2022-10-21 [1] CRAN (R 4.2.0) #> knitr 1.41 2022-11-18 [1] CRAN (R 4.2.0) #> later 1.3.0 2021-08-18 [1] CRAN (R 4.2.0) #> lattice 0.20-45 2021-09-22 [1] CRAN (R 4.2.1) #> lifecycle 1.0.3 2022-10-07 [1] CRAN (R 4.2.0) #> listenv 0.8.0 2019-12-05 [1] CRAN (R 4.2.0) #> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.2.0) #> Matrix 1.5-3 2022-11-11 [1] CRAN (R 4.2.0) #> memoise 2.0.1 2021-11-26 [1] CRAN (R 4.2.0) #> mime 0.12 2021-09-28 [1] CRAN (R 4.2.0) #> miniUI 0.1.1.1 2018-05-18 [1] CRAN (R 4.2.0) #> parallelly 1.32.1 2022-07-21 [1] CRAN (R 4.2.0) #> pillar 1.8.1 2022-08-19 [1] CRAN (R 4.2.0) #> pkgbuild 1.4.0 2022-11-27 [1] CRAN (R 4.2.1) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.2.0) #> pkgload 1.3.2 2022-11-16 [1] CRAN (R 4.2.0) #> png 0.1-8 2022-11-29 [1] CRAN (R 4.2.0) #> prettyunits 1.1.1 2020-01-24 [1] CRAN (R 4.2.0) #> processx 3.8.0 2022-10-26 [1] CRAN (R 4.2.0) #> profvis 0.3.7 2020-11-02 [1] CRAN (R 4.2.0) #> progress 1.2.2 2019-05-16 [1] CRAN (R 4.2.0) #> promises 1.2.0.1 2021-02-11 [1] CRAN (R 4.2.0) #> ps 1.7.2 2022-10-26 [1] CRAN (R 4.2.0) #> purrr 0.3.5 2022-10-06 [1] CRAN (R 4.2.0) #> R.cache 0.16.0 2022-07-21 [1] CRAN (R 4.2.0) #> R.methodsS3 1.8.2 2022-06-13 [1] CRAN (R 4.2.0) #> R.oo 1.25.0 2022-06-12 [1] CRAN (R 4.2.0) #> R.utils 2.12.2 2022-11-11 [1] CRAN (R 4.2.0) #> R6 2.5.1 2021-08-19 [1] CRAN (R 4.2.0) #> Rcpp 1.0.9 2022-07-08 [1] CRAN (R 4.2.0) #> remotes 2.4.2 2021-11-30 [1] CRAN (R 4.2.0) #> reprex 2.0.2 2022-08-17 [1] CRAN (R 4.2.0) #> reticulate 1.26 2022-08-31 [1] CRAN (R 4.2.0) #> rlang 1.0.6 2022-09-24 [1] CRAN (R 4.2.0) #> rmarkdown 2.18 2022-11-09 [1] CRAN (R 4.2.0) #> rprojroot 2.0.3 2022-04-02 [1] CRAN (R 4.2.0) #> rstudioapi 0.14 2022-08-22 [1] CRAN (R 4.2.0) #> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.2.0) #> shiny 1.7.3 2022-10-25 [1] CRAN (R 4.2.0) #> stringi 1.7.8 2022-07-11 [1] CRAN (R 4.2.0) #> stringr 1.5.0 2022-12-02 [1] CRAN (R 4.2.0) #> styler 1.8.1 2022-11-07 [1] CRAN (R 4.2.0) #> tensorflow 2.9.0 2022-05-21 [1] CRAN (R 4.2.0) #> testthat * 3.1.5 2022-10-08 [1] CRAN (R 4.2.0) #> tfautograph 0.3.2 2021-09-17 [1] CRAN (R 4.2.0) #> tfruns 1.5.1 2022-09-05 [1] CRAN (R 4.2.0) #> urlchecker 1.0.1 2021-11-30 [1] CRAN (R 4.2.0) #> usethis 2.1.6 2022-05-25 [1] CRAN (R 4.2.0) #> utf8 1.2.2 2021-07-24 [1] CRAN (R 4.2.0) #> vctrs 0.5.1 2022-11-16 [1] CRAN (R 4.2.0) #> whisker 0.4 2019-08-28 [1] CRAN (R 4.2.0) #> withr 2.5.0 2022-03-03 [1] CRAN (R 4.2.0) #> xfun 0.35 2022-11-16 [1] CRAN (R 4.2.0) #> xtable 1.8-4 2019-04-21 [1] CRAN (R 4.2.0) #> yaml 2.3.6 2022-10-18 [1] CRAN (R 4.2.0) #> yesno 0.1.2 2020-07-10 [1] CRAN (R 4.2.0) #> #> [1] /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/library #> #> V ── Loaded and on-disk version mismatch. #> P ── Loaded and on-disk path mismatch. #> #> ─ Python configuration ─────────────────────────────────────────────────────── #> python: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/bin/python #> libpython: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/libpython3.8.dylib #> pythonhome: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2:/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2 #> version: 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:05:16) [Clang 12.0.1 ] #> numpy: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/numpy #> numpy_version: 1.23.2 #> tensorflow: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/tensorflow #> #> NOTE: Python version was forced by use_python function #> #> ────────────────────────────────────────────────────────────────────────────── ```
njtierney commented 1 year ago

Related - TF2 error - Failure (test_inference.R:305): mcmc supports rwmh sampler with uniform proposals

devtools::load_all(".")
#> ℹ Loading greta
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> 
#> ✔ Initialising python and checking dependencies ... done!
#> Loaded Tensorflow version 2.10.0
source(here::here("tests", "testthat", "helpers.R"))
set.seed(5)
  x <- uniform(0, 1)
  m <- model(x)
  expect_ok(draws <- mcmc(m,
    sampler = rwmh("uniform"),
    n_samples = 100, warmup = 100,
    verbose = FALSE
  ))
#> Error: `expr` threw an unexpected error.
#> Message: greta hit a tensorflow error:
#> Error in py_call_impl(callable, dots$args, dots$keywords): RuntimeError:
#> Evaluation error: ValueError: Cannot reshape a tensor with 0 elements to shape
#> [1] (1 elements) for '{{node Reshape}} = Reshape[T=DT_DOUBLE,
#> Tshape=DT_INT32](Mul, Reshape/shape)' with input shapes: [0], [1] and with
#> input tensors computed as partial shapes: input[1] = [1]. .
#> Class:   simpleError/error/condition

Created on 2022-12-09 with reprex v2.0.2

Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.2.1 (2022-06-23) #> os macOS Monterey 12.3.1 #> system aarch64, darwin20 #> ui X11 #> language (EN) #> collate en_US.UTF-8 #> ctype en_US.UTF-8 #> tz Australia/Perth #> date 2022-12-09 #> pandoc 2.19.2 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/ (via rmarkdown) #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> ! package * version date (UTC) lib source #> abind 1.4-5 2016-07-21 [1] CRAN (R 4.2.0) #> backports 1.4.1 2021-12-13 [1] CRAN (R 4.2.0) #> base64enc 0.1-3 2015-07-28 [1] CRAN (R 4.2.0) #> brio 1.1.3 2021-11-30 [1] CRAN (R 4.2.0) #> cachem 1.0.6 2021-08-19 [1] CRAN (R 4.2.0) #> callr 3.7.3 2022-11-02 [1] CRAN (R 4.2.0) #> cli 3.4.1 2022-09-23 [1] CRAN (R 4.2.0) #> coda 0.19-4 2020-09-30 [1] CRAN (R 4.2.0) #> codetools 0.2-18 2020-11-04 [1] CRAN (R 4.2.1) #> crayon 1.5.2 2022-09-29 [1] CRAN (R 4.2.0) #> desc 1.4.2 2022-09-08 [1] CRAN (R 4.2.0) #> devtools 2.4.5 2022-10-11 [1] CRAN (R 4.2.0) #> digest 0.6.30 2022-10-18 [1] CRAN (R 4.2.0) #> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.2.0) #> evaluate 0.18 2022-11-07 [1] CRAN (R 4.2.0) #> fansi 1.0.3 2022-03-24 [1] CRAN (R 4.2.0) #> fastmap 1.1.0 2021-01-25 [1] CRAN (R 4.2.0) #> fs 1.5.2 2021-12-08 [1] CRAN (R 4.2.0) #> future 1.29.0 2022-11-06 [1] CRAN (R 4.2.0) #> globals 0.16.2 2022-11-21 [1] CRAN (R 4.2.1) #> glue 1.6.2 2022-02-24 [1] CRAN (R 4.2.0) #> VP greta * 0.4.2.9000 2022-09-08 [?] CRAN (R 4.2.0) (on disk 0.4.3) #> here 1.0.1 2020-12-13 [1] CRAN (R 4.2.0) #> highr 0.9 2021-04-16 [1] CRAN (R 4.2.0) #> hms 1.1.2 2022-08-19 [1] CRAN (R 4.2.0) #> htmltools 0.5.3 2022-07-18 [1] CRAN (R 4.2.0) #> htmlwidgets 1.5.4 2021-09-08 [1] CRAN (R 4.2.0) #> httpuv 1.6.6 2022-09-08 [1] CRAN (R 4.2.0) #> jsonlite 1.8.3 2022-10-21 [1] CRAN (R 4.2.0) #> knitr 1.41 2022-11-18 [1] CRAN (R 4.2.0) #> later 1.3.0 2021-08-18 [1] CRAN (R 4.2.0) #> lattice 0.20-45 2021-09-22 [1] CRAN (R 4.2.1) #> lifecycle 1.0.3 2022-10-07 [1] CRAN (R 4.2.0) #> listenv 0.8.0 2019-12-05 [1] CRAN (R 4.2.0) #> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.2.0) #> Matrix 1.5-3 2022-11-11 [1] CRAN (R 4.2.0) #> memoise 2.0.1 2021-11-26 [1] CRAN (R 4.2.0) #> mime 0.12 2021-09-28 [1] CRAN (R 4.2.0) #> miniUI 0.1.1.1 2018-05-18 [1] CRAN (R 4.2.0) #> parallelly 1.32.1 2022-07-21 [1] CRAN (R 4.2.0) #> pillar 1.8.1 2022-08-19 [1] CRAN (R 4.2.0) #> pkgbuild 1.4.0 2022-11-27 [1] CRAN (R 4.2.1) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.2.0) #> pkgload 1.3.2 2022-11-16 [1] CRAN (R 4.2.0) #> png 0.1-8 2022-11-29 [1] CRAN (R 4.2.0) #> prettyunits 1.1.1 2020-01-24 [1] CRAN (R 4.2.0) #> processx 3.8.0 2022-10-26 [1] CRAN (R 4.2.0) #> profvis 0.3.7 2020-11-02 [1] CRAN (R 4.2.0) #> progress 1.2.2 2019-05-16 [1] CRAN (R 4.2.0) #> promises 1.2.0.1 2021-02-11 [1] CRAN (R 4.2.0) #> ps 1.7.2 2022-10-26 [1] CRAN (R 4.2.0) #> purrr 0.3.5 2022-10-06 [1] CRAN (R 4.2.0) #> R.cache 0.16.0 2022-07-21 [1] CRAN (R 4.2.0) #> R.methodsS3 1.8.2 2022-06-13 [1] CRAN (R 4.2.0) #> R.oo 1.25.0 2022-06-12 [1] CRAN (R 4.2.0) #> R.utils 2.12.2 2022-11-11 [1] CRAN (R 4.2.0) #> R6 2.5.1 2021-08-19 [1] CRAN (R 4.2.0) #> Rcpp 1.0.9 2022-07-08 [1] CRAN (R 4.2.0) #> remotes 2.4.2 2021-11-30 [1] CRAN (R 4.2.0) #> reprex 2.0.2 2022-08-17 [1] CRAN (R 4.2.0) #> reticulate 1.26 2022-08-31 [1] CRAN (R 4.2.0) #> rlang 1.0.6 2022-09-24 [1] CRAN (R 4.2.0) #> rmarkdown 2.18 2022-11-09 [1] CRAN (R 4.2.0) #> rprojroot 2.0.3 2022-04-02 [1] CRAN (R 4.2.0) #> rstudioapi 0.14 2022-08-22 [1] CRAN (R 4.2.0) #> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.2.0) #> shiny 1.7.3 2022-10-25 [1] CRAN (R 4.2.0) #> stringi 1.7.8 2022-07-11 [1] CRAN (R 4.2.0) #> stringr 1.5.0 2022-12-02 [1] CRAN (R 4.2.0) #> styler 1.8.1 2022-11-07 [1] CRAN (R 4.2.0) #> tensorflow 2.9.0 2022-05-21 [1] CRAN (R 4.2.0) #> testthat * 3.1.5 2022-10-08 [1] CRAN (R 4.2.0) #> tfautograph 0.3.2 2021-09-17 [1] CRAN (R 4.2.0) #> tfruns 1.5.1 2022-09-05 [1] CRAN (R 4.2.0) #> urlchecker 1.0.1 2021-11-30 [1] CRAN (R 4.2.0) #> usethis 2.1.6 2022-05-25 [1] CRAN (R 4.2.0) #> utf8 1.2.2 2021-07-24 [1] CRAN (R 4.2.0) #> vctrs 0.5.1 2022-11-16 [1] CRAN (R 4.2.0) #> whisker 0.4 2019-08-28 [1] CRAN (R 4.2.0) #> withr 2.5.0 2022-03-03 [1] CRAN (R 4.2.0) #> xfun 0.35 2022-11-16 [1] CRAN (R 4.2.0) #> xtable 1.8-4 2019-04-21 [1] CRAN (R 4.2.0) #> yaml 2.3.6 2022-10-18 [1] CRAN (R 4.2.0) #> yesno 0.1.2 2020-07-10 [1] CRAN (R 4.2.0) #> #> [1] /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/library #> #> V ── Loaded and on-disk version mismatch. #> P ── Loaded and on-disk path mismatch. #> #> ─ Python configuration ─────────────────────────────────────────────────────── #> python: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/bin/python #> libpython: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/libpython3.8.dylib #> pythonhome: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2:/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2 #> version: 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:05:16) [Clang 12.0.1 ] #> numpy: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/numpy #> numpy_version: 1.23.2 #> tensorflow: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/tensorflow #> #> NOTE: Python version was forced by use_python function #> #> ────────────────────────────────────────────────────────────────────────────── ```
njtierney commented 1 year ago

Current progress is linked in file, tf2-test-inference-mcmcm-rwmh-normal-uniform.qmd locally

The file contains the following:

We get an error for rwmh like so

x <- normal(0, 1)
m <- model(x)

# this errors
draws <- mcmc(m,
              sampler = rwmh("normal"),
              n_samples = 100, 
              warmup = 100,
              verbose = FALSE)

The error being

Error: greta hit a tensorflow error:
Error in py_call_impl(callable,
dots$args, dots$keywords):
RuntimeError: Evaluation error:
ValueError: Cannot reshape a tensor
with 0 elements to shape [1] (1
elements) for '{{node Reshape}} =
Reshape[T=DT_DOUBLE,
Tshape=DT_INT32](Mul, Reshape/shape)'
with input shapes: [0], [1] and with
input tensors computed as partial
shapes: input[1] = [1]. .

Minimally, the code to fail is:

x <- normal(0, 1)
m <- model(x)
draws <- mcmc(m, sampler = rwmh())

So my thought is that there is something the wrong shape or dimension here related to rwmh sampler?

We don't get this error with the default sampler, hmc.

x <- normal(0, 1)
m <- model(x)
draws <- mcmc(m, sampler = hmc())

Let's go a journey - first step into mcmc

debugonce(mcmc)
x <- normal(0, 1)
m <- model(x)
draws <- mcmc(m, sampler = rwmh())

OK so the error occurs here in sample_carefully

#| eval: false
      result <- cleanly(
        self$tf_evaluate_sample_batch(
          free_state = tensorflow::as_tensor(
            free_state,
            dtype = tf_float()
          ),
          sampler_burst_length = tensorflow::as_tensor(sampler_burst_length),
          sampler_thin = tensorflow::as_tensor(sampler_thin),
          sampler_param_vec = tensorflow::as_tensor(
            sampler_param_vec,
            dtype = tf_float(),
            shape = length(sampler_param_vec)
          )
        )
      )

I'll place a browser there and take a look

x <- normal(0, 1)
m <- model(x)
draws <- mcmc(m, sampler = rwmh())

OK and because the code uses tf_function it needs to be undone there somehow

Breaking down this code

result <- cleanly(
        self$tf_evaluate_sample_batch(
          free_state = tensorflow::as_tensor(
            free_state,
            dtype = tf_float()
          ),
          sampler_burst_length = tensorflow::as_tensor(sampler_burst_length),
          sampler_thin = tensorflow::as_tensor(sampler_thin),
          sampler_param_vec = tensorflow::as_tensor(
            sampler_param_vec,
            dtype = tf_float(),
            shape = length(sampler_param_vec)
          )
        )
      )

The outputs of the free_state, sampler_burst_length, and sampler_thin, sampler_param_vec are:

free_state
sampler_burst_length
sampler_thin
sampler_param_vec
Browse[2]> free_state
     all_forward_variable_1
[1,]             0.04096190
[2,]             0.01090323
[3,]            -0.15063224
[4,]             0.08087937
Browse[2]> sampler_burst_length
[1] 3
Browse[2]> sampler_thin
[1] 1
Browse[2]> sampler_param_vec
rwmh_epsilon rwmh_diag_sd 
         0.1          1.0 

And what does this comparative part look like in hmc()?

x <- normal(0, 1)
m <- model(x)
draws <- mcmc(m, sampler = hmc())
free_state
sampler_burst_length
sampler_thin
sampler_param_vec
Browse[2]> free_state
     all_forward_variable_1
[1,]             0.13124380
[2,]            -0.12720201
[3,]            -0.15659195
[4,]             0.08557503
Browse[2]> sampler_burst_length
[1] 3
Browse[2]> sampler_thin
[1] 1
Browse[2]> sampler_param_vec
      hmc_l hmc_epsilon hmc_diag_sd 
        6.0         0.1         1.0 

OK so it's not going to awry at that point...

Overall, tf_evaluate_sample_batch is a tf function, written like so:

#| eval: false
self$tf_evaluate_sample_batch <- tensorflow::tf_function(
        f = self$define_tf_draws,
        input_signature = list(
          # free state
          tf$TensorSpec(shape = list(NULL, self$n_free),
                        dtype = tf_float()),
          # sampler_burst_length
          tf$TensorSpec(shape = list(),
                        dtype = tf$int32),
          # sampler_thin
          tf$TensorSpec(shape = list(),
                        dtype = tf$int32),
          # sampler_param_vec
          tf$TensorSpec(shape = list(
            length(
              unlist(
                self$sampler_parameter_values()
                )
              )
            ),
                        dtype = tf_float())
        )
      )

I can't seem to work out how to remove the tf_function part of so we can debug it.

I tried to just use self$define_tf_draws as it's own function, defining the inputs, free_state etc, as either the named variables, or as the Tensors above, but I think I'm missing something. Not sure how to proceed further from here.

hrlai commented 1 year ago

Hi @njtierney I got similar error from mcmc today when testing out the TF2 version from https://github.com/greta-dev/greta/pull/534

mcmc(m,
              sampler = hmc(10, 15),
              warmup = 4000,
              n_samples = 1000,
              chains = 1,
              one_by_one = TRUE)

gave:

    warmup                                           0/4000 | eta:  ?s          
Error in self$sample_carefully(free_state = self$free_state, sampler_burst_length = as.integer(n_samples),  : 
  object 'n_samples' not found

Sorry I wasn't able to give a minimal working example to reproduce the error now. But will keep looking into it. Posting here in case it is something that you're tackling at the moment.