greta-dev / greta

simple and scalable statistical modelling in R
https://greta-stats.org
Other
527 stars 63 forks source link

M1 TF2 cholesky factors all out of whack in calculate #585

Closed goldingn closed 1 month ago

goldingn commented 1 year ago

1) When running calculate() on cholesky factors of a Wishart-distributed matrix, the cholesky factor is wrong - a matrix of ones instead of the correct value 2) When mcmc() is run, it gets even weirder. The value given for the original matrix is a cholesky factor, but it's transposed.

I'm not sure what's going on for 1), but it could be something to do with the 'representations' code that reuses the cholesky factor.

For 2), could it be that when MCMC happens, the tensor for the original matrix is being overwritten with the cholesky factor in the same environment, and that is somehow being picked up in calculate?

This seems to be breaking my sampler for a model that uses cholesky factorisation, so I'm guessing something is being overwritten. I can't test on TF1 right now.

Note: I can't set the seed to demonstrate the reprex more clearly because of #559, and this is possibly related to or explains #560

library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply
x <- wishart(df = 4,
             Sigma = diag(3))
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#> 
# the wishart draw seems fine - it's symmetric
calculate(x, nsim = 1)
#> $x
#> , , 1
#> 
#>          [,1]     [,2]     [,3]
#> [1,] 1.551179 2.805199 1.886337
#> 
#> , , 2
#> 
#>          [,1]     [,2]     [,3]
#> [1,] 2.805199 5.957409 3.310655
#> 
#> , , 3
#> 
#>          [,1]     [,2]     [,3]
#> [1,] 1.886337 3.310655 4.102736

# but the cholesky factor is wrong. It should be an upper triangular matrix, bit
# is instead a ones matrix
chol_x <- chol(x)
calculate(x, chol_x, nsim = 1)
#> $x
#> , , 1
#> 
#>          [,1]     [,2]     [,3]
#> [1,] 2.339474 -1.06748 2.827606
#> 
#> , , 2
#> 
#>          [,1]     [,2]     [,3]
#> [1,] -1.06748 5.393758 1.731507
#> 
#> , , 3
#> 
#>          [,1]     [,2]     [,3]
#> [1,] 2.827606 1.731507 7.348657
#> 
#> 
#> $chol_x
#> , , 1
#> 
#>      [,1] [,2] [,3]
#> [1,]    1    1    1
#> 
#> , , 2
#> 
#>      [,1] [,2] [,3]
#> [1,]    1    1    1
#> 
#> , , 3
#> 
#>      [,1] [,2] [,3]
#> [1,]    1    1    1

# it's even weirder once we run the MCMC sampler on this:
m <- model(x)
draws <- mcmc(m, warmup = 1, n_samples = 1)
#> running 4 chains simultaneously on up to 8 CPU cores
#> 
#> warmup 0/1 | eta: ?s sampling 0/1 | eta: ?s

# now the matrix which should be symmetric looks like a cholesky factor (but
# lower triangular, when it should be upper triangular), and cholesky factor is
# still coming out as ones
calculate(x,
          chol(x),
          nsim = 1)
#> $x
#> , , 1
#> 
#>          [,1] [,2] [,3]
#> [1,] 2.534822    0    0
#> 
#> , , 2
#> 
#>         [,1]     [,2] [,3]
#> [1,] 2.45283 1.169166    0
#> 
#> , , 3
#> 
#>          [,1]      [,2]     [,3]
#> [1,] 1.350118 0.3224047 0.922814
#> 
#> 
#> $`chol(x)`
#> , , 1
#> 
#>      [,1] [,2] [,3]
#> [1,]    1    1    1
#> 
#> , , 2
#> 
#>      [,1] [,2] [,3]
#> [1,]    1    1    1
#> 
#> , , 3
#> 
#>      [,1] [,2] [,3]
#> [1,]    1    1    1
sessionInfo()
#> R version 4.2.2 (2022-10-31)
#> Platform: aarch64-apple-darwin20 (64-bit)
#> Running under: macOS Ventura 13.1
#> 
#> Matrix products: default
#> BLAS:   /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRblas.0.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRlapack.dylib
#> 
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#> [1] greta_0.4.3.9000
#> 
#> loaded via a namespace (and not attached):
#>  [1] Rcpp_1.0.10       compiler_4.2.2    base64enc_0.1-3   prettyunits_1.1.1
#>  [5] tools_4.2.2       progress_1.2.2    digest_0.6.31     jsonlite_1.8.4   
#>  [9] evaluate_0.20     lifecycle_1.0.3   lattice_0.20-45   png_0.1-8        
#> [13] pkgconfig_2.0.3   rlang_1.0.6       Matrix_1.5-3      reprex_2.0.2     
#> [17] cli_3.6.0         rstudioapi_0.14   yaml_2.3.7        parallel_4.2.2   
#> [21] xfun_0.37         fastmap_1.1.1     coda_0.19-4       withr_2.5.0      
#> [25] knitr_1.42        fs_1.6.1          vctrs_0.5.2       globals_0.16.2   
#> [29] hms_1.1.2         rprojroot_2.0.3   grid_4.2.2        here_1.0.1       
#> [33] reticulate_1.28   glue_1.6.2        listenv_0.9.0     R6_2.5.1         
#> [37] tfautograph_0.3.2 processx_3.8.0    parallelly_1.34.0 rmarkdown_2.20   
#> [41] magrittr_2.0.3    whisker_0.4.1     callr_3.7.3       backports_1.4.1  
#> [45] tfruns_1.5.1      codetools_0.2-18  ps_1.7.2          htmltools_0.5.4  
#> [49] ellipsis_0.3.2    abind_1.4-5       future_1.31.0     tensorflow_2.11.0
#> [53] crayon_1.5.2

Created on 2023-03-31 with reprex v2.0.2

goldingn commented 1 year ago

Note: this is the same for the lkj_correlation(), e.g. try subbing in x <- lkj_correlation(eta = 3, dimension = 3) in place of the wishart in the reprex

njtierney commented 1 year ago

Perhaps adding a test like this will help

library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply
library(testthat)
x <- wishart(df = 4, Sigma = diag(3))
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#> 
chol_x <- chol(x)
calc_chol <- calculate(x, chol_x, nsim = 1)
calc_chol
#> $x
#> , , 1
#> 
#>          [,1]       [,2]    [,3]
#> [1,] 8.300433 -0.4479606 0.57702
#> 
#> , , 2
#> 
#>            [,1]      [,2]        [,3]
#> [1,] -0.4479606 0.7182815 -0.01634725
#> 
#> , , 3
#> 
#>         [,1]        [,2]     [,3]
#> [1,] 0.57702 -0.01634725 1.586213
#> 
#> 
#> $chol_x
#> , , 1
#> 
#>      [,1] [,2] [,3]
#> [1,]    1    1    1
#> 
#> , , 2
#> 
#>      [,1] [,2] [,3]
#> [1,]    1    1    1
#> 
#> , , 3
#> 
#>      [,1] [,2] [,3]
#> [1,]    1    1    1
expect_equal(dim(calc_chol$chol_x), c(1,3,3))
calc_chol_mat <- matrix(calc_chol$chol_x, nrow = 3, ncol = 3)
calc_chol_mat
#>      [,1] [,2] [,3]
#> [1,]    1    1    1
#> [2,]    1    1    1
#> [3,]    1    1    1
lower_tri <- calc_chol_mat[lower.tri(calc_chol_mat)]
lower_tri
#> [1] 1 1 1
expect_equal(lower_tri, c(0,0,0))
#> Error: `lower_tri` not equal to c(0, 0, 0).
#> 3/3 mismatches (average diff: 1)
#> [1] 1 - 0 == 1
#> [2] 1 - 0 == 1
#> [3] 1 - 0 == 1

Created on 2023-04-19 with reprex v2.0.2

Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.2.3 (2023-03-15) #> os macOS Ventura 13.2 #> system aarch64, darwin20 #> ui X11 #> language (EN) #> collate en_US.UTF-8 #> ctype en_US.UTF-8 #> tz Australia/Brisbane #> date 2023-04-19 #> pandoc 2.19.2 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/ (via rmarkdown) #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> package * version date (UTC) lib source #> backports 1.4.1 2021-12-13 [1] CRAN (R 4.2.0) #> base64enc 0.1-3 2015-07-28 [1] CRAN (R 4.2.0) #> brio 1.1.3 2021-11-30 [1] CRAN (R 4.2.0) #> callr 3.7.3 2022-11-02 [1] CRAN (R 4.2.0) #> cli 3.6.1 2023-03-23 [1] CRAN (R 4.2.0) #> coda 0.19-4 2020-09-30 [1] CRAN (R 4.2.0) #> codetools 0.2-19 2023-02-01 [1] CRAN (R 4.2.3) #> crayon 1.5.2 2022-09-29 [1] CRAN (R 4.2.0) #> desc 1.4.2 2022-09-08 [1] CRAN (R 4.2.0) #> digest 0.6.31 2022-12-11 [1] CRAN (R 4.2.0) #> evaluate 0.20 2023-01-17 [1] CRAN (R 4.2.0) #> fansi 1.0.4 2023-01-22 [1] CRAN (R 4.2.0) #> fastmap 1.1.1 2023-02-24 [1] CRAN (R 4.2.0) #> fs 1.6.1 2023-02-06 [1] CRAN (R 4.2.0) #> future 1.32.0 2023-03-07 [1] CRAN (R 4.2.0) #> globals 0.16.2 2022-11-21 [1] CRAN (R 4.2.1) #> glue 1.6.2 2022-02-24 [1] CRAN (R 4.2.0) #> greta * 0.4.3.9000 2023-04-19 [1] local #> here 1.0.1 2020-12-13 [1] CRAN (R 4.2.0) #> hms 1.1.3 2023-03-21 [1] CRAN (R 4.2.0) #> htmltools 0.5.5 2023-03-23 [1] CRAN (R 4.2.0) #> jsonlite 1.8.4 2022-12-06 [1] CRAN (R 4.2.0) #> knitr 1.42 2023-01-25 [1] CRAN (R 4.2.0) #> lattice 0.21-8 2023-04-05 [1] CRAN (R 4.2.0) #> lifecycle 1.0.3 2022-10-07 [1] CRAN (R 4.2.0) #> listenv 0.9.0 2022-12-16 [1] CRAN (R 4.2.0) #> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.2.0) #> Matrix 1.5-4 2023-04-04 [1] CRAN (R 4.2.0) #> parallelly 1.35.0 2023-03-23 [1] CRAN (R 4.2.0) #> pillar 1.9.0 2023-03-22 [1] CRAN (R 4.2.0) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.2.0) #> pkgload 1.3.2 2022-11-16 [1] CRAN (R 4.2.0) #> png 0.1-8 2022-11-29 [1] CRAN (R 4.2.0) #> prettyunits 1.1.1 2020-01-24 [1] CRAN (R 4.2.0) #> processx 3.8.0 2022-10-26 [1] CRAN (R 4.2.0) #> progress 1.2.2 2019-05-16 [1] CRAN (R 4.2.0) #> ps 1.7.4 2023-04-02 [1] CRAN (R 4.2.0) #> purrr 1.0.1 2023-01-10 [1] CRAN (R 4.2.0) #> R.cache 0.16.0 2022-07-21 [1] CRAN (R 4.2.0) #> R.methodsS3 1.8.2 2022-06-13 [1] CRAN (R 4.2.0) #> R.oo 1.25.0 2022-06-12 [1] CRAN (R 4.2.0) #> R.utils 2.12.2 2022-11-11 [1] CRAN (R 4.2.0) #> R6 2.5.1 2021-08-19 [1] CRAN (R 4.2.0) #> Rcpp 1.0.10 2023-01-22 [1] CRAN (R 4.2.0) #> reprex 2.0.2 2022-08-17 [1] CRAN (R 4.2.0) #> reticulate 1.28 2023-01-27 [1] CRAN (R 4.2.0) #> rlang 1.1.0 2023-03-14 [1] CRAN (R 4.2.0) #> rmarkdown 2.21 2023-03-26 [1] CRAN (R 4.2.0) #> rprojroot 2.0.3 2022-04-02 [1] CRAN (R 4.2.0) #> rstudioapi 0.14 2022-08-22 [1] CRAN (R 4.2.0) #> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.2.0) #> styler 1.9.1 2023-03-04 [1] CRAN (R 4.2.0) #> tensorflow 2.11.0 2022-12-19 [1] CRAN (R 4.2.0) #> testthat * 3.1.7 2023-03-12 [1] CRAN (R 4.2.0) #> tfautograph 0.3.2 2021-09-17 [1] CRAN (R 4.2.0) #> tfruns 1.5.1 2022-09-05 [1] CRAN (R 4.2.0) #> utf8 1.2.3 2023-01-31 [1] CRAN (R 4.2.0) #> vctrs 0.6.1 2023-03-22 [1] CRAN (R 4.2.0) #> whisker 0.4.1 2022-12-05 [1] CRAN (R 4.2.0) #> withr 2.5.0 2022-03-03 [1] CRAN (R 4.2.0) #> xfun 0.38 2023-03-24 [1] CRAN (R 4.2.0) #> yaml 2.3.7 2023-01-23 [1] CRAN (R 4.2.0) #> #> [1] /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/library #> #> ─ Python configuration ─────────────────────────────────────────────────────── #> python: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/bin/python #> libpython: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/libpython3.8.dylib #> pythonhome: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2:/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2 #> version: 3.8.15 | packaged by conda-forge | (default, Nov 22 2022, 08:49:06) [Clang 14.0.6 ] #> numpy: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/numpy #> numpy_version: 1.23.2 #> tensorflow: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/tensorflow #> #> NOTE: Python version was forced by use_python function #> #> ────────────────────────────────────────────────────────────────────────────── ```
njtierney commented 1 year ago

So then we'd end up with tests like this


test_that("Cholesky factor of Wishart should be an upper triangular matrix", {
  x <- wishart(df = 4, Sigma = diag(3))
  chol_x <- chol(x)
  calc_chol <- calculate(x, chol_x, nsim = 1)
  expect_equal(dim(calc_chol$chol_x), c(1,3,3))
  calc_chol_mat <- matrix(calc_chol$chol_x, nrow = 3, ncol = 3)
  expect_equal(calc_chol_mat[lower.tri(calc_chol_mat)], c(0,0,0))
})

test_that("Cholesky factor of LJK_correlation should be an upper triangular matrix", {
  x <- lkj_correlation(eta = 3, dimension = 3)
  chol_x <- chol(x)
  calc_chol <- calculate(x, chol_x, nsim = 1)
  expect_equal(dim(calc_chol$chol_x), c(1,3,3))
  calc_chol_mat <- matrix(calc_chol$chol_x, nrow = 3, ncol = 3)
  expect_equal(calc_chol_mat[lower.tri(calc_chol_mat)], c(0,0,0))
})
njtierney commented 1 year ago

Notes from discussion on this, it seems that is an issue with the representation attribute not being properly accessed during sampling.

So, trying to unpack where this is happening - is it happening at the operation node, and why are the cholesky identifying parts not being appropriately used.

njtierney commented 2 months ago

For clarity,

now the matrix which should be symmetric looks like a cholesky factor (but lower triangular, when it should be upper triangular), and cholesky factor is still coming out as ones

A cholesky factor should be lower triangular (from https://en.wikipedia.org/wiki/Cholesky_decomposition)

njtierney commented 1 month ago

This is now resolved, issue with upper triangular is in #692