rstudio / tfprobability

R interface to TensorFlow Probability
https://rstudio.github.io/tfprobability/
Other
54 stars 16 forks source link

Runtime error with custom distributions #140

Open dirmeier opened 3 years ago

dirmeier commented 3 years ago

Hello all,

thanks for making probability accessible for R, it's hugely helpful. I am trying to port some of my Python code to R and encountered the following error:

> x_ <- tfkl$Input(shape=shape(2L), dtype=tf$float32)
> log_prob_ <- distribution$log_prob(x_)
 Error in py_call_impl(callable, dots$args, dots$keywords) : 
  RuntimeError: Evaluation error: non-numeric argument to binary operator. 

This seems to happen when creating a custom distribution, for instance using some change of variables. A minimal reproducble example is below. The bijector is from your blog.

X <- tfd$Normal(0.0, 1.0)$sample(shape(5, 2))

bijector_leaky_relu <- function(alpha) {
  tfb_inline(
    forward_fn = function(x)
      tf$where(tf$greater_equal(x, 0), x, alpha * x),
    inverse_fn = function(y)
      tf$where(tf$greater_equal(y, 0), y, 1/alpha * y),
    inverse_log_det_jacobian_fn = function(y) {
      I <- tf$ones_like(y)
      J_inv <- tf$where(tf$greater_equal(y, 0), I, 1/alpha * I)
      log_abs_det_J_inv <- tf$math$log(tf$abs(J_inv))
      tf$reduce_sum(log_abs_det_J_inv, axis = 1L)
    },
    forward_min_event_ndims = 1,
    inverse_min_event_ndims = 1
  )
}

bij <- bijector_leaky_relu(0.0)
bij$forward(X)
bij$inverse(bij$forward(X))

# leads to error
distribution <- tfd$TransformedDistribution(
  distribution=tfd_multivariate_normal_diag(loc = c(0, 0)),
  bijector=bij
)

# commenting this out and using it instead of the above, works
# distribution <- tfd_multivariate_normal_diag(loc = c(0, 0))

distribution$sample(shape(5))
distribution$log_prob(distribution$sample(shape(5)))

x_ <- tfkl$Input(shape=shape(2L), dtype=tf$float32)

> log_prob_ <-distribution$log_prob(x_)
 Error in py_call_impl(callable, dots$args, dots$keywords) : 
  RuntimeError: Evaluation error: non-numeric argument to binary operator. 

Do you have any ideas what the problem is? Thank you very much.

Best, Simon

skeydan commented 3 years ago

Hi Simon,

sorry for that. I suspect it could be due to your using the latest version, TFP 0.12 (compatible with TF 2.4)?

We have not yet updated the R code for that, but will try to get to it ASAP! Until then, please use TFP 0.11/ TF 2.3. Thanks!