Open njtierney opened 6 months ago
Currently the plan is to have the following code inside the definition of an operation node. Essentially this will send off a warning when something uses a cholesky representation and we are doing sampling on it
tf = function(dag) {
# where to put it
tfe <- dag$tf_environment
# what to call the tensor object
tf_name <- dag$tf_name(self)
mode <- dag$how_to_define(self)
# if sampling get the distribution constructor and sample this
if (mode == "sampling") {
tensor <- dag$draw_sample(self$distribution)
if (has_representation(self, "cholesky")) {
## TF1/2
## This approach currently fails because of how we use representations
## within greta.
# We will now error here since when sampling from a cholesky
# represented variable, we don't really get consistent results
cli::cli_warn(
## Could note that there are false positives?
message = c(
"We currently cannot use {.fun calculate} to sample a greta \\
array with a cholesky factor, due to an internal issue with how \\
greta handles cholesky representations.",
"See issue here on github for more details:",
"{.url https://github.com/greta-dev/greta/issues/593}"
)
)
cholesky_tensor <- tf_chol(tensor)
cholesky_tf_name <- dag$tf_name(self$representation$cholesky)
assign(cholesky_tf_name, cholesky_tensor, envir = dag$tf_environment)
# tf_name <- cholesky_tf_name
# tensor <- cholesky_tensor
}
}
So we get the following behaviour:
library(greta)
#>
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#>
#> binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#>
#> %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#> eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#> tapply
# succeeds
sig <- lkj_correlation(2, dim = 2)
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#>
w <- wishart(5, sig)
m <- model(w)
draws <- mcmc(m, warmup = 0, n_samples = 5, verbose = FALSE)
draws
#> $`11`
#> Markov Chain Monte Carlo (MCMC) output:
#> Start = 1
#> End = 5
#> Thinning interval = 1
#> w[1,1] w[2,1] w[1,2] w[2,2]
#> 1 0.06064959 -0.08968761 -0.08968761 0.2080036
#> 2 0.18153273 -0.16358580 -0.16358580 0.5022380
#> 3 0.11352640 -0.18384974 -0.18384974 0.5941550
#> 4 0.85345937 -0.98212422 -0.98212422 1.1355574
#> 5 0.77398713 -0.53847586 -0.53847586 0.5097147
#>
#> $`12`
#> Markov Chain Monte Carlo (MCMC) output:
#> Start = 1
#> End = 5
#> Thinning interval = 1
#> w[1,1] w[2,1] w[1,2] w[2,2]
#> 1 0.0007307899 0.001818318 0.001818318 0.004525567
#> 2 0.0007307899 0.001818318 0.001818318 0.004525567
#> 3 0.0007307899 0.001818318 0.001818318 0.004525567
#> 4 0.0007307899 0.001818318 0.001818318 0.004525567
#> 5 0.0007307899 0.001818318 0.001818318 0.004525567
#>
#> $`13`
#> Markov Chain Monte Carlo (MCMC) output:
#> Start = 1
#> End = 5
#> Thinning interval = 1
#> w[1,1] w[2,1] w[1,2] w[2,2]
#> 1 0.14602971 -0.13032268 -0.13032268 0.3595358
#> 2 0.08595726 -0.06137547 -0.06137547 0.8014967
#> 3 0.30209596 -0.02779237 -0.02779237 0.7412944
#> 4 0.02448499 -0.04408638 -0.04408638 0.8345146
#> 5 0.12286190 0.05137151 0.05137151 1.4068696
#>
#> $`14`
#> Markov Chain Monte Carlo (MCMC) output:
#> Start = 1
#> End = 5
#> Thinning interval = 1
#> w[1,1] w[2,1] w[1,2] w[2,2]
#> 1 0.1987510 -0.03785009 -0.03785009 0.1985830
#> 2 0.1984730 -0.03523137 -0.03523137 0.2151315
#> 3 0.7619678 -0.42175497 -0.42175497 0.2995763
#> 4 0.8598383 -0.57636239 -0.57636239 0.5204147
#> 5 1.4662869 -0.60184908 -0.60184908 0.9378646
#>
#> attr(,"class")
#> [1] "greta_mcmc_list" "mcmc.list"
#> attr(,"model_info")
#> attr(,"model_info")$raw_draws
#> $`11`
#> Markov Chain Monte Carlo (MCMC) output:
#> Start = 1
#> End = 5
#> Thinning interval = 1
#> 1 2 3 4
#> 1 0.4324554 0.2462714 -0.3641820 0.27454514
#> 2 0.5676601 0.4260666 -0.3839442 0.59567174
#> 3 0.7014721 0.3369368 -0.5456505 0.54444511
#> 4 1.1698393 0.9238286 -1.0631022 0.07328834
#> 5 1.0902837 0.8797654 -0.6120676 0.36754319
#>
#> $`12`
#> Markov Chain Monte Carlo (MCMC) output:
#> Start = 1
#> End = 5
#> Thinning interval = 1
#> 1 2 3 4
#> 1 -0.03417148 0.02703313 0.06726259 0.001144732
#> 2 -0.03417148 0.02703313 0.06726259 0.001144732
#> 3 -0.03417148 0.02703313 0.06726259 0.001144732
#> 4 -0.03417148 0.02703313 0.06726259 0.001144732
#> 5 -0.03417148 0.02703313 0.06726259 0.001144732
#>
#> $`13`
#> Markov Chain Monte Carlo (MCMC) output:
#> Start = 1
#> End = 5
#> Thinning interval = 1
#> 1 2 3 4
#> 1 0.1643967 0.3821383 -0.34103534 0.4931842
#> 2 0.2189142 0.2931847 -0.20934064 0.8704443
#> 3 0.6258184 0.5496326 -0.05056536 0.8594984
#> 4 0.6418233 0.1564768 -0.28174385 0.8689851
#> 5 1.0310803 0.3505166 0.14655942 1.1770259
#>
#> $`14`
#> Markov Chain Monte Carlo (MCMC) output:
#> Start = 1
#> End = 5
#> Thinning interval = 1
#> 1 2 3 4
#> 1 -0.19703020 0.4458150 -0.08490090 0.4374641
#> 2 0.17551313 0.4455031 -0.07908221 0.4570312
#> 3 0.32429059 0.8729076 -0.48316104 0.2571609
#> 4 0.13162990 0.9272746 -0.62156600 0.3661563
#> 5 -0.02499826 1.2109033 -0.49702488 0.8311624
#>
#> attr(,"class")
#> [1] "mcmc.list"
#>
#> attr(,"model_info")$samplers
#> attr(,"model_info")$samplers$`1`
#> Error in vapply(x, format, "", big.mark = big.mark, big.interval = big.interval, : values must be length 1,
#> but FUN(X[[4]]) result is length 4
# fails
x <- wishart(df = 4, Sigma = diag(3))
chol_x <- chol(x)
calc_chol <- calculate(x, chol_x, nsim = 1)
#> Warning: We currently cannot use `calculate()` to sample a greta array with a cholesky
#> factor, due to an internal issue with how greta handles cholesky
#> representations.
#> See issue here on github for more details:
#> <https://github.com/greta-dev/greta/issues/593>
calc_chol
#> $x
#> , , 1
#>
#> [,1] [,2] [,3]
#> [1,] 12.53445 3.035458 -2.170019
#>
#> , , 2
#>
#> [,1] [,2] [,3]
#> [1,] 3.035458 4.972589 -0.09387038
#>
#> , , 3
#>
#> [,1] [,2] [,3]
#> [1,] -2.170019 -0.09387038 1.807308
#>
#>
#> $chol_x
#> , , 1
#>
#> [,1] [,2] [,3]
#> [1,] 1 1 1
#>
#> , , 2
#>
#> [,1] [,2] [,3]
#> [1,] 1 1 1
#>
#> , , 3
#>
#> [,1] [,2] [,3]
#> [1,] 1 1 1
Created on 2024-05-07 with reprex v2.1.0
Here's an attempted solution at this problem, in commit: https://github.com/greta-dev/greta/pull/534/commits/917f936427205758f1756d66fc38310ea7edadd8
This adds a special flag "golden_cholesky" when chol
is used, so we can identify those arrays and warn for them.
Unfortunately it seems using chol(x)
propagates the cholesky flag I created.
Here's a reprex of the approach:
library(greta)
#>
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#>
#> binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#>
#> %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#> eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#> tapply
x <- wishart(df = 4, Sigma = diag(3))
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#>
x
#> greta array (operation following a wishart distribution)
#>
#> [,1] [,2] [,3]
#> [1,] ? ? ?
#> [2,] ? ? ?
#> [3,] ? ? ?
Don’t warn here, this should be fine
pre_mcmc <- calculate(x, nsim = 1)
pre_mcmc
#> $x
#> , , 1
#>
#> [,1] [,2] [,3]
#> [1,] 5.055917 0.6333886 3.368602
#>
#> , , 2
#>
#> [,1] [,2] [,3]
#> [1,] 0.6333886 1.445194 0.9147159
#>
#> , , 3
#>
#> [,1] [,2] [,3]
#> [1,] 3.368602 0.9147159 4.388693
This should warn
chol_x <- chol(x)
calculate(chol_x, nsim = 1)
#> Warning: Cannot use `calculate()` to sample a cholesky factor of a greta array
#> E.g., `x_chol <- chol(wishart(df = 4, Sigma = diag(3)))`
#> `calculate(x_chol)`
#> This is due to an internal issue with how greta handles cholesky
#> representations.
#> See issue here on github for more details:
#> <https://github.com/greta-dev/greta/issues/593>
#> $chol_x
#> , , 1
#>
#> [,1] [,2] [,3]
#> [1,] 1 1 1
#>
#> , , 2
#>
#> [,1] [,2] [,3]
#> [1,] 1 1 1
#>
#> , , 3
#>
#> [,1] [,2] [,3]
#> [1,] 1 1 1
But then this will warn (because chol was called on x?)
calculate(x, nsim = 1)
#> Warning: Cannot use `calculate()` to sample a cholesky factor of a greta array
#> E.g., `x_chol <- chol(wishart(df = 4, Sigma = diag(3)))`
#> `calculate(x_chol)`
#> This is due to an internal issue with how greta handles cholesky
#> representations.
#> See issue here on github for more details:
#> <https://github.com/greta-dev/greta/issues/593>
#> $x
#> , , 1
#>
#> [,1] [,2] [,3]
#> [1,] 5.687854 -0.3524896 -1.104498
#>
#> , , 2
#>
#> [,1] [,2] [,3]
#> [1,] -0.3524896 1.861157 -1.468
#>
#> , , 3
#>
#> [,1] [,2] [,3]
#> [1,] -1.104498 -1.468 2.467221
We initially thought that chol_x + 1
would trigger chol_x
to give the
right result - alas.
chol_x_p1 <- chol_x + 1
calculate(x, chol_x, chol_x_p1, nsim = 1)
#> Warning: Cannot use `calculate()` to sample a cholesky factor of a greta array
#> E.g., `x_chol <- chol(wishart(df = 4, Sigma = diag(3)))`
#> `calculate(x_chol)`
#> This is due to an internal issue with how greta handles cholesky
#> representations.
#> See issue here on github for more details:
#> <https://github.com/greta-dev/greta/issues/593>
#> $x
#> , , 1
#>
#> [,1] [,2] [,3]
#> [1,] 1.883891 2.321163 0.5246651
#>
#> , , 2
#>
#> [,1] [,2] [,3]
#> [1,] 2.321163 4.046714 0.5877278
#>
#> , , 3
#>
#> [,1] [,2] [,3]
#> [1,] 0.5246651 0.5877278 0.6911986
#>
#>
#> $chol_x
#> , , 1
#>
#> [,1] [,2] [,3]
#> [1,] 1 1 1
#>
#> , , 2
#>
#> [,1] [,2] [,3]
#> [1,] 1 1 1
#>
#> , , 3
#>
#> [,1] [,2] [,3]
#> [1,] 1 1 1
#>
#>
#> $chol_x_p1
#> , , 1
#>
#> [,1] [,2] [,3]
#> [1,] 2 2 2
#>
#> , , 2
#>
#> [,1] [,2] [,3]
#> [1,] 2 2 2
#>
#> , , 3
#>
#> [,1] [,2] [,3]
#> [1,] 2 2 2
Ideally this should error, specifically calling out chol_x
, not x
.
calculate(x, chol_x, nsim = 1)
#> Warning: Cannot use `calculate()` to sample a cholesky factor of a greta array
#> E.g., `x_chol <- chol(wishart(df = 4, Sigma = diag(3)))`
#> `calculate(x_chol)`
#> This is due to an internal issue with how greta handles cholesky
#> representations.
#> See issue here on github for more details:
#> <https://github.com/greta-dev/greta/issues/593>
#> $x
#> , , 1
#>
#> [,1] [,2] [,3]
#> [1,] 11.52874 1.026355 -1.737126
#>
#> , , 2
#>
#> [,1] [,2] [,3]
#> [1,] 1.026355 1.71024 0.5303688
#>
#> , , 3
#>
#> [,1] [,2] [,3]
#> [1,] -1.737126 0.5303688 1.886979
#>
#>
#> $chol_x
#> , , 1
#>
#> [,1] [,2] [,3]
#> [1,] 1 1 1
#>
#> , , 2
#>
#> [,1] [,2] [,3]
#> [1,] 1 1 1
#>
#> , , 3
#>
#> [,1] [,2] [,3]
#> [1,] 1 1 1
It is is hard to do, I’m currently not sure how I can specifically call
out chol_x
and not x
.
This gist does a comparison of x
and chol_x
:
https://gist.github.com/njtierney/6b8a5d6a8380f61570c6fffcf6a530b5
In addition, there are still issues with MCMC
m <- model(x)
draws <- mcmc(m, warmup = 1, n_samples = 1)
#> running 4 chains simultaneously on up to 8 CPU cores
#>
#> warmup 0/1 | eta: ?s sampling 0/1 | eta: ?s
now the matrix which should be symmetric looks like a cholesky factor (but lower triangular, when it should be upper triangular), and cholesky factor is still coming out as ones
post_mcmc <- calculate(x, nsim = 1)
#> Warning: Cannot use `calculate()` to sample a cholesky factor of a greta array
#> E.g., `x_chol <- chol(wishart(df = 4, Sigma = diag(3)))`
#> `calculate(x_chol)`
#> This is due to an internal issue with how greta handles cholesky
#> representations.
#> See issue here on github for more details:
#> <https://github.com/greta-dev/greta/issues/593>
post_mcmc
#> $x
#> , , 1
#>
#> [,1] [,2] [,3]
#> [1,] 1.048605 0 0
#>
#> , , 2
#>
#> [,1] [,2] [,3]
#> [1,] 1.126161 2.339339 0
#>
#> , , 3
#>
#> [,1] [,2] [,3]
#> [1,] 0.5542009 1.327417 2.427972
Created on 2024-05-10 with reprex v2.1.0
We currently have a bug where when the cholesky is defining itself in sampling mode, and when it is a representation of something.
At the moment in order to deal with / delay the bugs in #593, #594, and #585, we can set this up to error early, rather than its current behaviour, which is to return a matrix of 1s.
This is a stopgap solution so that we can get TF2 greta onto CRAN