tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.16k stars 1.08k forks source link

Can't jit PoissonLogNormalQuadratureCompound log_prob #1803

Open GianmarcoCallegher opened 3 months ago

GianmarcoCallegher commented 3 months ago

If I try to jit the log_prob method of the PoissonLogNormalQuadratureCompound

from jax import jit
import tensorflow_probability.substrates.jax.distributions as tfd

jit(tfd.PoissonLogNormalQuadratureCompound(0.0, 1.0).log_prob)(1.)

I get the following error:

TypeError: Shapes must be 1D sequences of concrete values of integer type, got
Traced<ShapedArray(int32[1])>with<DynamicJaxprTrace(level=1/0)>.

operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line [/var/folders/vr/x979b72d00jgztv5gxjgl5j00000gn/T/ipykernel_26756/1664260717.py:1](https://file+.vscode-resource.vscode-cdn.net/var/folders/vr/x979b72d00jgztv5gxjgl5j00000gn/T/ipykernel_26756/1664260717.py:1) (<module>)

  operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line [/var/folders/vr/x979b72d00jgztv5gxjgl5j00000gn/T/ipykernel_26756/1664260717.py:1](https://file+.vscode-resource.vscode-cdn.net/var/folders/vr/x979b72d00jgztv5gxjgl5j00000gn/T/ipykernel_26756/1664260717.py:1) (<module>)