Probabilistic reasoning and statistical analysis in TensorFlow
4.16k
stars
1.08k
forks
source link
tfp.math.scan_associative doesn't work for all associative functions (it should be using `vmap` for `lowered_fn`) #1812
Open
Joshuaalbert opened 4 weeks ago
Here is a simple example of an associative function that
scan_associative
fails to handle because it assumes the associative op broadcasts.The solution is to use
jax.vmap
to distributed elements inlowered_fn
here rather than rely on broadcasting.MVCE