greta-dev / greta

simple and scalable statistical modelling in R
https://greta-stats.org
Other
518 stars 63 forks source link

`calculate()` errors if user working environment has an object called `batch_size` #634

Open goldingn opened 4 weeks ago

goldingn commented 4 weeks ago

When the user has an object in their environment called batch_size, and the user is simulating values with calculate() setting the values argument to be MCMC samples, it causes an obscure TF dimension error. Reprex:

library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply
x <- normal(0, 1)
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#> 
y <- x ^ 2
m <- model(x)
draws <- mcmc(m,
              chains = 1,
              warmup = 10,
              n_samples = 10)
#> 
#>     warmup                                             0/10 | eta:  ?s              warmup ========================================== 10/10 | eta:  0s          
#>   sampling                                             0/10 | eta:  ?s            sampling ========================================== 10/10 | eta:  0s

calculate(x, values = draws, nsim = 1)
#> $x
#> , , 1
#> 
#>            [,1]
#> [1,] -0.1318054
batch_size <- 13
calculate(x, values = draws, nsim = 1)
#> Error in eval(expr, envir, enclos): RuntimeError: ValueError: Shape must be rank 3 but is rank 2 for '{{node Tile_1}} = Tile[T=DT_DOUBLE, Tmultiples=DT_INT32](Const, Tile_1/multiples)' with input shapes: [1,1,1], [2].

Created on 2024-06-16 with reprex v2.0.2 It's probably down to my shonky coding interacting with R's lexical scoping in a weird way, triggered by this bit of code here: https://github.com/greta-dev/greta/blob/tf2-poke-tf-fun/R/calculate.R#L466-L472

This was not especially fun to debug. For now the workaround is to call my environment batch_size object anything else.

goldingn commented 4 weeks ago

FML

Screenshot 2024-06-16 at 3 41 14 pm
njtierney commented 4 weeks ago

Another work around could be to call it .batch_size internally - adding a . prefix is usually not done in production code but can be a nice thing to indicate internal use?