greta-dev / greta.dynamics

a greta extension for modelling dynamical systems
https://greta-dev.github.io/greta.dynamics/
Other
6 stars 2 forks source link

TF2 versions fail with data inside functions #27

Closed goldingn closed 6 months ago

goldingn commented 6 months ago

The current release of greta.dynamics does not work with TF2 because of it relies on contrib for ODE solvers (#25).

Once that is fixed, another issue will appear, which is that greta:as_tf_function() (which is used internally in greta.dynamics functions) is failing to correctly handle definition of data greta array nodes inside the function. This is causing issues in experimental branches of greta.dynamics (e.g. greta_2) which use custom functions (but not ODE solvers).

The ODE example code in the current release is one example of one which will fail, because it includes the term (1 - Prey / K), where the 1 will be turned into a data greta array inside the function. When doing MCMC, this will yield an error message returned will be related to tensors changing their shape inside the loops used to solve dynamics.

Here's a reprex using the greta_2 branch and its iterate_dynamic_function() function, using TF2:

library(greta.dynamics)
#> Loading required package: greta
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply

oneplus <- function(state, iter, one = 1) {
  state + one
}

# this fails, as 'one' is defined inside the function
x <- normal(0, 1)
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#> 
res <- iterate_dynamic_function(transition_function = oneplus,
                                initial_state = x,
                                tol = 0, 
                                niter = 3)
m <- model(res$all_states)
draws <- mcmc(m)
#> Error in eval(expr, envir, enclos): RuntimeError: ValueError: Input tensor `chain_of_reshape_of_identity/forward_log_det_jacobian/reshape/forward/Reshape:0` enters the loop with shape (1, 1, 1), but has shape (None, 1, 1) after one iteration. To allow the shape to vary across iterations, use the `shape_invariants` argument of tf.while_loop to specify a less-specific shape.

# this works, as 'one' is defined outside the function and passed in
x <- normal(0, 1)
res <- iterate_dynamic_function(transition_function = oneplus,
                                initial_state = x,
                                tol = 0,
                                one = ones(1),
                                niter = 3)
m <- model(res$all_states)
draws <- mcmc(m)
#> running 4 chains simultaneously on up to 8 CPU cores

Created on 2024-03-13 with reprex v2.0.2

Note I did a slightly weird delayed execution thing here with a default value for the one argument so I could use the same function in both examples, but it's the same thing that happens with state + 1.

This code was used in TF1 greta to overcome the issue, by ensuring that the tensors for data greta arrays were defined as constants (no batch dimensions) rather than placeholders. It seems to still be in place though, so I'll try to debug now.

goldingn commented 6 months ago

OK, so it looks like the issue is that in TF2 greta, we define the data itself as a constant (not batch dimension) but then explicitly tile it to match the batch dimension to avoid later broadcasting hiccups: https://github.com/greta-dev/greta/blob/15b0157af4c66dd0173733b836cdb5a5dc34a0f3/R/node_types.R#L37-L56

I think the most profitable approach will be to see whether greta.dynamics can be modified to set the expected shapes of the tensors passed into tf$while_loop() e.g. here to all have a batch dimension

goldingn commented 6 months ago

28 has the solution - the batch size needs to be found dynamically (inside the while loop to avoid graph weirdness) and put in greta_stash so that greta can create the constant tensors with the correct dimensions.

I also got temporarily tripped because I importing greta_stash into greta.dynamics, but it turns out that breaks everything - so here it is accessed directly via .internals