greta-dev / greta

simple and scalable statistical modelling in R
https://greta-stats.org
Other
524 stars 63 forks source link

M1 TF2 dev errors for new optimisers #564

Closed njtierney closed 2 weeks ago

njtierney commented 1 year ago

Need to update optimizers - which is done in #550, but need to work to manage the merging of that branch into TF2 branch as it is probably a bit stale by now

Should resolve #551

Below is an example of some code that currently fails with opt

devtools::load_all(".")
#> ℹ Loading greta
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> 
#> ✔ Initialising python and checking dependencies ... done!
#> Loaded Tensorflow version 2.10.0
source("tests/testthat/helpers.R")
sd <- runif(5)
x <- rnorm(5, 2, 0.1)
z <- variable(dim = 5)
distribution(x) <- normal(z, sd)

m <- model(z)
o <- opt(m, hessian = TRUE)
#> Error in py_get_attr_impl(x, name, silent): AttributeError: module 'tensorflow' has no attribute 'contrib'

Created on 2022-10-12 by the reprex package (v2.0.1)

Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.2.1 (2022-06-23) #> os macOS Monterey 12.3.1 #> system aarch64, darwin20 #> ui X11 #> language (EN) #> collate en_AU.UTF-8 #> ctype en_AU.UTF-8 #> tz Australia/Perth #> date 2022-10-12 #> pandoc 2.19.2 @ /Applications/RStudio.app/Contents/MacOS/quarto/bin/tools/ (via rmarkdown) #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> ! package * version date (UTC) lib source #> abind 1.4-5 2016-07-21 [1] CRAN (R 4.2.0) #> backports 1.4.1 2021-12-13 [1] CRAN (R 4.2.0) #> base64enc 0.1-3 2015-07-28 [1] CRAN (R 4.2.0) #> brio 1.1.3 2021-11-30 [1] CRAN (R 4.2.0) #> cachem 1.0.6 2021-08-19 [1] CRAN (R 4.2.0) #> callr 3.7.2 2022-08-22 [1] CRAN (R 4.2.0) #> cli 3.3.0.9000 2022-06-15 [1] Github (r-lib/cli@31a5db5) #> coda 0.19-4 2020-09-30 [1] CRAN (R 4.2.0) #> codetools 0.2-18 2020-11-04 [1] CRAN (R 4.2.1) #> crayon 1.5.1 2022-03-26 [1] CRAN (R 4.2.0) #> desc 1.4.2 2022-09-08 [1] CRAN (R 4.2.0) #> devtools 2.4.4 2022-07-20 [1] CRAN (R 4.2.0) #> digest 0.6.29 2021-12-01 [1] CRAN (R 4.2.0) #> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.2.0) #> evaluate 0.16 2022-08-09 [1] CRAN (R 4.2.0) #> fansi 1.0.3 2022-03-24 [1] CRAN (R 4.2.0) #> fastmap 1.1.0 2021-01-25 [1] CRAN (R 4.2.0) #> fs 1.5.2 2021-12-08 [1] CRAN (R 4.2.0) #> future 1.27.0 2022-07-22 [1] CRAN (R 4.2.0) #> globals 0.16.0 2022-08-05 [1] CRAN (R 4.2.0) #> glue 1.6.2 2022-02-24 [1] CRAN (R 4.2.0) #> P greta * 0.4.2.9000 2022-10-11 [?] load_all() #> here 1.0.1 2020-12-13 [1] CRAN (R 4.2.0) #> highr 0.9 2021-04-16 [1] CRAN (R 4.2.0) #> hms 1.1.1 2021-09-26 [1] CRAN (R 4.2.0) #> htmltools 0.5.3 2022-07-18 [1] CRAN (R 4.2.0) #> htmlwidgets 1.5.4 2021-09-08 [1] CRAN (R 4.2.0) #> httpuv 1.6.5 2022-01-05 [1] CRAN (R 4.2.0) #> jsonlite 1.8.0 2022-02-22 [1] CRAN (R 4.2.0) #> knitr 1.40 2022-08-24 [1] CRAN (R 4.2.0) #> later 1.3.0 2021-08-18 [1] CRAN (R 4.2.0) #> lattice 0.20-45 2021-09-22 [1] CRAN (R 4.2.1) #> lifecycle 1.0.2 2022-09-09 [1] CRAN (R 4.2.0) #> listenv 0.8.0 2019-12-05 [1] CRAN (R 4.2.0) #> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.2.0) #> Matrix 1.4-1 2022-03-23 [1] CRAN (R 4.2.1) #> memoise 2.0.1 2021-11-26 [1] CRAN (R 4.2.0) #> mime 0.12 2021-09-28 [1] CRAN (R 4.2.0) #> miniUI 0.1.1.1 2018-05-18 [1] CRAN (R 4.2.0) #> parallelly 1.32.1 2022-07-21 [1] CRAN (R 4.2.0) #> pillar 1.8.1 2022-08-19 [1] CRAN (R 4.2.0) #> pkgbuild 1.3.1 2021-12-20 [1] CRAN (R 4.2.0) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.2.0) #> pkgload 1.3.0 2022-06-27 [1] CRAN (R 4.2.0) #> png 0.1-7 2013-12-03 [1] CRAN (R 4.2.0) #> prettyunits 1.1.1 2020-01-24 [1] CRAN (R 4.2.0) #> processx 3.7.0 2022-07-07 [1] CRAN (R 4.2.0) #> profvis 0.3.7 2020-11-02 [1] CRAN (R 4.2.0) #> progress 1.2.2 2019-05-16 [1] CRAN (R 4.2.0) #> promises 1.2.0.1 2021-02-11 [1] CRAN (R 4.2.0) #> ps 1.7.1 2022-06-18 [1] CRAN (R 4.2.0) #> purrr 0.3.4 2020-04-17 [1] CRAN (R 4.2.0) #> R.cache 0.16.0 2022-07-21 [1] CRAN (R 4.2.0) #> R.methodsS3 1.8.2 2022-06-13 [1] CRAN (R 4.2.0) #> R.oo 1.25.0 2022-06-12 [1] CRAN (R 4.2.0) #> R.utils 2.12.0 2022-06-28 [1] CRAN (R 4.2.0) #> R6 2.5.1 2021-08-19 [1] CRAN (R 4.2.0) #> Rcpp 1.0.9 2022-07-08 [1] CRAN (R 4.2.0) #> remotes 2.4.2 2021-11-30 [1] CRAN (R 4.2.0) #> reprex 2.0.1 2021-08-05 [1] CRAN (R 4.2.0) #> reticulate 1.25 2022-05-11 [1] CRAN (R 4.2.0) #> rlang 1.0.5 2022-08-31 [1] CRAN (R 4.2.0) #> rmarkdown 2.16 2022-08-24 [1] CRAN (R 4.2.0) #> rprojroot 2.0.3 2022-04-02 [1] CRAN (R 4.2.0) #> rstudioapi 0.13 2020-11-12 [1] CRAN (R 4.2.0) #> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.2.0) #> shiny 1.7.2 2022-07-19 [1] CRAN (R 4.2.0) #> stringi 1.7.8 2022-07-11 [1] CRAN (R 4.2.0) #> stringr 1.4.1 2022-08-20 [1] CRAN (R 4.2.0) #> styler 1.7.0 2022-03-13 [1] CRAN (R 4.2.0) #> tensorflow 2.9.0 2022-05-21 [1] CRAN (R 4.2.0) #> testthat * 3.1.4 2022-04-26 [1] CRAN (R 4.2.0) #> tfautograph 0.3.2 2021-09-17 [1] CRAN (R 4.2.0) #> tfruns 1.5.0 2021-02-26 [1] CRAN (R 4.2.0) #> tibble 3.1.8 2022-07-22 [1] CRAN (R 4.2.0) #> urlchecker 1.0.1 2021-11-30 [1] CRAN (R 4.2.0) #> usethis 2.1.6 2022-05-25 [1] CRAN (R 4.2.0) #> utf8 1.2.2 2021-07-24 [1] CRAN (R 4.2.0) #> vctrs 0.4.1 2022-04-13 [1] CRAN (R 4.2.0) #> whisker 0.4 2019-08-28 [1] CRAN (R 4.2.0) #> withr 2.5.0 2022-03-03 [1] CRAN (R 4.2.0) #> xfun 0.33 2022-09-12 [1] CRAN (R 4.2.0) #> xtable 1.8-4 2019-04-21 [1] CRAN (R 4.2.0) #> yaml 2.3.5 2022-02-21 [1] CRAN (R 4.2.0) #> yesno 0.1.2 2020-07-10 [1] CRAN (R 4.2.0) #> #> [1] /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/library #> #> P ── Loaded and on-disk path mismatch. #> #> ─ Python configuration ─────────────────────────────────────────────────────── #> python: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/bin/python #> libpython: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/libpython3.8.dylib #> pythonhome: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2:/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2 #> version: 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:05:16) [Clang 12.0.1 ] #> numpy: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/numpy #> numpy_version: 1.23.2 #> tensorflow: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/tensorflow #> #> NOTE: Python version was forced by use_python function #> #> ────────────────────────────────────────────────────────────────────────────── ```
njtierney commented 1 year ago
x <- rnorm(5, 2, 0.1)
z <- variable(dim = 5)
distribution(x) <- normal(z, 0.1)

m <- model(z)
o <- opt(model = m,
         initial_values = initials(z = rnorm(5)),
         optimiser = gradient_descent())

Putting a browser in the optimiser initialisation, we end up at:

# set up the model
initialize = function(initial_values,
                      model,
                      name,
                      method,
                      parameters,
                      other_args,
                      max_iterations,
                      tolerance,
                      adjust) {
  super$initialize(initial_values,
                   model,
                   parameters = list(),
                   seed = get_seed()
  )

  self$name <- name
  self$method <- method
  self$parameters <- parameters
  self$other_args <- other_args
  self$max_iterations <- as.integer(max_iterations)
  self$tolerance <- tolerance
  self$adjust <- adjust

  if ("uses_callbacks" %in% names(other_args)) {
    self$uses_callbacks <- other_args$uses_callbacks
  }

  browser()
  self$create_optimiser_objective()
  self$create_tf_minimiser()
}

We create the optimiser objective perfectly fine, then let's step into create_tf_minimizer()

create_tf_minimiser = function() {
  # browser()
  dag <- self$model$dag
  tfe <- dag$tf_environment

  self$sanitise_dtypes()

  optimise_fun <- eval(parse(text = self$method))
  # dag$on_graph(
  tfe$tf_optimiser <- do.call(
    optimise_fun,
    self$parameters
  )
  # )

  if (self$adjust) {
    browser()
    dag$tf_run(train <- tf_optimiser$minimize(optimiser_objective_adj))
    })
  } else {
    dag$tf_run(train <- tf_optimiser$minimize(optimiser_objective))
  }
}

Now, tfe$optimiser_objective_adj is

tf.Tensor([232.99020842], shape=(1), dtype=float64)

and running this:

tfe$tf_optimiser$minimize(tfe$optimiser_objective_adj)
# Error in py_call_impl(callable, dots$args, dots$keywords) : 
#   TypeError: minimize() missing 1 required positional argument: 'var_list'

OK fine - let's put the free state in the var_list FYI the free_state is

#> self$free_state
#           [,1]     [,2]        [,3]       [,4]       [,5]
# [1,] -1.665815 1.807032 0.005928454 -0.2735771 -0.8771901

SO

tfe$tf_optimiser$minimize(
  tfe$optimiser_objective_adj,
  var_list = tf$Variable(initial_value = self$free_state)
)

# Error in py_call_impl(callable, dots$args, dots$keywords) : 
#   ValueError: `tape` is required when a `Tensor` loss is passed. Received: loss=[232.99020842], tape=None.

OK

OK, let's pass gradient tape

tfe$tf_optimiser$minimize(tfe$optimiser_objective,
                          var_list = tf$Variable(self$free_state),
                          tape = tf$GradientTape())

# Error in py_call_impl(callable, dots$args, dots$keywords) : 
#   ValueError: No gradients provided for any variable: (['Variable:0'],). Provided `grads_and_vars` is ((None, <tf.Variable 'Variable:0' shape=(1, 5) dtype=float64, numpy=array([[-1.66581522,  1.80703249,  0.00592845, -0.27357705, -0.8771901 ]])>),).

Now I'm not sure what to make of this - I feel like I'm missing something obvious, like we are supposed to be optimising over something else?

njtierney commented 1 year ago

This error is to do with accessing hessians in the new optimiser interface, I'll try and make some more specific issues for the specific issues we are encountering here

njtierney commented 1 year ago

No works for all optimisers, as of https://github.com/njtierney/greta/commit/eef698f23effbc97595d61c3377cdfc1c8686de1