Closed goldingn closed 8 months ago
OK, so it looks like the issue is that in TF2 greta, we define the data itself as a constant (not batch dimension) but then explicitly tile it to match the batch dimension to avoid later broadcasting hiccups: https://github.com/greta-dev/greta/blob/15b0157af4c66dd0173733b836cdb5a5dc34a0f3/R/node_types.R#L37-L56
I think the most profitable approach will be to see whether greta.dynamics can be modified to set the expected shapes of the tensors passed into tf$while_loop()
e.g. here to all have a batch dimension
greta_stash
so that greta can create the constant tensors with the correct dimensions.I also got temporarily tripped because I importing greta_stash
into greta.dynamics, but it turns out that breaks everything - so here it is accessed directly via .internals
The current release of greta.dynamics does not work with TF2 because of it relies on contrib for ODE solvers (#25).
Once that is fixed, another issue will appear, which is that
greta:as_tf_function()
(which is used internally in greta.dynamics functions) is failing to correctly handle definition of data greta array nodes inside the function. This is causing issues in experimental branches of greta.dynamics (e.g.greta_2
) which use custom functions (but not ODE solvers).The ODE example code in the current release is one example of one which will fail, because it includes the term
(1 - Prey / K)
, where the1
will be turned into a data greta array inside the function. When doing MCMC, this will yield an error message returned will be related to tensors changing their shape inside the loops used to solve dynamics.Here's a reprex using the
greta_2
branch and itsiterate_dynamic_function()
function, using TF2:Created on 2024-03-13 with reprex v2.0.2
Note I did a slightly weird delayed execution thing here with a default value for the
one
argument so I could use the same function in both examples, but it's the same thing that happens withstate + 1
.This code was used in TF1 greta to overcome the issue, by ensuring that the tensors for data greta arrays were defined as constants (no batch dimensions) rather than placeholders. It seems to still be in place though, so I'll try to debug now.