If I try to jit the log_prob method of the PoissonLogNormalQuadratureCompound
from jax import jit
import tensorflow_probability.substrates.jax.distributions as tfd
jit(tfd.PoissonLogNormalQuadratureCompound(0.0, 1.0).log_prob)(1.)
I get the following error:
TypeError: Shapes must be 1D sequences of concrete values of integer type, got
Traced<ShapedArray(int32[1])>with<DynamicJaxprTrace(level=1/0)>.
operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
from line [/var/folders/vr/x979b72d00jgztv5gxjgl5j00000gn/T/ipykernel_26756/1664260717.py:1](https://file+.vscode-resource.vscode-cdn.net/var/folders/vr/x979b72d00jgztv5gxjgl5j00000gn/T/ipykernel_26756/1664260717.py:1) (<module>)
operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
from line [/var/folders/vr/x979b72d00jgztv5gxjgl5j00000gn/T/ipykernel_26756/1664260717.py:1](https://file+.vscode-resource.vscode-cdn.net/var/folders/vr/x979b72d00jgztv5gxjgl5j00000gn/T/ipykernel_26756/1664260717.py:1) (<module>)
If I try to
jit
thelog_prob
method of thePoissonLogNormalQuadratureCompound
I get the following error: