Open njtierney opened 8 months ago
So this error seems to be first noticed at
https://github.com/greta-dev/greta/blob/1bbc661a07c915ceb3b2f8af417b865fa7453d12/R/calculate.R#L506
Perhaps I am missing where the chol
operation is happening, but this is where we have the matrix of 1s returned for the first time.
Debugging tf_chol
internal function in greta:
tf_chol <- function(x) {
x_chol <- tf$linalg$cholesky(x)
x_chol_t <- tf_transpose(x_chol)
x_chol_t
}
When this is computed, we get something that looks right:
Browse[3]> x_chol_t
tf.Tensor(
[[[ 3.54538307 0.22028915 -0.05668089]
[ 0. 1.45304198 -0.45842791]
[ 0. 0. 0.28949603]]], shape=(1, 3, 3), dtype=float64)
So something is happening at
https://github.com/greta-dev/greta/blob/1bbc661a07c915ceb3b2f8af417b865fa7453d12/R/calculate.R#L506
where perhaps it isn't grabbing the right piece of data or something. Trying to debug this is a bit tricky.
OK so looking further into what is inside tfe
at
https://github.com/greta-dev/greta/blob/1bbc661a07c915ceb3b2f8af417b865fa7453d12/R/calculate.R#L506
It looks like we are grabbing the wrong slot, or there is some mislabeling that is happening here.
We are getting tfe$all_sampling_operation_1
, which contains 1s:
Browse[2]> tfe$all_sampling_operation_1
tf.Tensor(
[[[1. 1. 1.]
[1. 1. 1.]
[1. 1. 1.]]], shape=(1, 3, 3), dtype=float64)
But if we look at the variable named tfe$all_sampling_
, then we get:
#> Browse[4]> tfe$all_sampling_
#> tf.Tensor(
#> [[[ 3.54538307 0.22028915 -0.05668089]
#> [ 0. 1.45304198 -0.45842791]
#> [ 0. 0. 0.28949603]]], shape=(1, 3, 3), dtype=float64)
Which is the output we want.
So something is wrong is happening in tf_name
or something, will investigate!
OK so there's a few things going on here that are a bit hard to untangle
For context, here is the code that I'm running
load_all()
set.seed(2023-11-10-1209)
x <- wishart(df = 4, Sigma = diag(3))
chol_x <- chol(x)
calc_chol <- calculate(chol_x, nsim = 1)
In this branch: https://github.com/njtierney/greta/tree/cholesky-585
OK, so, in:
https://github.com/njtierney/greta/blob/cholesky-585/R/node_types.R#L162-L169
# if sampling get the distribution constructor and sample this
if (mode == "sampling") {
tensor <- dag$draw_sample(self$distribution)
if (has_representation(self, "cholesky")) {
cholesky_tensor <- tf_chol(tensor)
cholesky_tf_name <- dag$tf_name(self$representation$cholesky)
assign(cholesky_tf_name, cholesky_tensor, envir = dag$tf_environment)
}
}
We are calling the wrong name of the representation
here:
https://github.com/njtierney/greta/blob/cholesky-585/R/node_types.R#L166
cholesky_tf_name <- dag$tf_name(self$representation$cholesky)
This should be representations
(with an s)
But if you do
cholesky_tf_name <- dag$tf_name(self$representations$cholesky)
Then this will not work, since self$representations$cholesky
is a greta array:
Browse[3]> self$representations$cholesky
greta array (variable)
[,1] [,2] [,3]
[1,] ? ? ?
[2,] 0 ? ?
[3,] 0 0 ?
So I think it was supposed to be
cholesky_tensor <- tf_chol(tensor)
cholesky_tf_name <- dag$tf_name(self)
assign(cholesky_tf_name, cholesky_tensor, envir = dag$tf_environment)
However, cholesky_tf_name
resolves to all_sampling_operation_2
, which is just duplicating the above line:
https://github.com/njtierney/greta/blob/cholesky-585/R/node_types.R#L158
tf_name <- dag$tf_name(self)
And overall, this is getting overwritten in
https://github.com/njtierney/greta/blob/cholesky-585/R/node_types.R#L191
assign(tf_name, tensor, envir = dag$tf_environment)
And ultimately, the only tf_name
that gets resolved here:
https://github.com/njtierney/greta/blob/cholesky-585/R/calculate.R#L506
target_tensor_list <- lapply(target_names_list, get, envir = tfe
Is all_sampling_operation_1
, not all_sampling_operation_2
, and all_sampling_operation_1
is a bunch of 1s.
So I think I've started to hit on where the error is, sort of, but it seems like we might have a slightly more complex problem, and I can't immediately recall what our intention/workflow was with doing something like
https://github.com/njtierney/greta/blob/cholesky-585/R/node_types.R#L162-L169
# if sampling get the distribution constructor and sample this
if (mode == "sampling") {
tensor <- dag$draw_sample(self$distribution)
if (has_representation(self, "cholesky")) {
cholesky_tensor <- tf_chol(tensor)
cholesky_tf_name <- dag$tf_name(self$representation$cholesky)
assign(cholesky_tf_name, cholesky_tensor, envir = dag$tf_environment)
}
}
In the first place.
How is this working with the other distributions that use representations?
They don't really use it the same way, I am beginning to think that this branch: https://github.com/njtierney/greta/tree/cholesky-585 was a bit more experimental.
The typical usage is to set some flag, e.g., https://github.com/greta-dev/greta/blob/93aaf361c04591a0d20f73c25b6e6693023482fd/R/probability_distributions.R#L162-L164
I've tried going through the debugging warrens back in greta tf2-poke-fun branch, with
load_all()
x <- wishart(df = 4, Sigma = diag(3))
log_x <- log(x)
debugonce(calculate_target_tensor_list)
res <- calculate(log_x, nsim = 1)
and comparing this to
load_all()
x <- wishart(df = 4, Sigma = diag(3))
chol_x <- chol(x)
debugonce(calculate_target_tensor_list)
res <- calculate(chol_x, nsim = 1)
And where I'm landing currently is that in the log
version, we end up at
https://github.com/greta-dev/greta/blob/master/R/node_types.R#L167
the operation
is log
But when we want to do chol
, the operation
is identity
, and we don't end up back there.
I might need some guidance on driving the debugger to the right spot.
The issue is related to the dag of using chol
on wishart
:
x <- wishart(df = 4, Sigma = diag(3))
chol_x <- chol(x)
m <- model(chol_x)
plot(m)
As @goldingn said, the issue looks like it is in replacing the variable (the circle in the plotted dag):
library(greta)
#>
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#>
#> binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#>
#> %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#> eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#> tapply
# devtools::load_all()
x <- wishart(df = 4, Sigma = diag(3))
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#>
chol_x <- chol(x)
chol_x_orig <- greta:::representation(x, "cholesky")
res <- calculate(chol_x, chol_x_orig, nsim = 1)
res
#> $chol_x
#> , , 1
#>
#> [,1] [,2] [,3]
#> [1,] 1 1 1
#>
#> , , 2
#>
#> [,1] [,2] [,3]
#> [1,] 1 1 1
#>
#> , , 3
#>
#> [,1] [,2] [,3]
#> [1,] 1 1 1
#>
#>
#> $chol_x_orig
#> , , 1
#>
#> [,1] [,2] [,3]
#> [1,] 3.291763 0 0
#>
#> , , 2
#>
#> [,1] [,2] [,3]
#> [1,] 0.002398802 1.579327 0
#>
#> , , 3
#>
#> [,1] [,2] [,3]
#> [1,] -1.216082 0.9263707 2.086389
# check names to make sure we got the right ones
m <- model(chol_x)
m$dag$tf_name(greta:::get_node(x))
#> [1] "all_forward_operation_2"
m$dag$tf_name(greta:::get_node(chol_x))
#> [1] "all_forward_operation_1"
m$dag$tf_name(greta:::get_node(chol_x_orig))
#> [1] "all_forward_variable_1"
Created on 2023-11-17 with reprex v2.0.2
Follow up from meeting on this, we need to do some analysis on when the representations are used, specifically
Everything that is a representation of something else (e.g., cholesky) needs to know what it is a representation of, and add that node as its parent. However this could be infinitely recursive. We need to do an analysis of when representations are used in variables and they have a distribution, and the distribution can't update it, then we tell it that is has parents, and then define the appropriate parents. Remembering that we want to avoid infinite recursion.
This error is produced because the Cholesky factorisation of a Wishart produces a bunch of 1s.
Note that this is using the branch in #587
It is related to #585 and is currently being addressed in #587
Created on 2023-11-09 with reprex v2.0.2
Session info
``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.3.2 (2023-10-31) #> os macOS Sonoma 14.0 #> system aarch64, darwin20 #> ui X11 #> language (EN) #> collate en_US.UTF-8 #> ctype en_US.UTF-8 #> tz Australia/Hobart #> date 2023-11-09 #> pandoc 3.1.1 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/ (via rmarkdown) #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> package * version date (UTC) lib source #> backports 1.4.1 2021-12-13 [1] CRAN (R 4.3.0) #> base64enc 0.1-3 2015-07-28 [1] CRAN (R 4.3.0) #> callr 3.7.3 2022-11-02 [1] CRAN (R 4.3.0) #> cli 3.6.1 2023-03-23 [1] CRAN (R 4.3.0) #> coda 0.19-4 2020-09-30 [1] CRAN (R 4.3.0) #> codetools 0.2-19 2023-02-01 [1] CRAN (R 4.3.2) #> crayon 1.5.2 2022-09-29 [1] CRAN (R 4.3.0) #> digest 0.6.33 2023-07-07 [1] CRAN (R 4.3.0) #> evaluate 0.23 2023-11-01 [1] CRAN (R 4.3.1) #> fastmap 1.1.1 2023-02-24 [1] CRAN (R 4.3.0) #> fs 1.6.3 2023-07-20 [1] CRAN (R 4.3.0) #> future 1.33.0 2023-07-01 [1] CRAN (R 4.3.0) #> globals 0.16.2 2022-11-21 [1] CRAN (R 4.3.0) #> glue 1.6.2 2022-02-24 [1] CRAN (R 4.3.0) #> greta * 0.4.3.9000 2023-11-09 [1] local #> hms 1.1.3 2023-03-21 [1] CRAN (R 4.3.0) #> htmltools 0.5.7 2023-11-03 [1] CRAN (R 4.3.1) #> jsonlite 1.8.7 2023-06-29 [1] CRAN (R 4.3.0) #> knitr 1.45 2023-10-30 [1] CRAN (R 4.3.1) #> lattice 0.21-9 2023-10-01 [1] CRAN (R 4.3.2) #> lifecycle 1.0.3 2022-10-07 [1] CRAN (R 4.3.0) #> listenv 0.9.0 2022-12-16 [1] CRAN (R 4.3.0) #> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.3.0) #> Matrix 1.6-1.1 2023-09-18 [1] CRAN (R 4.3.2) #> parallelly 1.36.0 2023-05-26 [1] CRAN (R 4.3.0) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.3.0) #> png 0.1-8 2022-11-29 [1] CRAN (R 4.3.0) #> prettyunits 1.2.0 2023-09-24 [1] CRAN (R 4.3.1) #> processx 3.8.2 2023-06-30 [1] CRAN (R 4.3.0) #> progress 1.2.2 2019-05-16 [1] CRAN (R 4.3.0) #> ps 1.7.5 2023-04-18 [1] CRAN (R 4.3.0) #> purrr 1.0.2 2023-08-10 [1] CRAN (R 4.3.0) #> R.cache 0.16.0 2022-07-21 [1] CRAN (R 4.3.0) #> R.methodsS3 1.8.2 2022-06-13 [1] CRAN (R 4.3.0) #> R.oo 1.25.0 2022-06-12 [1] CRAN (R 4.3.0) #> R.utils 2.12.2 2022-11-11 [1] CRAN (R 4.3.0) #> R6 2.5.1 2021-08-19 [1] CRAN (R 4.3.0) #> Rcpp 1.0.11 2023-07-06 [1] CRAN (R 4.3.0) #> reprex 2.0.2 2022-08-17 [1] CRAN (R 4.3.0) #> reticulate 1.34.0 2023-10-12 [1] CRAN (R 4.3.1) #> rlang 1.1.1 2023-04-28 [1] CRAN (R 4.3.0) #> rmarkdown 2.25 2023-09-18 [1] CRAN (R 4.3.1) #> rstudioapi 0.15.0 2023-07-07 [1] CRAN (R 4.3.0) #> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.3.0) #> styler 1.9.1 2023-03-04 [1] CRAN (R 4.3.0) #> tensorflow 2.14.0 2023-09-28 [1] CRAN (R 4.3.1) #> tfautograph 0.3.2 2021-09-17 [1] CRAN (R 4.3.0) #> tfruns 1.5.1 2022-09-05 [1] CRAN (R 4.3.0) #> vctrs 0.6.4 2023-10-12 [1] CRAN (R 4.3.1) #> whisker 0.4.1 2022-12-05 [1] CRAN (R 4.3.0) #> withr 2.5.2 2023-10-30 [1] CRAN (R 4.3.1) #> xfun 0.41 2023-11-01 [1] CRAN (R 4.3.1) #> yaml 2.3.7 2023-01-23 [1] CRAN (R 4.3.0) #> #> [1] /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/library #> #> ─ Python configuration ─────────────────────────────────────────────────────── #> python: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/bin/python #> libpython: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/libpython3.8.dylib #> pythonhome: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2:/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2 #> version: 3.8.15 | packaged by conda-forge | (default, Nov 22 2022, 08:49:06) [Clang 14.0.6 ] #> numpy: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/numpy #> numpy_version: 1.23.2 #> tensorflow: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/tensorflow #> #> NOTE: Python version was forced by use_python() function #> #> ────────────────────────────────────────────────────────────────────────────── ```