greta-dev / greta

simple and scalable statistical modelling in R
https://greta-stats.org
Other
518 stars 63 forks source link

TF2 error - Failure (`test_distributions.R:1101)`: Cholesky factor of Wishart should be an upper triangular matrix #593

Open njtierney opened 8 months ago

njtierney commented 8 months ago

This error is produced because the Cholesky factorisation of a Wishart produces a bunch of 1s.

Note that this is using the branch in #587

It is related to #585 and is currently being addressed in #587

library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply
x <- wishart(df = 4, Sigma = diag(3))
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#> 
chol_x <- chol(x)
calc_chol <- calculate(x, chol_x, nsim = 1)
calc_chol
#> $x
#> , , 1
#> 
#>          [,1]       [,2]      [,3]
#> [1,] 2.682823 -0.3437398 -1.058342
#> 
#> , , 2
#> 
#>            [,1]     [,2]     [,3]
#> [1,] -0.3437398 3.249632 1.320124
#> 
#> , , 3
#> 
#>           [,1]     [,2]    [,3]
#> [1,] -1.058342 1.320124 1.29027
#> 
#> 
#> $chol_x
#> , , 1
#> 
#>      [,1] [,2] [,3]
#> [1,]    1    1    1
#> 
#> , , 2
#> 
#>      [,1] [,2] [,3]
#> [1,]    1    1    1
#> 
#> , , 3
#> 
#>      [,1] [,2] [,3]
#> [1,]    1    1    1
# calc_chol$chol_x should be upper tri, but is just 1s

Created on 2023-11-09 with reprex v2.0.2

Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.3.2 (2023-10-31) #> os macOS Sonoma 14.0 #> system aarch64, darwin20 #> ui X11 #> language (EN) #> collate en_US.UTF-8 #> ctype en_US.UTF-8 #> tz Australia/Hobart #> date 2023-11-09 #> pandoc 3.1.1 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/ (via rmarkdown) #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> package * version date (UTC) lib source #> backports 1.4.1 2021-12-13 [1] CRAN (R 4.3.0) #> base64enc 0.1-3 2015-07-28 [1] CRAN (R 4.3.0) #> callr 3.7.3 2022-11-02 [1] CRAN (R 4.3.0) #> cli 3.6.1 2023-03-23 [1] CRAN (R 4.3.0) #> coda 0.19-4 2020-09-30 [1] CRAN (R 4.3.0) #> codetools 0.2-19 2023-02-01 [1] CRAN (R 4.3.2) #> crayon 1.5.2 2022-09-29 [1] CRAN (R 4.3.0) #> digest 0.6.33 2023-07-07 [1] CRAN (R 4.3.0) #> evaluate 0.23 2023-11-01 [1] CRAN (R 4.3.1) #> fastmap 1.1.1 2023-02-24 [1] CRAN (R 4.3.0) #> fs 1.6.3 2023-07-20 [1] CRAN (R 4.3.0) #> future 1.33.0 2023-07-01 [1] CRAN (R 4.3.0) #> globals 0.16.2 2022-11-21 [1] CRAN (R 4.3.0) #> glue 1.6.2 2022-02-24 [1] CRAN (R 4.3.0) #> greta * 0.4.3.9000 2023-11-09 [1] local #> hms 1.1.3 2023-03-21 [1] CRAN (R 4.3.0) #> htmltools 0.5.7 2023-11-03 [1] CRAN (R 4.3.1) #> jsonlite 1.8.7 2023-06-29 [1] CRAN (R 4.3.0) #> knitr 1.45 2023-10-30 [1] CRAN (R 4.3.1) #> lattice 0.21-9 2023-10-01 [1] CRAN (R 4.3.2) #> lifecycle 1.0.3 2022-10-07 [1] CRAN (R 4.3.0) #> listenv 0.9.0 2022-12-16 [1] CRAN (R 4.3.0) #> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.3.0) #> Matrix 1.6-1.1 2023-09-18 [1] CRAN (R 4.3.2) #> parallelly 1.36.0 2023-05-26 [1] CRAN (R 4.3.0) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.3.0) #> png 0.1-8 2022-11-29 [1] CRAN (R 4.3.0) #> prettyunits 1.2.0 2023-09-24 [1] CRAN (R 4.3.1) #> processx 3.8.2 2023-06-30 [1] CRAN (R 4.3.0) #> progress 1.2.2 2019-05-16 [1] CRAN (R 4.3.0) #> ps 1.7.5 2023-04-18 [1] CRAN (R 4.3.0) #> purrr 1.0.2 2023-08-10 [1] CRAN (R 4.3.0) #> R.cache 0.16.0 2022-07-21 [1] CRAN (R 4.3.0) #> R.methodsS3 1.8.2 2022-06-13 [1] CRAN (R 4.3.0) #> R.oo 1.25.0 2022-06-12 [1] CRAN (R 4.3.0) #> R.utils 2.12.2 2022-11-11 [1] CRAN (R 4.3.0) #> R6 2.5.1 2021-08-19 [1] CRAN (R 4.3.0) #> Rcpp 1.0.11 2023-07-06 [1] CRAN (R 4.3.0) #> reprex 2.0.2 2022-08-17 [1] CRAN (R 4.3.0) #> reticulate 1.34.0 2023-10-12 [1] CRAN (R 4.3.1) #> rlang 1.1.1 2023-04-28 [1] CRAN (R 4.3.0) #> rmarkdown 2.25 2023-09-18 [1] CRAN (R 4.3.1) #> rstudioapi 0.15.0 2023-07-07 [1] CRAN (R 4.3.0) #> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.3.0) #> styler 1.9.1 2023-03-04 [1] CRAN (R 4.3.0) #> tensorflow 2.14.0 2023-09-28 [1] CRAN (R 4.3.1) #> tfautograph 0.3.2 2021-09-17 [1] CRAN (R 4.3.0) #> tfruns 1.5.1 2022-09-05 [1] CRAN (R 4.3.0) #> vctrs 0.6.4 2023-10-12 [1] CRAN (R 4.3.1) #> whisker 0.4.1 2022-12-05 [1] CRAN (R 4.3.0) #> withr 2.5.2 2023-10-30 [1] CRAN (R 4.3.1) #> xfun 0.41 2023-11-01 [1] CRAN (R 4.3.1) #> yaml 2.3.7 2023-01-23 [1] CRAN (R 4.3.0) #> #> [1] /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/library #> #> ─ Python configuration ─────────────────────────────────────────────────────── #> python: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/bin/python #> libpython: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/libpython3.8.dylib #> pythonhome: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2:/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2 #> version: 3.8.15 | packaged by conda-forge | (default, Nov 22 2022, 08:49:06) [Clang 14.0.6 ] #> numpy: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/numpy #> numpy_version: 1.23.2 #> tensorflow: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/tensorflow #> #> NOTE: Python version was forced by use_python() function #> #> ────────────────────────────────────────────────────────────────────────────── ```
njtierney commented 8 months ago

So this error seems to be first noticed at

https://github.com/greta-dev/greta/blob/1bbc661a07c915ceb3b2f8af417b865fa7453d12/R/calculate.R#L506

Perhaps I am missing where the chol operation is happening, but this is where we have the matrix of 1s returned for the first time.

njtierney commented 8 months ago

Debugging tf_chol internal function in greta:

tf_chol <- function(x) {
  x_chol <- tf$linalg$cholesky(x)
  x_chol_t <- tf_transpose(x_chol)
  x_chol_t
}

When this is computed, we get something that looks right:

Browse[3]> x_chol_t
tf.Tensor(
[[[ 3.54538307  0.22028915 -0.05668089]
  [ 0.          1.45304198 -0.45842791]
  [ 0.          0.          0.28949603]]], shape=(1, 3, 3), dtype=float64)

So something is happening at

https://github.com/greta-dev/greta/blob/1bbc661a07c915ceb3b2f8af417b865fa7453d12/R/calculate.R#L506

where perhaps it isn't grabbing the right piece of data or something. Trying to debug this is a bit tricky.

njtierney commented 8 months ago

OK so looking further into what is inside tfe at

https://github.com/greta-dev/greta/blob/1bbc661a07c915ceb3b2f8af417b865fa7453d12/R/calculate.R#L506

It looks like we are grabbing the wrong slot, or there is some mislabeling that is happening here.

We are getting tfe$all_sampling_operation_1, which contains 1s:

Browse[2]> tfe$all_sampling_operation_1
tf.Tensor(
[[[1. 1. 1.]
  [1. 1. 1.]
  [1. 1. 1.]]], shape=(1, 3, 3), dtype=float64)

But if we look at the variable named tfe$all_sampling_, then we get:

#> Browse[4]> tfe$all_sampling_
#> tf.Tensor(
#>   [[[ 3.54538307  0.22028915 -0.05668089]
#>     [ 0.          1.45304198 -0.45842791]
#>     [ 0.          0.          0.28949603]]], shape=(1, 3, 3), dtype=float64)

Which is the output we want.

So something is wrong is happening in tf_name or something, will investigate!

njtierney commented 8 months ago

OK so there's a few things going on here that are a bit hard to untangle

For context, here is the code that I'm running

load_all()
set.seed(2023-11-10-1209)
x <- wishart(df = 4, Sigma = diag(3))
chol_x <- chol(x)
calc_chol <- calculate(chol_x, nsim = 1)

In this branch: https://github.com/njtierney/greta/tree/cholesky-585

OK, so, in:

https://github.com/njtierney/greta/blob/cholesky-585/R/node_types.R#L162-L169

      # if sampling get the distribution constructor and sample this
      if (mode == "sampling") {
        tensor <- dag$draw_sample(self$distribution)
        if (has_representation(self, "cholesky")) {
          cholesky_tensor <- tf_chol(tensor)
          cholesky_tf_name <- dag$tf_name(self$representation$cholesky)
          assign(cholesky_tf_name, cholesky_tensor, envir = dag$tf_environment)
        }
      }

We are calling the wrong name of the representation here:

https://github.com/njtierney/greta/blob/cholesky-585/R/node_types.R#L166

cholesky_tf_name <- dag$tf_name(self$representation$cholesky)

This should be representations (with an s)

But if you do

cholesky_tf_name <- dag$tf_name(self$representations$cholesky)

Then this will not work, since self$representations$cholesky is a greta array:

Browse[3]> self$representations$cholesky
greta array (variable)

     [,1] [,2] [,3]
[1,]  ?    ?    ?  
[2,] 0     ?    ?  
[3,] 0    0     ?  

So I think it was supposed to be

cholesky_tensor <- tf_chol(tensor)
cholesky_tf_name <- dag$tf_name(self)
assign(cholesky_tf_name, cholesky_tensor, envir = dag$tf_environment)

However, cholesky_tf_name resolves to all_sampling_operation_2, which is just duplicating the above line:

https://github.com/njtierney/greta/blob/cholesky-585/R/node_types.R#L158

tf_name <- dag$tf_name(self)

And overall, this is getting overwritten in

https://github.com/njtierney/greta/blob/cholesky-585/R/node_types.R#L191

assign(tf_name, tensor, envir = dag$tf_environment)

And ultimately, the only tf_name that gets resolved here:

https://github.com/njtierney/greta/blob/cholesky-585/R/calculate.R#L506

target_tensor_list <- lapply(target_names_list, get, envir = tfe

Is all_sampling_operation_1, not all_sampling_operation_2, and all_sampling_operation_1 is a bunch of 1s.

So I think I've started to hit on where the error is, sort of, but it seems like we might have a slightly more complex problem, and I can't immediately recall what our intention/workflow was with doing something like

https://github.com/njtierney/greta/blob/cholesky-585/R/node_types.R#L162-L169

      # if sampling get the distribution constructor and sample this
      if (mode == "sampling") {
        tensor <- dag$draw_sample(self$distribution)
        if (has_representation(self, "cholesky")) {
          cholesky_tensor <- tf_chol(tensor)
          cholesky_tf_name <- dag$tf_name(self$representation$cholesky)
          assign(cholesky_tf_name, cholesky_tensor, envir = dag$tf_environment)
        }
      }

In the first place.

goldingn commented 8 months ago

How is this working with the other distributions that use representations?

njtierney commented 8 months ago

They don't really use it the same way, I am beginning to think that this branch: https://github.com/njtierney/greta/tree/cholesky-585 was a bit more experimental.

The typical usage is to set some flag, e.g., https://github.com/greta-dev/greta/blob/93aaf361c04591a0d20f73c25b6e6693023482fd/R/probability_distributions.R#L162-L164

I've tried going through the debugging warrens back in greta tf2-poke-fun branch, with

load_all()
x <- wishart(df = 4, Sigma = diag(3))
log_x <- log(x)
debugonce(calculate_target_tensor_list)
res <- calculate(log_x, nsim = 1)

and comparing this to

load_all()
x <- wishart(df = 4, Sigma = diag(3))
chol_x <- chol(x)
debugonce(calculate_target_tensor_list)
res <- calculate(chol_x, nsim = 1)

And where I'm landing currently is that in the log version, we end up at

https://github.com/greta-dev/greta/blob/master/R/node_types.R#L167

the operation is log

But when we want to do chol , the operation is identity, and we don't end up back there.

I might need some guidance on driving the debugger to the right spot.

njtierney commented 8 months ago

The issue is related to the dag of using chol on wishart:

x <- wishart(df = 4, Sigma = diag(3)) 
chol_x <- chol(x) 
m <- model(chol_x)
plot(m)

model-wishart

As @goldingn said, the issue looks like it is in replacing the variable (the circle in the plotted dag):

library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply
# devtools::load_all()
x <- wishart(df = 4, Sigma = diag(3)) 
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#> 
chol_x <- chol(x) 
chol_x_orig <- greta:::representation(x, "cholesky") 
res <- calculate(chol_x, chol_x_orig, nsim = 1) 
res 
#> $chol_x
#> , , 1
#> 
#>      [,1] [,2] [,3]
#> [1,]    1    1    1
#> 
#> , , 2
#> 
#>      [,1] [,2] [,3]
#> [1,]    1    1    1
#> 
#> , , 3
#> 
#>      [,1] [,2] [,3]
#> [1,]    1    1    1
#> 
#> 
#> $chol_x_orig
#> , , 1
#> 
#>          [,1] [,2] [,3]
#> [1,] 3.291763    0    0
#> 
#> , , 2
#> 
#>             [,1]     [,2] [,3]
#> [1,] 0.002398802 1.579327    0
#> 
#> , , 3
#> 
#>           [,1]      [,2]     [,3]
#> [1,] -1.216082 0.9263707 2.086389
# check names to make sure we got the right ones 
m <- model(chol_x)
m$dag$tf_name(greta:::get_node(x)) 
#> [1] "all_forward_operation_2"
m$dag$tf_name(greta:::get_node(chol_x)) 
#> [1] "all_forward_operation_1"
m$dag$tf_name(greta:::get_node(chol_x_orig))
#> [1] "all_forward_variable_1"

Created on 2023-11-17 with reprex v2.0.2

Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.3.2 (2023-10-31) #> os macOS Sonoma 14.0 #> system aarch64, darwin20 #> ui X11 #> language (EN) #> collate en_US.UTF-8 #> ctype en_US.UTF-8 #> tz Australia/Hobart #> date 2023-11-17 #> pandoc 3.1.1 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/ (via rmarkdown) #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> package * version date (UTC) lib source #> backports 1.4.1 2021-12-13 [1] CRAN (R 4.3.0) #> base64enc 0.1-3 2015-07-28 [1] CRAN (R 4.3.0) #> callr 3.7.3 2022-11-02 [1] CRAN (R 4.3.0) #> cli 3.6.1 2023-03-23 [1] CRAN (R 4.3.0) #> coda 0.19-4 2020-09-30 [1] CRAN (R 4.3.0) #> codetools 0.2-19 2023-02-01 [1] CRAN (R 4.3.2) #> crayon 1.5.2 2022-09-29 [1] CRAN (R 4.3.0) #> digest 0.6.33 2023-07-07 [1] CRAN (R 4.3.0) #> evaluate 0.23 2023-11-01 [1] CRAN (R 4.3.1) #> fastmap 1.1.1 2023-02-24 [1] CRAN (R 4.3.0) #> fs 1.6.3 2023-07-20 [1] CRAN (R 4.3.0) #> future 1.33.0 2023-07-01 [1] CRAN (R 4.3.0) #> globals 0.16.2 2022-11-21 [1] CRAN (R 4.3.0) #> glue 1.6.2 2022-02-24 [1] CRAN (R 4.3.0) #> greta * 0.4.3.9000 2023-11-16 [1] local #> hms 1.1.3 2023-03-21 [1] CRAN (R 4.3.0) #> htmltools 0.5.7 2023-11-03 [1] CRAN (R 4.3.1) #> jsonlite 1.8.7 2023-06-29 [1] CRAN (R 4.3.0) #> knitr 1.45 2023-10-30 [1] CRAN (R 4.3.1) #> lattice 0.21-9 2023-10-01 [1] CRAN (R 4.3.2) #> lifecycle 1.0.3 2022-10-07 [1] CRAN (R 4.3.0) #> listenv 0.9.0 2022-12-16 [1] CRAN (R 4.3.0) #> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.3.0) #> Matrix 1.6-1.1 2023-09-18 [1] CRAN (R 4.3.2) #> parallelly 1.36.0 2023-05-26 [1] CRAN (R 4.3.0) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.3.0) #> png 0.1-8 2022-11-29 [1] CRAN (R 4.3.0) #> prettyunits 1.2.0 2023-09-24 [1] CRAN (R 4.3.1) #> processx 3.8.2 2023-06-30 [1] CRAN (R 4.3.0) #> progress 1.2.2 2019-05-16 [1] CRAN (R 4.3.0) #> ps 1.7.5 2023-04-18 [1] CRAN (R 4.3.0) #> purrr 1.0.2 2023-08-10 [1] CRAN (R 4.3.0) #> R.cache 0.16.0 2022-07-21 [1] CRAN (R 4.3.0) #> R.methodsS3 1.8.2 2022-06-13 [1] CRAN (R 4.3.0) #> R.oo 1.25.0 2022-06-12 [1] CRAN (R 4.3.0) #> R.utils 2.12.2 2022-11-11 [1] CRAN (R 4.3.0) #> R6 2.5.1 2021-08-19 [1] CRAN (R 4.3.0) #> Rcpp 1.0.11 2023-07-06 [1] CRAN (R 4.3.0) #> reprex 2.0.2 2022-08-17 [1] CRAN (R 4.3.0) #> reticulate 1.34.0 2023-10-12 [1] CRAN (R 4.3.1) #> rlang 1.1.1 2023-04-28 [1] CRAN (R 4.3.0) #> rmarkdown 2.25 2023-09-18 [1] CRAN (R 4.3.1) #> rstudioapi 0.15.0 2023-07-07 [1] CRAN (R 4.3.0) #> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.3.0) #> styler 1.9.1 2023-03-04 [1] CRAN (R 4.3.0) #> tensorflow 2.14.0 2023-09-28 [1] CRAN (R 4.3.1) #> tfautograph 0.3.2 2021-09-17 [1] CRAN (R 4.3.0) #> tfruns 1.5.1 2022-09-05 [1] CRAN (R 4.3.0) #> vctrs 0.6.4 2023-10-12 [1] CRAN (R 4.3.1) #> whisker 0.4.1 2022-12-05 [1] CRAN (R 4.3.0) #> withr 2.5.2 2023-10-30 [1] CRAN (R 4.3.1) #> xfun 0.41 2023-11-01 [1] CRAN (R 4.3.1) #> yaml 2.3.7 2023-01-23 [1] CRAN (R 4.3.0) #> #> [1] /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/library #> #> ─ Python configuration ─────────────────────────────────────────────────────── #> python: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/bin/python #> libpython: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/libpython3.8.dylib #> pythonhome: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2:/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2 #> version: 3.8.15 | packaged by conda-forge | (default, Nov 22 2022, 08:49:06) [Clang 14.0.6 ] #> numpy: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/numpy #> numpy_version: 1.23.2 #> tensorflow: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/tensorflow #> #> NOTE: Python version was forced by use_python() function #> #> ────────────────────────────────────────────────────────────────────────────── ```
njtierney commented 7 months ago

Follow up from meeting on this, we need to do some analysis on when the representations are used, specifically

Everything that is a representation of something else (e.g., cholesky) needs to know what it is a representation of, and add that node as its parent. However this could be infinitely recursive. We need to do an analysis of when representations are used in variables and they have a distribution, and the distribution can't update it, then we tell it that is has parents, and then define the appropriate parents. Remembering that we want to avoid infinite recursion.