greta-dev / greta

simple and scalable statistical modelling in R
https://greta-stats.org
Other
524 stars 63 forks source link

M1 TF2 dev error: wishart and rwish aren't sampling the same #560

Closed njtierney closed 2 weeks ago

njtierney commented 1 year ago

NOTE: This is in the Tensorflow 2 development branch (https://github.com/greta-dev/greta/pull/534)

  devtools::load_all(".")
#> ℹ Loading greta
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> 
#> ✔ Initialising python and checking dependencies ... done!
#> Loaded Tensorflow version 2.9.2
  source("tests/testthat/helpers.R")

sigma <- rwish(1, 5, diag(4))[1, , ]
prob <- t(runif(4))
prob <- prob / sum(prob)

compare_iid_samples(wishart,
                    rwish,
                    parameters = list(df = 7, Sigma = sigma)
)
#> Error: test_result$p.value is not more than `p_value_threshold`. Difference: -0.001

Created on 2022-10-05 by the reprex package (v2.0.1)

Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.2.0 (2022-04-22) #> os macOS Monterey 12.3.1 #> system aarch64, darwin20 #> ui X11 #> language (EN) #> collate en_AU.UTF-8 #> ctype en_AU.UTF-8 #> tz Australia/Brisbane #> date 2022-10-05 #> pandoc 2.19.2 @ /Applications/RStudio.app/Contents/MacOS/quarto/bin/tools/ (via rmarkdown) #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> ! package * version date (UTC) lib source #> abind 1.4-5 2016-07-21 [1] CRAN (R 4.2.0) #> backports 1.4.1 2021-12-13 [1] CRAN (R 4.2.0) #> base64enc 0.1-3 2015-07-28 [1] CRAN (R 4.2.0) #> boot 1.3-28 2021-05-03 [1] CRAN (R 4.2.0) #> brio 1.1.3 2021-11-30 [1] CRAN (R 4.2.0) #> cachem 1.0.6 2021-08-19 [1] CRAN (R 4.2.0) #> callr 3.7.2 2022-08-22 [1] CRAN (R 4.2.0) #> cli 3.3.0.9000 2022-06-15 [1] Github (r-lib/cli@31a5db5) #> coda 0.19-4 2020-09-30 [1] CRAN (R 4.2.0) #> codetools 0.2-18 2020-11-04 [1] CRAN (R 4.2.0) #> cramer 0.9-3 2019-01-05 [1] CRAN (R 4.2.0) #> crayon 1.5.1 2022-03-26 [1] CRAN (R 4.2.0) #> desc 1.4.2 2022-09-08 [1] CRAN (R 4.2.0) #> devtools 2.4.4 2022-07-20 [1] CRAN (R 4.2.0) #> digest 0.6.29 2021-12-01 [1] CRAN (R 4.2.0) #> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.2.0) #> evaluate 0.16 2022-08-09 [1] CRAN (R 4.2.0) #> fansi 1.0.3 2022-03-24 [1] CRAN (R 4.2.0) #> fastmap 1.1.0 2021-01-25 [1] CRAN (R 4.2.0) #> fs 1.5.2 2021-12-08 [1] CRAN (R 4.2.0) #> future 1.27.0 2022-07-22 [1] CRAN (R 4.2.0) #> globals 0.16.0 2022-08-05 [1] CRAN (R 4.2.0) #> glue 1.6.2 2022-02-24 [1] CRAN (R 4.2.0) #> P greta * 0.4.2.9000 2022-09-19 [?] load_all() #> here 1.0.1 2020-12-13 [1] CRAN (R 4.2.0) #> highr 0.9 2021-04-16 [1] CRAN (R 4.2.0) #> hms 1.1.1 2021-09-26 [1] CRAN (R 4.2.0) #> htmltools 0.5.3 2022-07-18 [1] CRAN (R 4.2.0) #> htmlwidgets 1.5.4 2021-09-08 [1] CRAN (R 4.2.0) #> httpuv 1.6.5 2022-01-05 [1] CRAN (R 4.2.0) #> jsonlite 1.8.0 2022-02-22 [1] CRAN (R 4.2.0) #> knitr 1.40 2022-08-24 [1] CRAN (R 4.2.0) #> later 1.3.0 2021-08-18 [1] CRAN (R 4.2.0) #> lattice 0.20-45 2021-09-22 [1] CRAN (R 4.2.0) #> lifecycle 1.0.2 2022-09-09 [1] CRAN (R 4.2.0) #> listenv 0.8.0 2019-12-05 [1] CRAN (R 4.2.0) #> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.2.0) #> Matrix 1.4-1 2022-03-23 [1] CRAN (R 4.2.0) #> memoise 2.0.1 2021-11-26 [1] CRAN (R 4.2.0) #> mime 0.12 2021-09-28 [1] CRAN (R 4.2.0) #> miniUI 0.1.1.1 2018-05-18 [1] CRAN (R 4.2.0) #> parallelly 1.32.1 2022-07-21 [1] CRAN (R 4.2.0) #> pillar 1.8.1 2022-08-19 [1] CRAN (R 4.2.0) #> pkgbuild 1.3.1 2021-12-20 [1] CRAN (R 4.2.0) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.2.0) #> pkgload 1.3.0 2022-06-27 [1] CRAN (R 4.2.0) #> png 0.1-7 2013-12-03 [1] CRAN (R 4.2.0) #> prettyunits 1.1.1 2020-01-24 [1] CRAN (R 4.2.0) #> processx 3.7.0 2022-07-07 [1] CRAN (R 4.2.0) #> profvis 0.3.7 2020-11-02 [1] CRAN (R 4.2.0) #> progress 1.2.2 2019-05-16 [1] CRAN (R 4.2.0) #> promises 1.2.0.1 2021-02-11 [1] CRAN (R 4.2.0) #> ps 1.7.1 2022-06-18 [1] CRAN (R 4.2.0) #> purrr 0.3.4 2020-04-17 [1] CRAN (R 4.2.0) #> R.cache 0.16.0 2022-07-21 [1] CRAN (R 4.2.0) #> R.methodsS3 1.8.2 2022-06-13 [1] CRAN (R 4.2.0) #> R.oo 1.25.0 2022-06-12 [1] CRAN (R 4.2.0) #> R.utils 2.12.0 2022-06-28 [1] CRAN (R 4.2.0) #> R6 2.5.1 2021-08-19 [1] CRAN (R 4.2.0) #> Rcpp 1.0.9 2022-07-08 [1] CRAN (R 4.2.0) #> remotes 2.4.2 2021-11-30 [1] CRAN (R 4.2.0) #> reprex 2.0.1 2021-08-05 [1] CRAN (R 4.2.0) #> reticulate 1.25 2022-05-11 [1] CRAN (R 4.2.0) #> rlang 1.0.5 2022-08-31 [1] CRAN (R 4.2.0) #> rmarkdown 2.16 2022-08-24 [1] CRAN (R 4.2.0) #> rprojroot 2.0.3 2022-04-02 [1] CRAN (R 4.2.0) #> rstudioapi 0.13 2020-11-12 [1] CRAN (R 4.2.0) #> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.2.0) #> shiny 1.7.2 2022-07-19 [1] CRAN (R 4.2.0) #> stringi 1.7.8 2022-07-11 [1] CRAN (R 4.2.0) #> stringr 1.4.1 2022-08-20 [1] CRAN (R 4.2.0) #> styler 1.7.0 2022-03-13 [1] CRAN (R 4.2.0) #> tensorflow 2.9.0 2022-05-21 [1] CRAN (R 4.2.0) #> testthat * 3.1.4 2022-04-26 [1] CRAN (R 4.2.0) #> tfautograph 0.3.2 2021-09-17 [1] CRAN (R 4.2.0) #> tfruns 1.5.0 2021-02-26 [1] CRAN (R 4.2.0) #> tibble 3.1.8 2022-07-22 [1] CRAN (R 4.2.0) #> urlchecker 1.0.1 2021-11-30 [1] CRAN (R 4.2.0) #> usethis 2.1.6 2022-05-25 [1] CRAN (R 4.2.0) #> utf8 1.2.2 2021-07-24 [1] CRAN (R 4.2.0) #> vctrs 0.4.1 2022-04-13 [1] CRAN (R 4.2.0) #> whisker 0.4 2019-08-28 [1] CRAN (R 4.2.0) #> withr 2.5.0 2022-03-03 [1] CRAN (R 4.2.0) #> xfun 0.33 2022-09-12 [1] CRAN (R 4.2.0) #> xtable 1.8-4 2019-04-21 [1] CRAN (R 4.2.0) #> yaml 2.3.5 2022-02-21 [1] CRAN (R 4.2.0) #> yesno 0.1.2 2020-07-10 [1] CRAN (R 4.2.0) #> #> [1] /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/library #> #> P ── Loaded and on-disk path mismatch. #> #> ─ Python configuration ─────────────────────────────────────────────────────── #> python: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/bin/python #> libpython: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/libpython3.8.dylib #> pythonhome: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2:/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2 #> version: 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:05:16) [Clang 12.0.1 ] #> numpy: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/numpy #> numpy_version: 1.22.4 #> tensorflow: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/tensorflow #> #> NOTE: Python version was forced by use_python function #> #> ────────────────────────────────────────────────────────────────────────────── ```
njtierney commented 1 year ago

It turns out that the TFP distribution returns the wishart distribution as a lower triangular Cholesky factor. This is because we are using

tfp$distributions$WishartTriL

https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/WishartTriL

Which, according to the documentation, returns a lower triangular Cholesky factor.

We can see that here, looking at random draws from this code:

sigma <- rwish(1, 5, diag(4))[1, , ]
sigma
r_samples <- rWishart(
  n = 3,
  df = 7,
  Sigma = sigma
) 
r_samples
, , 1

          [,1]      [,2]      [,3]
[1,]  3.502007 -6.585769 -2.661935
[2,] -6.585769 15.045059 -2.043331
[3,] -2.661935 -2.043331 26.932117
[4,]  4.202719 -9.495109  3.272535
          [,4]
[1,]  4.202719
[2,] -9.495109
[3,]  3.272535
[4,] 18.246596

, , 2

          [,1]      [,2]      [,3]
[1,] 11.542086 -9.200114 -9.765930
[2,] -9.200114  9.105552  8.881690
[3,] -9.765930  8.881690 22.683447
[4,]  7.876086 -8.947543 -4.968877
          [,4]
[1,]  7.876086
[2,] -8.947543
[3,] -4.968877
[4,] 45.715534

, , 3

          [,1]      [,2]      [,3]
[1,]  3.661248 -6.890945  1.342439
[2,] -6.890945 19.554629 -2.524039
[3,]  1.342439 -2.524039  1.744909
[4,] -6.535986 12.780041 -5.763078
          [,4]
[1,] -6.535986
[2,] 12.780041
[3,] -5.763078
[4,] 30.628326
hist(r_samples)

image

compared to the greta data:

g_wish <- wishart(
  df = 7,
  Sigma = sigma
)

greta_samples <- calculate(g_wish, nsim = 3)
greta_samples$g_wish
, , 1

         [,1]      [,2]       [,3]
[1,] 3.535723 -1.790248 -7.7369216
[2,] 1.692909 -1.432415 -1.2130116
[3,] 2.056882 -1.753289 -0.4312736
          [,4]
[1,] -4.057937
[2,] -2.531842
[3,] -2.063981

, , 2

     [,1]     [,2]       [,3]       [,4]
[1,]    0 2.631537  0.1488081 -0.4031088
[2,]    0 1.900439 -1.3181909  0.6885713
[3,]    0 4.974351 -4.1338302 -4.8195536

, , 3

     [,1] [,2]     [,3]      [,4]
[1,]    0    0 7.422673 0.2911123
[2,]    0    0 3.380941 0.8407127
[3,]    0    0 3.719833 1.8276059

, , 4

     [,1] [,2] [,3]     [,4]
[1,]    0    0    0 3.253616
[2,]    0    0    0 2.183522
[3,]    0    0    0 2.094107
hist(unlist(greta_samples))

image

The data looks rather suspiciously non-triangular.

I believe this is due to the argument, input_output_cholesky being set to TRUE, the documentation for this argument stating:

Python bool. If True, functions whose input or output have the semantics of samples assume inputs are in Cholesky form and return outputs in Cholesky form. In particular, if this flag is True, input to log_prob is presumed of Cholesky form and output from sample, mean, and mode are of Cholesky form. Setting this argument to True is purely a computational optimization and does not change the underlying distribution; for instance, mean returns the Cholesky of the mean, not the mean of Cholesky factors. The variance and stddev methods are unaffected by this flag. Default value: False (i.e., input/output does not have Cholesky semantics).

If we change this from TRUE to FALSE as:

tfp$distributions$WishartTriL(
          df = df,
          scale_tril = sigma_chol,
          input_output_cholesky = TRUE
        )
```r

Then we get the following for greta arrays:

```r
g_wish <- wishart(
  df = 7,
  Sigma = sigma
)

greta_samples <- calculate(g_wish, nsim = 3)
greta_samples$g_wish
, , 1

          [,1]       [,2]       [,3]
[1,] 19.339180 -27.928833 -24.664543
[2,]  4.786275  -4.952132  -6.593055
[3,]  4.006417  -2.028697  -6.779876
          [,4]
[1,] -0.951870
[2,]  4.589708
[3,] -4.908824

, , 2

           [,1]      [,2]      [,3]
[1,] -27.928833 62.309693 39.529848
[2,]  -4.952132 18.451103  3.839761
[3,]  -2.028697  8.711624 11.558550
           [,4]
[1,]  -1.758316
[2,] -13.512347
[3,]   1.504732

, , 3

           [,1]      [,2]     [,3]
[1,] -24.664543 39.529848 51.80217
[2,]  -6.593055  3.839761 15.54168
[3,]  -6.779876 11.558550 26.49733
          [,4]
[1,] -1.375958
[2,] -5.367311
[3,]  6.708346

, , 4

          [,1]       [,2]      [,3]
[1,] -0.951870  -1.758316 -1.375958
[2,]  4.589708 -13.512347 -5.367311
[3,] -4.908824   1.504732  6.708346
          [,4]
[1,]  4.046637
[2,] 18.103990
[3,] 30.378239
hist(unlist(greta_samples))

image

Which looks much more similar to the r_samples...although not exactly identical! However the tests in test_iid_sampels now pass with those changes, so I think we are in the clear again?

However, I am curious why the general formatting of the arrays in greta seems different to the ones in R 🤔

njtierney commented 1 year ago

I'm tempted to say that this is resolved, but I'd like a check from @goldingn

njtierney commented 1 year ago

OK so the density calculations are not the same for wishart when we change input_output_cholesky to TRUE, which leads to more breaking tests.

njtierney commented 1 year ago

OK so here's what TF2 looks like - will show greta TF1 soon

devtools::load_all("../greta/")
#> ℹ Loading greta
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> 
#> ✔ Initialising python and checking dependencies ... done!
#> Loaded Tensorflow version 2.10.0
source("../greta/tests/testthat/helpers.R")

sigma <- rwish(1, 5, diag(4))[1, , ]
prob <- t(runif(4))
prob <- prob / sum(prob)

g_wish <- wishart(
  df = 7,
  Sigma = sigma
)

greta_samples <- calculate(g_wish, nsim = 3)
greta_samples$g_wish
#> , , 1
#> 
#>          [,1]      [,2]     [,3]       [,4]
#> [1,] 7.561697 -3.905852 6.006358  1.4588092
#> [2,] 7.145607 -2.866151 7.978184 -0.2109818
#> [3,] 8.041428 -4.140150 6.452004  1.5914972
#> 
#> , , 2
#> 
#>      [,1]     [,2]     [,3]      [,4]
#> [1,]    0 3.085177 5.416986 -2.531217
#> [2,]    0 1.479217 2.908386 -2.259755
#> [3,]    0 3.298409 5.069922 -1.380517
#> 
#> , , 3
#> 
#>      [,1] [,2]     [,3]       [,4]
#> [1,]    0    0 3.662858 -1.4163186
#> [2,]    0    0 1.399411 -0.2632336
#> [3,]    0    0 1.702768  0.5577520
#> 
#> , , 4
#> 
#>      [,1] [,2] [,3]     [,4]
#> [1,]    0    0    0 2.485087
#> [2,]    0    0    0 2.535196
#> [3,]    0    0    0 2.604749

r_wish <- rwish(
  n = 3,
  df = 6,
  Sigma = sigma
)

r_wish
#> , , 1
#> 
#>          [,1]       [,2]     [,3]       [,4]
#> [1,] 21.72889  -4.956895 25.60117 -5.6430929
#> [2,] 40.02885 -14.222116 40.47071 -0.4652243
#> [3,] 31.50728 -16.426614 27.22179  6.8740127
#> 
#> , , 2
#> 
#>            [,1]     [,2]       [,3]      [,4]
#> [1,]  -4.956895  5.69157  -3.128652  4.310667
#> [2,] -14.222116 12.96749  -3.250030 -2.446697
#> [3,] -16.426614 10.69993 -11.524312 -2.751033
#> 
#> , , 3
#> 
#>          [,1]       [,2]     [,3]      [,4]
#> [1,] 25.60117  -3.128652 38.54822 -7.193464
#> [2,] 40.47071  -3.250030 58.38350 -5.931755
#> [3,] 27.22179 -11.524312 33.58387  5.110906
#> 
#> , , 4
#> 
#>            [,1]      [,2]      [,3]      [,4]
#> [1,] -5.6430929  4.310667 -7.193464 10.014733
#> [2,] -0.4652243 -2.446697 -5.931755  3.203074
#> [3,]  6.8740127 -2.751033  5.110906  6.955262

# regular R wishart
rWishart(n = 3,
         df = 7,
         Sigma = sigma)
#> , , 1
#> 
#>           [,1]       [,2]       [,3]      [,4]
#> [1,] 11.721441 -9.7199183  1.4354515  2.748119
#> [2,] -9.719918  9.9798060  0.7573852 -2.673627
#> [3,]  1.435451  0.7573852  7.7456930 -3.022828
#> [4,]  2.748119 -2.6736271 -3.0228284  8.361974
#> 
#> , , 2
#> 
#>            [,1]       [,2]      [,3]      [,4]
#> [1,]  18.084787 -12.480217 11.769252  5.711503
#> [2,] -12.480217  10.031460 -9.311690 -2.572676
#> [3,]  11.769252  -9.311690 10.880932  1.931768
#> [4,]   5.711503  -2.572676  1.931768  6.164143
#> 
#> , , 3
#> 
#>           [,1]       [,2]      [,3]      [,4]
#> [1,]  59.60473 -20.377244  65.04302 14.859133
#> [2,] -20.37724   7.911757 -20.50039 -6.243963
#> [3,]  65.04302 -20.500391  79.07283 12.188564
#> [4,]  14.85913  -6.243963  12.18856 14.733156

compare_iid_samples(wishart,
                    rwish,
                    parameters = list(df = 7, Sigma = sigma)
)
#> Error: test_result$p.value is not more than `p_value_threshold`. Difference: -0.001

Created on 2022-12-16 with reprex v2.0.2

Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.2.1 (2022-06-23) #> os macOS Monterey 12.3.1 #> system aarch64, darwin20 #> ui X11 #> language (EN) #> collate en_US.UTF-8 #> ctype en_US.UTF-8 #> tz Australia/Brisbane #> date 2022-12-16 #> pandoc 2.19.2 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/ (via rmarkdown) #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> ! package * version date (UTC) lib source #> abind 1.4-5 2016-07-21 [1] CRAN (R 4.2.0) #> assertthat 0.2.1 2019-03-21 [1] CRAN (R 4.2.0) #> backports 1.4.1 2021-12-13 [1] CRAN (R 4.2.0) #> base64enc 0.1-3 2015-07-28 [1] CRAN (R 4.2.0) #> boot 1.3-28.1 2022-11-22 [1] CRAN (R 4.2.1) #> brio 1.1.3 2021-11-30 [1] CRAN (R 4.2.0) #> broom 1.0.1 2022-08-29 [1] CRAN (R 4.2.0) #> cachem 1.0.6 2021-08-19 [1] CRAN (R 4.2.0) #> callr 3.7.3 2022-11-02 [1] CRAN (R 4.2.0) #> cellranger 1.1.0 2016-07-27 [1] CRAN (R 4.2.0) #> cli 3.4.1 2022-09-23 [1] CRAN (R 4.2.0) #> coda 0.19-4 2020-09-30 [1] CRAN (R 4.2.0) #> codetools 0.2-18 2020-11-04 [1] CRAN (R 4.2.1) #> colorspace 2.0-3 2022-02-21 [1] CRAN (R 4.2.0) #> cramer 0.9-3 2019-01-05 [1] CRAN (R 4.2.0) #> crayon 1.5.2 2022-09-29 [1] CRAN (R 4.2.0) #> DBI 1.1.3 2022-06-18 [1] CRAN (R 4.2.0) #> dbplyr 2.2.1 2022-06-27 [1] CRAN (R 4.2.0) #> desc 1.4.2 2022-09-08 [1] CRAN (R 4.2.0) #> devtools 2.4.5 2022-10-11 [1] CRAN (R 4.2.0) #> digest 0.6.30 2022-10-18 [1] CRAN (R 4.2.0) #> dplyr * 1.0.10 2022-09-01 [1] CRAN (R 4.2.0) #> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.2.0) #> evaluate 0.18 2022-11-07 [1] CRAN (R 4.2.0) #> fansi 1.0.3 2022-03-24 [1] CRAN (R 4.2.0) #> fastmap 1.1.0 2021-01-25 [1] CRAN (R 4.2.0) #> forcats * 0.5.2 2022-08-19 [1] CRAN (R 4.2.0) #> fs 1.5.2 2021-12-08 [1] CRAN (R 4.2.0) #> future 1.29.0 2022-11-06 [1] CRAN (R 4.2.0) #> gargle 1.2.1 2022-09-08 [1] CRAN (R 4.2.0) #> generics 0.1.3 2022-07-05 [1] CRAN (R 4.2.0) #> ggplot2 * 3.4.0 2022-11-04 [1] CRAN (R 4.2.0) #> globals 0.16.2 2022-11-21 [1] CRAN (R 4.2.1) #> glue 1.6.2 2022-02-24 [1] CRAN (R 4.2.0) #> googledrive 2.0.0 2021-07-08 [1] CRAN (R 4.2.0) #> googlesheets4 1.0.1 2022-08-13 [1] CRAN (R 4.2.0) #> P greta * 0.4.2.9000 2022-12-14 [?] load_all() #> gtable 0.3.1 2022-09-01 [1] CRAN (R 4.2.0) #> haven 2.5.1 2022-08-22 [1] CRAN (R 4.2.0) #> here 1.0.1 2020-12-13 [1] CRAN (R 4.2.0) #> highr 0.9 2021-04-16 [1] CRAN (R 4.2.0) #> hms 1.1.2 2022-08-19 [1] CRAN (R 4.2.0) #> htmltools 0.5.3 2022-07-18 [1] CRAN (R 4.2.0) #> htmlwidgets 1.5.4 2021-09-08 [1] CRAN (R 4.2.0) #> httpuv 1.6.6 2022-09-08 [1] CRAN (R 4.2.0) #> httr 1.4.4 2022-08-17 [1] CRAN (R 4.2.0) #> jsonlite 1.8.3 2022-10-21 [1] CRAN (R 4.2.0) #> knitr 1.41 2022-11-18 [1] CRAN (R 4.2.0) #> later 1.3.0 2021-08-18 [1] CRAN (R 4.2.0) #> lattice 0.20-45 2021-09-22 [1] CRAN (R 4.2.1) #> lifecycle 1.0.3 2022-10-07 [1] CRAN (R 4.2.0) #> listenv 0.8.0 2019-12-05 [1] CRAN (R 4.2.0) #> lubridate 1.9.0 2022-11-06 [1] CRAN (R 4.2.0) #> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.2.0) #> Matrix 1.5-3 2022-11-11 [1] CRAN (R 4.2.0) #> memoise 2.0.1 2021-11-26 [1] CRAN (R 4.2.0) #> mime 0.12 2021-09-28 [1] CRAN (R 4.2.0) #> miniUI 0.1.1.1 2018-05-18 [1] CRAN (R 4.2.0) #> modelr 0.1.10 2022-11-11 [1] CRAN (R 4.2.0) #> munsell 0.5.0 2018-06-12 [1] CRAN (R 4.2.0) #> parallelly 1.32.1 2022-07-21 [1] CRAN (R 4.2.0) #> pillar 1.8.1 2022-08-19 [1] CRAN (R 4.2.0) #> pkgbuild 1.4.0 2022-11-27 [1] CRAN (R 4.2.1) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.2.0) #> pkgload 1.3.2 2022-11-16 [1] CRAN (R 4.2.0) #> png 0.1-8 2022-11-29 [1] CRAN (R 4.2.0) #> prettyunits 1.1.1 2020-01-24 [1] CRAN (R 4.2.0) #> processx 3.8.0 2022-10-26 [1] CRAN (R 4.2.0) #> profvis 0.3.7 2020-11-02 [1] CRAN (R 4.2.0) #> progress 1.2.2 2019-05-16 [1] CRAN (R 4.2.0) #> promises 1.2.0.1 2021-02-11 [1] CRAN (R 4.2.0) #> ps 1.7.2 2022-10-26 [1] CRAN (R 4.2.0) #> purrr * 0.3.5 2022-10-06 [1] CRAN (R 4.2.0) #> R.cache 0.16.0 2022-07-21 [1] CRAN (R 4.2.0) #> R.methodsS3 1.8.2 2022-06-13 [1] CRAN (R 4.2.0) #> R.oo 1.25.0 2022-06-12 [1] CRAN (R 4.2.0) #> R.utils 2.12.2 2022-11-11 [1] CRAN (R 4.2.0) #> R6 2.5.1 2021-08-19 [1] CRAN (R 4.2.0) #> Rcpp 1.0.9 2022-07-08 [1] CRAN (R 4.2.0) #> readr * 2.1.3 2022-10-01 [1] CRAN (R 4.2.0) #> readxl 1.4.1 2022-08-17 [1] CRAN (R 4.2.0) #> remotes 2.4.2 2021-11-30 [1] CRAN (R 4.2.0) #> reprex 2.0.2 2022-08-17 [1] CRAN (R 4.2.0) #> reticulate 1.26 2022-08-31 [1] CRAN (R 4.2.0) #> rlang 1.0.6 2022-09-24 [1] CRAN (R 4.2.0) #> rmarkdown 2.18 2022-11-09 [1] CRAN (R 4.2.0) #> rprojroot 2.0.3 2022-04-02 [1] CRAN (R 4.2.0) #> rstudioapi 0.14 2022-08-22 [1] CRAN (R 4.2.0) #> rvest 1.0.3 2022-08-19 [1] CRAN (R 4.2.0) #> scales 1.2.1 2022-08-20 [1] CRAN (R 4.2.0) #> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.2.0) #> shiny 1.7.3 2022-10-25 [1] CRAN (R 4.2.0) #> stringi 1.7.8 2022-07-11 [1] CRAN (R 4.2.0) #> stringr * 1.5.0 2022-12-02 [1] CRAN (R 4.2.0) #> styler 1.8.1 2022-11-07 [1] CRAN (R 4.2.0) #> tensorflow 2.9.0 2022-05-21 [1] CRAN (R 4.2.0) #> testthat * 3.1.5 2022-10-08 [1] CRAN (R 4.2.0) #> tfautograph 0.3.2 2021-09-17 [1] CRAN (R 4.2.0) #> tfruns 1.5.1 2022-09-05 [1] CRAN (R 4.2.0) #> tibble * 3.1.8 2022-07-22 [1] CRAN (R 4.2.0) #> tidyr * 1.2.1 2022-09-08 [1] CRAN (R 4.2.0) #> tidyselect 1.2.0 2022-10-10 [1] CRAN (R 4.2.0) #> tidyverse * 1.3.2 2022-07-18 [1] CRAN (R 4.2.0) #> timechange 0.1.1 2022-11-04 [1] CRAN (R 4.2.0) #> tzdb 0.3.0 2022-03-28 [1] CRAN (R 4.2.0) #> urlchecker 1.0.1 2021-11-30 [1] CRAN (R 4.2.0) #> usethis 2.1.6 2022-05-25 [1] CRAN (R 4.2.0) #> utf8 1.2.2 2021-07-24 [1] CRAN (R 4.2.0) #> vctrs 0.5.1 2022-11-16 [1] CRAN (R 4.2.0) #> whisker 0.4 2019-08-28 [1] CRAN (R 4.2.0) #> withr 2.5.0 2022-03-03 [1] CRAN (R 4.2.0) #> xfun 0.35 2022-11-16 [1] CRAN (R 4.2.0) #> xml2 1.3.3 2021-11-30 [1] CRAN (R 4.2.0) #> xtable 1.8-4 2019-04-21 [1] CRAN (R 4.2.0) #> yaml 2.3.6 2022-10-18 [1] CRAN (R 4.2.0) #> yesno 0.1.2 2020-07-10 [1] CRAN (R 4.2.0) #> #> [1] /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/library #> #> P ── Loaded and on-disk path mismatch. #> #> ─ Python configuration ─────────────────────────────────────────────────────── #> python: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/bin/python #> libpython: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/libpython3.8.dylib #> pythonhome: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2:/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2 #> version: 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:05:16) [Clang 12.0.1 ] #> numpy: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/numpy #> numpy_version: 1.23.2 #> tensorflow: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/tensorflow #> #> NOTE: Python version was forced by use_python function #> #> ────────────────────────────────────────────────────────────────────────────── ```
njtierney commented 1 year ago

OK, it turns out that we need input_output_cholesky FALSE (off) for sampling, but TRUE (on) for calculating the log prob!

Proof this works:

devtools::load_all("../greta/")
#> ℹ Loading greta
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> 
#> ✔ Initialising python and checking dependencies ... done!
#> Loaded Tensorflow version 2.10.0
source("../greta/tests/testthat/helpers.R")

sigma <- rwish(1, 5, diag(4))[1, , ]
prob <- t(runif(4))
prob <- prob / sum(prob)

g_wish <- wishart(
  df = 7,
  Sigma = sigma
)

greta_samples <- calculate(g_wish, nsim = 3)
greta_samples$g_wish
#> , , 1
#> 
#>          [,1]      [,2]     [,3]      [,4]
#> [1,] 57.17926 -29.53487 45.41826 11.031073
#> [2,] 51.05970 -20.48039 57.00897 -1.507593
#> [3,] 64.66457 -33.29272 51.88332 12.797910
#> 
#> , , 2
#> 
#>           [,1]     [,2]       [,3]       [,4]
#> [1,] -29.53487 24.77400  -6.747586 -13.507148
#> [2,] -20.48039 10.40290 -18.564547  -2.737962
#> [3,] -33.29272 28.02034  -9.989587 -11.142546
#> 
#> , , 3
#> 
#>          [,1]       [,2]     [,3]       [,4]
#> [1,] 45.41826  -6.747586 78.83660 -10.137212
#> [2,] 57.00897 -18.564547 74.06848  -8.623862
#> [3,] 51.88332  -9.989587 70.23188   4.218955
#> 
#> , , 4
#> 
#>           [,1]       [,2]       [,3]     [,4]
#> [1,] 11.031073 -13.507148 -10.137212 16.71680
#> [2,] -1.507593  -2.737962  -8.623862 11.64752
#> [3,] 12.797910 -11.142546   4.218955 11.53449

r_wish <- rwish(
  n = 3,
  df = 6,
  Sigma = sigma
)

r_wish
#> , , 1
#> 
#>          [,1]       [,2]     [,3]       [,4]
#> [1,] 21.72889  -4.956895 25.60117 -5.6430929
#> [2,] 40.02885 -14.222116 40.47071 -0.4652243
#> [3,] 31.50728 -16.426614 27.22179  6.8740127
#> 
#> , , 2
#> 
#>            [,1]     [,2]       [,3]      [,4]
#> [1,]  -4.956895  5.69157  -3.128652  4.310667
#> [2,] -14.222116 12.96749  -3.250030 -2.446697
#> [3,] -16.426614 10.69993 -11.524312 -2.751033
#> 
#> , , 3
#> 
#>          [,1]       [,2]     [,3]      [,4]
#> [1,] 25.60117  -3.128652 38.54822 -7.193464
#> [2,] 40.47071  -3.250030 58.38350 -5.931755
#> [3,] 27.22179 -11.524312 33.58387  5.110906
#> 
#> , , 4
#> 
#>            [,1]      [,2]      [,3]      [,4]
#> [1,] -5.6430929  4.310667 -7.193464 10.014733
#> [2,] -0.4652243 -2.446697 -5.931755  3.203074
#> [3,]  6.8740127 -2.751033  5.110906  6.955262

# regular R wishart
rWishart(n = 3,
         df = 7,
         Sigma = sigma)
#> , , 1
#> 
#>           [,1]       [,2]       [,3]      [,4]
#> [1,] 11.721441 -9.7199183  1.4354515  2.748119
#> [2,] -9.719918  9.9798060  0.7573852 -2.673627
#> [3,]  1.435451  0.7573852  7.7456930 -3.022828
#> [4,]  2.748119 -2.6736271 -3.0228284  8.361974
#> 
#> , , 2
#> 
#>            [,1]       [,2]      [,3]      [,4]
#> [1,]  18.084787 -12.480217 11.769252  5.711503
#> [2,] -12.480217  10.031460 -9.311690 -2.572676
#> [3,]  11.769252  -9.311690 10.880932  1.931768
#> [4,]   5.711503  -2.572676  1.931768  6.164143
#> 
#> , , 3
#> 
#>           [,1]       [,2]      [,3]      [,4]
#> [1,]  59.60473 -20.377244  65.04302 14.859133
#> [2,] -20.37724   7.911757 -20.50039 -6.243963
#> [3,]  65.04302 -20.500391  79.07283 12.188564
#> [4,]  14.85913  -6.243963  12.18856 14.733156

compare_iid_samples(wishart,
                    rwish,
                    parameters = list(df = 7, Sigma = sigma)
)

# also density

# parameters to test
m <- 5
df <- m + 1
sig <- rWishart(1, df, diag(m))[, , 1]

# wrapper for argument names
dwishart <- function(x, df, Sigma, log = FALSE) { # nolint
  ans <- MCMCpack::dwish(W = x, v = df, S = Sigma)
  if (log) {
    ans <- log(ans)
  }
  ans
}

# no vectorised wishart, so loop through all of these
replicate(
  10,
  compare_distribution(
    greta::wishart,
    dwishart,
    parameters = list(
      df = df,
      Sigma = sig
    ),
    x = rWishart(1, df, sig)[, , 1],
    multivariate = TRUE
  )
)
#>  [1] TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE

Created on 2022-12-16 with reprex v2.0.2

Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.2.1 (2022-06-23) #> os macOS Monterey 12.3.1 #> system aarch64, darwin20 #> ui X11 #> language (EN) #> collate en_US.UTF-8 #> ctype en_US.UTF-8 #> tz Australia/Brisbane #> date 2022-12-16 #> pandoc 2.19.2 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/ (via rmarkdown) #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> ! package * version date (UTC) lib source #> abind 1.4-5 2016-07-21 [1] CRAN (R 4.2.0) #> backports 1.4.1 2021-12-13 [1] CRAN (R 4.2.0) #> base64enc 0.1-3 2015-07-28 [1] CRAN (R 4.2.0) #> boot 1.3-28.1 2022-11-22 [1] CRAN (R 4.2.1) #> brio 1.1.3 2021-11-30 [1] CRAN (R 4.2.0) #> cachem 1.0.6 2021-08-19 [1] CRAN (R 4.2.0) #> callr 3.7.3 2022-11-02 [1] CRAN (R 4.2.0) #> cli 3.4.1 2022-09-23 [1] CRAN (R 4.2.0) #> coda 0.19-4 2020-09-30 [1] CRAN (R 4.2.0) #> codetools 0.2-18 2020-11-04 [1] CRAN (R 4.2.1) #> cramer 0.9-3 2019-01-05 [1] CRAN (R 4.2.0) #> crayon 1.5.2 2022-09-29 [1] CRAN (R 4.2.0) #> desc 1.4.2 2022-09-08 [1] CRAN (R 4.2.0) #> devtools 2.4.5 2022-10-11 [1] CRAN (R 4.2.0) #> digest 0.6.30 2022-10-18 [1] CRAN (R 4.2.0) #> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.2.0) #> evaluate 0.18 2022-11-07 [1] CRAN (R 4.2.0) #> fastmap 1.1.0 2021-01-25 [1] CRAN (R 4.2.0) #> fs 1.5.2 2021-12-08 [1] CRAN (R 4.2.0) #> future 1.29.0 2022-11-06 [1] CRAN (R 4.2.0) #> globals 0.16.2 2022-11-21 [1] CRAN (R 4.2.1) #> glue 1.6.2 2022-02-24 [1] CRAN (R 4.2.0) #> P greta * 0.4.2.9000 2022-12-14 [?] load_all() #> here 1.0.1 2020-12-13 [1] CRAN (R 4.2.0) #> highr 0.9 2021-04-16 [1] CRAN (R 4.2.0) #> hms 1.1.2 2022-08-19 [1] CRAN (R 4.2.0) #> htmltools 0.5.3 2022-07-18 [1] CRAN (R 4.2.0) #> htmlwidgets 1.5.4 2021-09-08 [1] CRAN (R 4.2.0) #> httpuv 1.6.6 2022-09-08 [1] CRAN (R 4.2.0) #> jsonlite 1.8.3 2022-10-21 [1] CRAN (R 4.2.0) #> knitr 1.41 2022-11-18 [1] CRAN (R 4.2.0) #> later 1.3.0 2021-08-18 [1] CRAN (R 4.2.0) #> lattice 0.20-45 2021-09-22 [1] CRAN (R 4.2.1) #> lifecycle 1.0.3 2022-10-07 [1] CRAN (R 4.2.0) #> listenv 0.8.0 2019-12-05 [1] CRAN (R 4.2.0) #> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.2.0) #> MASS 7.3-58.1 2022-08-03 [1] CRAN (R 4.2.0) #> Matrix 1.5-3 2022-11-11 [1] CRAN (R 4.2.0) #> MatrixModels 0.5-1 2022-09-11 [1] CRAN (R 4.2.0) #> mcmc 0.9-7 2020-03-21 [1] CRAN (R 4.2.0) #> MCMCpack 1.6-3 2022-04-13 [1] CRAN (R 4.2.0) #> memoise 2.0.1 2021-11-26 [1] CRAN (R 4.2.0) #> mime 0.12 2021-09-28 [1] CRAN (R 4.2.0) #> miniUI 0.1.1.1 2018-05-18 [1] CRAN (R 4.2.0) #> parallelly 1.32.1 2022-07-21 [1] CRAN (R 4.2.0) #> pkgbuild 1.4.0 2022-11-27 [1] CRAN (R 4.2.1) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.2.0) #> pkgload 1.3.2 2022-11-16 [1] CRAN (R 4.2.0) #> png 0.1-8 2022-11-29 [1] CRAN (R 4.2.0) #> prettyunits 1.1.1 2020-01-24 [1] CRAN (R 4.2.0) #> processx 3.8.0 2022-10-26 [1] CRAN (R 4.2.0) #> profvis 0.3.7 2020-11-02 [1] CRAN (R 4.2.0) #> progress 1.2.2 2019-05-16 [1] CRAN (R 4.2.0) #> promises 1.2.0.1 2021-02-11 [1] CRAN (R 4.2.0) #> ps 1.7.2 2022-10-26 [1] CRAN (R 4.2.0) #> purrr 0.3.5 2022-10-06 [1] CRAN (R 4.2.0) #> quantreg 5.94 2022-07-20 [1] CRAN (R 4.2.0) #> R.cache 0.16.0 2022-07-21 [1] CRAN (R 4.2.0) #> R.methodsS3 1.8.2 2022-06-13 [1] CRAN (R 4.2.0) #> R.oo 1.25.0 2022-06-12 [1] CRAN (R 4.2.0) #> R.utils 2.12.2 2022-11-11 [1] CRAN (R 4.2.0) #> R6 2.5.1 2021-08-19 [1] CRAN (R 4.2.0) #> Rcpp 1.0.9 2022-07-08 [1] CRAN (R 4.2.0) #> remotes 2.4.2 2021-11-30 [1] CRAN (R 4.2.0) #> reprex 2.0.2 2022-08-17 [1] CRAN (R 4.2.0) #> reticulate 1.26 2022-08-31 [1] CRAN (R 4.2.0) #> rlang 1.0.6 2022-09-24 [1] CRAN (R 4.2.0) #> rmarkdown 2.18 2022-11-09 [1] CRAN (R 4.2.0) #> rprojroot 2.0.3 2022-04-02 [1] CRAN (R 4.2.0) #> rstudioapi 0.14 2022-08-22 [1] CRAN (R 4.2.0) #> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.2.0) #> shiny 1.7.3 2022-10-25 [1] CRAN (R 4.2.0) #> SparseM 1.81 2021-02-18 [1] CRAN (R 4.2.0) #> stringi 1.7.8 2022-07-11 [1] CRAN (R 4.2.0) #> stringr 1.5.0 2022-12-02 [1] CRAN (R 4.2.0) #> styler 1.8.1 2022-11-07 [1] CRAN (R 4.2.0) #> survival 3.4-0 2022-08-09 [1] CRAN (R 4.2.0) #> tensorflow 2.9.0 2022-05-21 [1] CRAN (R 4.2.0) #> testthat * 3.1.5 2022-10-08 [1] CRAN (R 4.2.0) #> tfautograph 0.3.2 2021-09-17 [1] CRAN (R 4.2.0) #> tfruns 1.5.1 2022-09-05 [1] CRAN (R 4.2.0) #> urlchecker 1.0.1 2021-11-30 [1] CRAN (R 4.2.0) #> usethis 2.1.6 2022-05-25 [1] CRAN (R 4.2.0) #> vctrs 0.5.1 2022-11-16 [1] CRAN (R 4.2.0) #> waldo 0.4.0 2022-03-16 [1] CRAN (R 4.2.0) #> whisker 0.4 2019-08-28 [1] CRAN (R 4.2.0) #> withr 2.5.0 2022-03-03 [1] CRAN (R 4.2.0) #> xfun 0.35 2022-11-16 [1] CRAN (R 4.2.0) #> xtable 1.8-4 2019-04-21 [1] CRAN (R 4.2.0) #> yaml 2.3.6 2022-10-18 [1] CRAN (R 4.2.0) #> yesno 0.1.2 2020-07-10 [1] CRAN (R 4.2.0) #> #> [1] /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/library #> #> P ── Loaded and on-disk path mismatch. #> #> ─ Python configuration ─────────────────────────────────────────────────────── #> python: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/bin/python #> libpython: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/libpython3.8.dylib #> pythonhome: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2:/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2 #> version: 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:05:16) [Clang 12.0.1 ] #> numpy: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/numpy #> numpy_version: 1.23.2 #> tensorflow: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/tensorflow #> #> NOTE: Python version was forced by use_python function #> #> ────────────────────────────────────────────────────────────────────────────── ```