I replaced len() with .shape[0] and changed the integration_domain argument validity check so that tf.Variable can be used as integration_domain and the boundary value checks are skipped if the values are not concrete, which can happen when compiling the function.
JIT compilation and gradient calculation works with JAX, Torch and Tensorflow.
I replaced
len()
with.shape[0]
and changed theintegration_domain
argument validity check so thattf.Variable
can be used asintegration_domain
and the boundary value checks are skipped if the values are not concrete, which can happen when compiling the function. JIT compilation and gradient calculation works with JAX, Torch and Tensorflow.